diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..b9a41bb --- /dev/null +++ b/.editorconfig @@ -0,0 +1,19 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +insert_final_newline = true +indent_style = space +indent_size = 2 +trim_trailing_whitespace = true + +[*.go] +indent_style = tab +indent_size = 4 + +[Makefile] +indent_style = tab + +[*.md] +trim_trailing_whitespace = false diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..42cd9a3 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,47 @@ +name: CI + +on: + push: + branches: + - main + pull_request: + +concurrency: + group: ci-${{ github.workflow }}-${{ github.ref }} + cancel-in-progress: true + +jobs: + build-and-test: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install Nix + uses: cachix/install-nix-action@v27 + with: + extra_nix_config: | + experimental-features = nix-command flakes + + - name: Setup Go + uses: actions/setup-go@v5 + with: + go-version-file: go.work + + - name: Install just + uses: taiki-e/install-action@just + + - name: Nix lint + run: just lint-nix + + - name: Pre-commit hooks + run: just precommit-run + + - name: Guardrails + run: just ci-check + + - name: Nix build package + run: nix build 'path:.#sotto' + + - name: Nix run --help smoke + run: nix run 'path:.#sotto' -- --help diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000..035e291 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,37 @@ +name: Release + +on: + push: + tags: + - "v*" + +permissions: + contents: write + +jobs: + release: + runs-on: ubuntu-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Install Nix + uses: cachix/install-nix-action@v27 + with: + extra_nix_config: | + experimental-features = nix-command flakes + + - name: Build + run: nix build 'path:.#sotto' + + - name: Package binary + run: | + mkdir -p dist + cp result/bin/sotto dist/sotto + chmod +x dist/sotto + + - name: Create release + uses: softprops/action-gh-release@v2 + with: + files: dist/sotto + generate_release_notes: true diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..9e914d6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,27 @@ +# Build artifacts +/bin/ +/dist/ +/result +/result-* +/apps/sotto/sotto +/apps/sotto/vendor/ + +# Go +**/*.test +**/*.out +coverage.out + +# Editor/system +.DS_Store +.direnv/ +.env +.env.* + +# Runtime artifacts +*.log +*.sock + +# Local planning/session notes +.pi/ +PLAN.md +SESSION.md diff --git a/.golangci.yml b/.golangci.yml new file mode 100644 index 0000000..939b032 --- /dev/null +++ b/.golangci.yml @@ -0,0 +1,26 @@ +run: + timeout: 5m + tests: true + +linters: + disable-all: true + enable: + - errcheck + - gosimple + - govet + - ineffassign + - staticcheck + - typecheck + - unused + - revive + - misspell + +linters-settings: + revive: + rules: + - name: package-comments + disabled: true + +issues: + max-issues-per-linter: 0 + max-same-issues: 0 diff --git a/.just/ci.just b/.just/ci.just new file mode 100644 index 0000000..708c15a --- /dev/null +++ b/.just/ci.just @@ -0,0 +1,11 @@ +ci: + just fmt + just lint + just test + +ci-check: + just fmt-check + just lint + just test + just generate + git diff --exit-code -- apps/sotto/proto/gen/go diff --git a/.just/codegen.just b/.just/codegen.just new file mode 100644 index 0000000..bcb4697 --- /dev/null +++ b/.just/codegen.just @@ -0,0 +1,2 @@ +generate: tools + PATH="{{bin_dir}}:$PATH" "{{bin_dir}}/buf" generate apps/sotto/proto/third_party --template buf.gen.yaml diff --git a/.just/common.just b/.just/common.just new file mode 100644 index 0000000..9cf6e80 --- /dev/null +++ b/.just/common.just @@ -0,0 +1,5 @@ +bin_dir := justfile_directory() + "/bin" +tooling_flake := "path:" + justfile_directory() + +default: + @just --list diff --git a/.just/go.just b/.just/go.just new file mode 100644 index 0000000..6c36f03 --- /dev/null +++ b/.just/go.just @@ -0,0 +1,24 @@ +# Install pinned developer tools into ./bin. +tools: + mkdir -p "{{bin_dir}}" + GOBIN="{{bin_dir}}" go install github.com/bufbuild/buf/cmd/buf@v1.57.2 + GOBIN="{{bin_dir}}" go install google.golang.org/protobuf/cmd/protoc-gen-go@v1.36.6 + GOBIN="{{bin_dir}}" go install google.golang.org/grpc/cmd/protoc-gen-go-grpc@v1.5.1 + GOBIN="{{bin_dir}}" go install github.com/golangci/golangci-lint/cmd/golangci-lint@v1.64.8 + +# Format Go sources (excluding generated vendor paths). +fmt: + bash -euo pipefail -c 'mapfile -t files < <(find apps/sotto -type f -name "*.go" -not -path "*/vendor/*"); if [[ "${#files[@]}" -eq 0 ]]; then exit 0; fi; gofmt -w "${files[@]}"' + +# Check Go formatting only (no changes). +fmt-check: + bash -euo pipefail -c 'mapfile -t files < <(find apps/sotto -type f -name "*.go" -not -path "*/vendor/*"); if [[ "${#files[@]}" -eq 0 ]]; then exit 0; fi; test -z "$(gofmt -l "${files[@]}")"' + +lint: tools + "{{bin_dir}}/golangci-lint" run ./apps/sotto/... + +test: + go test ./apps/sotto/... + +test-integration: + go test -tags=integration ./apps/sotto/internal/audio -run Integration diff --git a/.just/hooks.just b/.just/hooks.just new file mode 100644 index 0000000..b44e8f6 --- /dev/null +++ b/.just/hooks.just @@ -0,0 +1,8 @@ +precommit-install: + nix develop '{{ tooling_flake }}' -c prek install --hook-type pre-commit --hook-type pre-push + +precommit-run: + nix develop '{{ tooling_flake }}' -c prek run --all-files --hook-stage pre-commit + +prepush-run: + nix develop '{{ tooling_flake }}' -c prek run --all-files --hook-stage pre-push diff --git a/.just/nix.just b/.just/nix.just new file mode 100644 index 0000000..6d1d6d6 --- /dev/null +++ b/.just/nix.just @@ -0,0 +1,13 @@ +nix-build-check: + nix build 'path:.#sotto' + +nix-run-help-check: + nix run 'path:.#sotto' -- --help + +fmt-nix: + nix develop '{{ tooling_flake }}' -c nixfmt flake.nix + +lint-nix: + nix develop '{{ tooling_flake }}' -c deadnix flake.nix + nix develop '{{ tooling_flake }}' -c statix check flake.nix + nix develop '{{ tooling_flake }}' -c nixfmt --check flake.nix diff --git a/.just/smoke.just b/.just/smoke.just new file mode 100644 index 0000000..dd2fc90 --- /dev/null +++ b/.just/smoke.just @@ -0,0 +1,11 @@ +smoke-riva-doctor: + sotto doctor + +smoke-riva-manual: + @echo "Run this in an active Hyprland session with local Riva up:" + @echo " 1) sotto doctor" + @echo " 2) sotto toggle # start recording" + @echo " 3) speak a short phrase" + @echo " 4) sotto toggle # stop+commit" + @echo " 5) verify clipboard/paste + cues" + @echo " 6) sotto cancel # verify cancel path" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..9b4dab2 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,57 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + exclude: ^apps/sotto/(vendor/|proto/third_party/) + stages: [pre-commit] + - id: end-of-file-fixer + exclude: ^apps/sotto/(vendor/|proto/third_party/) + stages: [pre-commit] + - id: check-added-large-files + exclude: ^apps/sotto/(vendor/|proto/third_party/) + stages: [pre-commit] + - id: check-merge-conflict + stages: [pre-commit] + + - repo: local + hooks: + - id: go-fmt-check + name: go fmt check + entry: just fmt-check + language: system + pass_filenames: false + always_run: true + stages: [pre-commit] + + - id: lint-nix + name: nix lint + entry: just lint-nix + language: system + pass_filenames: false + always_run: true + stages: [pre-push] + + - id: ci-check + name: ci-check + entry: just ci-check + language: system + pass_filenames: false + always_run: true + stages: [pre-push] + + - id: nix-build-check + name: nix build package + entry: just nix-build-check + language: system + pass_filenames: false + always_run: true + stages: [pre-push] + + - id: nix-run-help-check + name: nix run --help smoke + entry: just nix-run-help-check + language: system + pass_filenames: false + always_run: true + stages: [pre-push] diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 0000000..a6d0caf --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,95 @@ +# sotto Agent Guide + +## Scope + +Applies to the entire `sotto/` repository. + +## Mission + +Deliver a production-grade, local-first ASR CLI: + +- single Go binary +- no daemon/background service +- deterministic toggle/stop/cancel behavior +- strong cleanup and failure safety +- reproducible tooling + packaging + +## Fast Start (read in order) + +1. `README.md` (user-facing behavior) +2. `PLAN.md` (active checklist) +3. `SESSION.md` (what was actually executed) +4. `just --list` (task entrypoints) +5. Only then open the package(s) you need to change + +## Component Map + +| Task area | Primary paths | +| --- | --- | +| CLI + dispatch | `apps/sotto/internal/cli/`, `apps/sotto/internal/app/` | +| IPC + single-instance | `apps/sotto/internal/ipc/` | +| Session/FSM | `apps/sotto/internal/session/`, `apps/sotto/internal/fsm/` | +| Audio capture | `apps/sotto/internal/audio/` | +| Riva streaming | `apps/sotto/internal/riva/`, `apps/sotto/internal/pipeline/` | +| Transcript assembly | `apps/sotto/internal/transcript/` | +| Clipboard/paste output | `apps/sotto/internal/output/`, `apps/sotto/internal/hypr/` | +| Indicator + cues | `apps/sotto/internal/indicator/` | +| Config system | `apps/sotto/internal/config/` | +| Diagnostics/logging | `apps/sotto/internal/doctor/`, `apps/sotto/internal/logging/` | +| Tooling/packaging | `justfile`, `.just/`, `flake.nix`, `.github/workflows/` | +| Proto/codegen | `apps/sotto/proto/third_party/`, `apps/sotto/proto/gen/`, `buf.gen.yaml` | + +## Non-Negotiable Workflow Rules + +1. Read target files before editing. +2. Keep scope tight to the requested behavior. +3. Update `PLAN.md` checklist items only when executed + verified. +4. Log key decisions and commands in `SESSION.md`. +5. Add or update regression tests for behavior changes when feasible. +6. Do not claim runtime verification unless it was actually run. + +## Go Engineering Standards + +Write canonical, idiomatic Go: + +- `gofmt` clean, straightforward naming, small focused functions +- explicit constructors and dependency wiring (no hidden globals) +- `context.Context` first for cancelable/timeout-aware operations +- wrap errors with actionable context; use `errors.Is` for branching +- keep interfaces near consumers; avoid broad shared interfaces +- separate state/policy logic from I/O adapters +- avoid clever abstractions; prefer explicit control flow + +## Testing Policy + +- Use `testing` + `testify` (`require`/`assert`) as needed. +- Prefer real boundaries/resources (temp files, unix sockets, `httptest`, PATH fixtures). +- Do **not** introduce mocking frameworks. +- Riva runtime inference remains manual smoke (non-CI). + +## Config Change Contract (Mandatory) + +Any config-key change must update all relevant locations: + +1. `internal/config/types.go` +2. `internal/config/defaults.go` +3. `internal/config/parser.go` +4. `internal/config/validate.go` (if constraints change) +5. parser/validation tests +6. `docs/configuration.md` and any README examples +7. consuming defaults in external config repos when in scope + +## Required Checks Before Hand-off + +Run and report: + +1. `just ci-check` +2. `nix build 'path:.#sotto'` + +If skipped, state exactly what was skipped, why, and how to run it. + +## Safety + +- Never commit secrets (e.g., `NGC_API_KEY`). +- Avoid destructive shell operations unless explicitly requested. +- Do not edit outside `sotto/` unless explicitly asked. diff --git a/README.md b/README.md index e69de29..58a0f7e 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,114 @@ +# sotto + +[![CI](https://github.com/rbright/sotto/actions/workflows/ci.yml/badge.svg)](https://github.com/rbright/sotto/actions/workflows/ci.yml) + +Local-first speech-to-text CLI. + +`sotto` captures microphone audio, streams to a local ASR backend (Riva by default), assembles transcript text, and commits output to the clipboard with optional paste dispatch. + +## Why this exists + +- single-process CLI (no daemon) +- local-first by default (localhost Riva endpoints) +- explicit state machine (`toggle`, `stop`, `cancel`) +- deterministic config + observable runtime logs + +## Feature summary + +- single-instance command coordination via unix socket +- audio capture via PipeWire/Pulse +- streaming ASR via NVIDIA Riva gRPC +- transcript normalization + optional trailing space +- output adapters: + - clipboard command (`clipboard_cmd`) + - optional paste command override (`paste_cmd`) + - default Hyprland paste path (`hyprctl sendshortcut`) when `paste_cmd` is unset +- indicator backends: + - `hypr` notifications + - `desktop` (freedesktop notifications, e.g. mako) +- optional WAV cue files for start/stop/complete/cancel +- built-in environment diagnostics via `sotto doctor` + +## Platform scope (current) + +`sotto` is currently optimized for **Wayland + Hyprland** workflows. + +- default paste behavior uses `hyprctl` +- `doctor` currently checks for a Hyprland session + +You can still reduce Hyprland coupling by using: + +- `indicator.backend = desktop` +- `paste_cmd = "..."` (explicit command override) + +## Install + +### Nix (recommended) + +```bash +nix build 'path:.#sotto' +nix run 'path:.#sotto' -- --help +``` + +### From source + +```bash +just tools +go test ./apps/sotto/... +go build ./apps/sotto/cmd/sotto +``` + +## Quickstart + +```bash +sotto doctor +sotto toggle # start +sotto toggle # stop + commit +``` + +Core commands: + +```bash +sotto toggle +sotto stop +sotto cancel +sotto status +sotto devices +sotto doctor +sotto version +``` + +## Configuration + +Config resolution order: + +1. `--config ` +2. `$XDG_CONFIG_HOME/sotto/config.jsonc` +3. `~/.config/sotto/config.jsonc` + +Compatibility note: + +- if the default `.jsonc` file is missing, sotto will fall back to legacy `config.conf` automatically. + +See full key reference and examples in: + +- [`docs/configuration.md`](./docs/configuration.md) + +## Verification + +Required local gate before hand-off: + +```bash +just ci-check +nix build 'path:.#sotto' +``` + +Manual/local runtime checklist: + +- [`docs/verification.md`](./docs/verification.md) + +## Architecture and design docs + +- [`docs/architecture.md`](./docs/architecture.md) +- [`docs/modularity.md`](./docs/modularity.md) +- [`AGENTS.md`](./AGENTS.md) diff --git a/apps/sotto/cmd/sotto/main.go b/apps/sotto/cmd/sotto/main.go new file mode 100644 index 0000000..e62abdf --- /dev/null +++ b/apps/sotto/cmd/sotto/main.go @@ -0,0 +1,20 @@ +// Package main provides the sotto CLI process entrypoint. +package main + +import ( + "context" + "os" + "os/signal" + "syscall" + + "github.com/rbright/sotto/internal/app" +) + +// main wires process signal handling to the application runner. +func main() { + ctx, stop := signal.NotifyContext(context.Background(), os.Interrupt, syscall.SIGTERM) + defer stop() + + exitCode := app.Execute(ctx, os.Args[1:], os.Stdout, os.Stderr) + os.Exit(exitCode) +} diff --git a/apps/sotto/cmd/sotto/main_test.go b/apps/sotto/cmd/sotto/main_test.go new file mode 100644 index 0000000..4746df9 --- /dev/null +++ b/apps/sotto/cmd/sotto/main_test.go @@ -0,0 +1,58 @@ +package main + +import ( + "os" + "os/exec" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMainHelp(t *testing.T) { + output, err := runMainSubprocess(t, "--help") + require.NoError(t, err, string(output)) + require.Contains(t, string(output), "Usage:") +} + +func TestMainInvalidCommandExitsNonZero(t *testing.T) { + output, err := runMainSubprocess(t, "not-a-command") + require.Error(t, err) + + exitErr, ok := err.(*exec.ExitError) + require.True(t, ok) + require.Equal(t, 2, exitErr.ExitCode()) + require.Contains(t, string(output), "unknown command") +} + +func TestMainHelperProcess(t *testing.T) { + if os.Getenv("GO_WANT_HELPER_PROCESS") != "1" { + return + } + + args := os.Args + dashIndex := -1 + for i, arg := range args { + if arg == "--" { + dashIndex = i + break + } + } + + os.Args = []string{"sotto"} + if dashIndex >= 0 && dashIndex+1 < len(args) { + os.Args = append(os.Args, args[dashIndex+1:]...) + } + + main() +} + +func runMainSubprocess(t *testing.T, args ...string) ([]byte, error) { + t.Helper() + + cmdArgs := []string{"-test.run=TestMainHelperProcess", "--"} + cmdArgs = append(cmdArgs, args...) + + cmd := exec.Command(os.Args[0], cmdArgs...) + cmd.Env = append(os.Environ(), "GO_WANT_HELPER_PROCESS=1") + return cmd.CombinedOutput() +} diff --git a/apps/sotto/go.mod b/apps/sotto/go.mod new file mode 100644 index 0000000..41f7850 --- /dev/null +++ b/apps/sotto/go.mod @@ -0,0 +1,20 @@ +module github.com/rbright/sotto + +go 1.25.5 + +require ( + github.com/jfreymuth/pulse v0.1.1 + github.com/stretchr/testify v1.10.0 + google.golang.org/grpc v1.79.1 + google.golang.org/protobuf v1.36.11 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/net v0.48.0 // indirect + golang.org/x/sys v0.39.0 // indirect + golang.org/x/text v0.32.0 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/apps/sotto/go.sum b/apps/sotto/go.sum new file mode 100644 index 0000000..6f59a99 --- /dev/null +++ b/apps/sotto/go.sum @@ -0,0 +1,50 @@ +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/jfreymuth/pulse v0.1.1 h1:9WLNBNCijmtZ14ZJpatgJPu/NjwAl3TIKItSFnTh+9A= +github.com/jfreymuth/pulse v0.1.1/go.mod h1:cpYspI6YljhkUf1WLXLLDmeaaPFc3CnGLjDZf9dZ4no= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= +go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= +go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0= +go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs= +go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18= +go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE= +go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8= +go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= +go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= +go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= +golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= +golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/sys v0.39.0 h1:CvCKL8MeisomCi6qNZ+wbb0DN9E5AATixKsvNtMoMFk= +golang.org/x/sys v0.39.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= +google.golang.org/grpc v1.79.1 h1:zGhSi45ODB9/p3VAawt9a+O/MULLl9dpizzNNpq7flY= +google.golang.org/grpc v1.79.1/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/apps/sotto/internal/app/app.go b/apps/sotto/internal/app/app.go new file mode 100644 index 0000000..bba924f --- /dev/null +++ b/apps/sotto/internal/app/app.go @@ -0,0 +1,349 @@ +// Package app wires CLI commands to runtime components and top-level execution flow. +package app + +import ( + "context" + "errors" + "fmt" + "io" + "log/slog" + "os" + "strings" + "syscall" + "time" + + "github.com/rbright/sotto/internal/audio" + "github.com/rbright/sotto/internal/cli" + "github.com/rbright/sotto/internal/config" + "github.com/rbright/sotto/internal/doctor" + "github.com/rbright/sotto/internal/indicator" + "github.com/rbright/sotto/internal/ipc" + "github.com/rbright/sotto/internal/logging" + "github.com/rbright/sotto/internal/output" + "github.com/rbright/sotto/internal/pipeline" + "github.com/rbright/sotto/internal/session" + "github.com/rbright/sotto/internal/version" +) + +// Runner holds process-level dependencies used by command handlers. +type Runner struct { + Stdout io.Writer + Stderr io.Writer + Logger *slog.Logger +} + +// Execute is the package entrypoint used by cmd/sotto/main.go. +func Execute(ctx context.Context, args []string, stdout, stderr io.Writer) int { + r := Runner{Stdout: stdout, Stderr: stderr} + return r.Execute(ctx, args) +} + +// Execute parses CLI arguments, loads config/logging, and dispatches a command. +func (r Runner) Execute(ctx context.Context, args []string) int { + parsed, err := cli.Parse(args) + if err != nil { + fmt.Fprintf(r.Stderr, "error: %v\n\n", err) + fmt.Fprint(r.Stderr, cli.HelpText("sotto")) + return 2 + } + + if parsed.ShowHelp { + fmt.Fprint(r.Stdout, cli.HelpText("sotto")) + return 0 + } + + if parsed.Command == cli.CommandVersion { + fmt.Fprintln(r.Stdout, version.String()) + return 0 + } + + logRuntime, err := logging.New() + if err != nil { + fmt.Fprintf(r.Stderr, "error: setup logging: %v\n", err) + return 1 + } + defer func() { _ = logRuntime.Close() }() + + logger := r.Logger + if logger == nil { + logger = logRuntime.Logger + } + + cfgLoaded, err := config.Load(parsed.ConfigPath) + if err != nil { + fmt.Fprintf(r.Stderr, "error: %v\n", err) + logger.Error("load config failed", "error", err.Error()) + return 1 + } + for _, w := range cfgLoaded.Warnings { + msg := w.Message + if w.Line > 0 { + msg = fmt.Sprintf("line %d: %s", w.Line, w.Message) + } + fmt.Fprintf(r.Stderr, "warning: %s\n", msg) + logger.Warn("config warning", "line", w.Line, "message", w.Message) + } + + if speechPlan, _, err := config.BuildSpeechPhrases(cfgLoaded.Config); err == nil { + logger.Debug("speech context plan", "phrase_count", len(speechPlan), "phrases", speechPlan) + } + + logger.Info("command start", + "command", parsed.Command, + "config", cfgLoaded.Path, + "log", logRuntime.Path, + ) + + switch parsed.Command { + case cli.CommandDoctor: + report := doctor.Run(cfgLoaded) + fmt.Fprintln(r.Stdout, report.String()) + if report.OK() { + return 0 + } + return 1 + case cli.CommandDevices: + return r.commandDevices(ctx) + case cli.CommandStatus: + return r.commandStatus(ctx) + case cli.CommandStop: + return r.forwardOrFail(ctx, "stop") + case cli.CommandCancel: + return r.forwardOrFail(ctx, "cancel") + case cli.CommandToggle: + return r.commandToggle(ctx, cfgLoaded.Config, logger) + default: + fmt.Fprintf(r.Stderr, "error: unsupported command %q\n", parsed.Command) + return 2 + } +} + +// commandDevices prints discovered input devices and key availability metadata. +func (r Runner) commandDevices(ctx context.Context) int { + devices, err := audio.ListDevices(ctx) + if err != nil { + fmt.Fprintf(r.Stderr, "error: %v\n", err) + return 1 + } + if len(devices) == 0 { + fmt.Fprintln(r.Stdout, "no audio devices found") + return 1 + } + + for _, device := range devices { + defaultMark := " " + if device.Default { + defaultMark = "*" + } + availability := "yes" + if !device.Available { + availability = "no" + } + muted := "no" + if device.Muted { + muted = "yes" + } + fmt.Fprintf( + r.Stdout, + "%s id=%s | description=%q | state=%s | available=%s | muted=%s\n", + defaultMark, + device.ID, + device.Description, + device.State, + availability, + muted, + ) + } + + return 0 +} + +// commandStatus queries the active owner (if any) and prints session state. +func (r Runner) commandStatus(ctx context.Context) int { + socketPath, err := ipc.RuntimeSocketPath() + if err != nil { + fmt.Fprintln(r.Stdout, "idle") + return 0 + } + + resp, handled, err := tryForward(ctx, socketPath, "status") + if handled { + if err != nil { + fmt.Fprintf(r.Stderr, "error: %v\n", err) + return 1 + } + if resp.State == "" { + resp.State = "idle" + } + fmt.Fprintln(r.Stdout, resp.State) + return 0 + } + + fmt.Fprintln(r.Stdout, "idle") + return 0 +} + +// forwardOrFail forwards a command to the active owner and fails when no owner exists. +func (r Runner) forwardOrFail(ctx context.Context, command string) int { + socketPath, err := ipc.RuntimeSocketPath() + if err != nil { + fmt.Fprintf(r.Stderr, "error: %v\n", err) + return 1 + } + + resp, handled, err := tryForward(ctx, socketPath, command) + if !handled { + fmt.Fprintf(r.Stderr, "error: no active sotto session\n") + return 1 + } + if err != nil { + fmt.Fprintf(r.Stderr, "error: %v\n", err) + return 1 + } + if resp.Message != "" { + fmt.Fprintln(r.Stdout, resp.Message) + } + return 0 +} + +// commandToggle starts a new owner session or forwards toggle to an existing owner. +func (r Runner) commandToggle(ctx context.Context, cfg config.Config, logger *slog.Logger) int { + socketPath, err := ipc.RuntimeSocketPath() + if err != nil { + fmt.Fprintf(r.Stderr, "error: %v\n", err) + return 1 + } + + resp, handled, err := tryForward(ctx, socketPath, "toggle") + if handled { + if err != nil { + fmt.Fprintf(r.Stderr, "error: %v\n", err) + return 1 + } + if resp.Message != "" { + fmt.Fprintln(r.Stdout, resp.Message) + } + return 0 + } + + listener, err := ipc.Acquire(ctx, socketPath, 180*time.Millisecond, 8, nil) + if err != nil { + if errors.Is(err, ipc.ErrAlreadyRunning) { + resp, _, forwardErr := tryForward(ctx, socketPath, "toggle") + if forwardErr != nil { + fmt.Fprintf(r.Stderr, "error: %v\n", forwardErr) + return 1 + } + if resp.Message != "" { + fmt.Fprintln(r.Stdout, resp.Message) + } + return 0 + } + fmt.Fprintf(r.Stderr, "error: %v\n", err) + return 1 + } + defer func() { + _ = listener.Close() + _ = os.Remove(socketPath) + }() + + transcriber := pipeline.NewTranscriber(cfg, logger) + committer := output.NewCommitter(cfg, logger) + indicatorCtl := indicator.NewHyprNotify(cfg.Indicator, logger) + controller := session.NewController(logger, transcriber, committer, indicatorCtl) + + serverCtx, serverCancel := context.WithCancel(ctx) + defer serverCancel() + + serverErrCh := make(chan error, 1) + go func() { + serverErrCh <- ipc.Serve(serverCtx, listener, controller) + }() + + result := controller.Run(ctx) + serverCancel() + if serverErr := <-serverErrCh; serverErr != nil { + fmt.Fprintf(r.Stderr, "error: ipc server failed: %v\n", serverErr) + return 1 + } + + logSessionResult(logger, result) + + if result.Cancelled { + fmt.Fprintln(r.Stdout, "cancelled") + return 0 + } + if result.Err != nil { + fmt.Fprintf(r.Stderr, "error: %v\n", result.Err) + return 1 + } + if strings.TrimSpace(result.Transcript) != "" { + fmt.Fprintln(r.Stdout, strings.TrimSpace(result.Transcript)) + } + + return 0 +} + +// logSessionResult writes normalized session metrics into the runtime logger. +func logSessionResult(logger *slog.Logger, result session.Result) { + if logger == nil { + return + } + fields := []any{ + "state", result.State, + "cancelled", result.Cancelled, + "started_at", result.StartedAt.Format(time.RFC3339Nano), + "finished_at", result.FinishedAt.Format(time.RFC3339Nano), + "duration_ms", result.FinishedAt.Sub(result.StartedAt).Milliseconds(), + "audio_device", result.AudioDevice, + "bytes_captured", result.BytesCaptured, + "transcript_length", len(result.Transcript), + "grpc_latency_ms", result.GRPCLatency.Milliseconds(), + "focused_monitor", result.FocusedMonitor, + } + + if result.Err != nil { + logger.Error("session failed", append(fields, "error", result.Err.Error())...) + return + } + logger.Info("session complete", fields...) +} + +// tryForward attempts to send a command to an existing owner and classifies outcome. +// +// handled=false means there was no active owner to handle the request. +func tryForward(ctx context.Context, socketPath string, command string) (ipc.Response, bool, error) { + resp, err := ipc.Send(ctx, socketPath, ipc.Request{Command: command}, 220*time.Millisecond) + if err == nil { + if resp.OK { + return resp, true, nil + } + return resp, true, errors.New(resp.Error) + } + + if isSocketMissing(err) { + return ipc.Response{}, false, nil + } + if isConnectionRefused(err) { + return ipc.Response{}, false, nil + } + + return ipc.Response{}, true, fmt.Errorf("forward command %q: %w", command, err) +} + +// isSocketMissing reports whether forwarding failed because the owner socket is absent. +func isSocketMissing(err error) bool { + if err == nil { + return false + } + return errors.Is(err, os.ErrNotExist) || + strings.Contains(err.Error(), "no such file or directory") +} + +// isConnectionRefused reports whether forwarding failed because no owner is listening. +func isConnectionRefused(err error) bool { + if err == nil { + return false + } + return errors.Is(err, syscall.ECONNREFUSED) +} diff --git a/apps/sotto/internal/app/app_test.go b/apps/sotto/internal/app/app_test.go new file mode 100644 index 0000000..6f82069 --- /dev/null +++ b/apps/sotto/internal/app/app_test.go @@ -0,0 +1,239 @@ +package app + +import ( + "bytes" + "context" + "net" + "os" + "path/filepath" + "testing" + + "github.com/rbright/sotto/internal/ipc" + "github.com/stretchr/testify/require" +) + +func TestExecuteHelp(t *testing.T) { + var stdout bytes.Buffer + var stderr bytes.Buffer + + exitCode := Execute(context.Background(), []string{"--help"}, &stdout, &stderr) + require.Equal(t, 0, exitCode) + require.Contains(t, stdout.String(), "Usage:") + require.Empty(t, stderr.String()) +} + +func TestExecuteVersion(t *testing.T) { + var stdout bytes.Buffer + var stderr bytes.Buffer + + exitCode := Execute(context.Background(), []string{"version"}, &stdout, &stderr) + require.Equal(t, 0, exitCode) + require.Contains(t, stdout.String(), "sotto") + require.Empty(t, stderr.String()) +} + +func TestExecuteUnknownCommand(t *testing.T) { + var stdout bytes.Buffer + var stderr bytes.Buffer + + exitCode := Execute(context.Background(), []string{"definitely-not-a-command"}, &stdout, &stderr) + require.Equal(t, 2, exitCode) + require.Contains(t, stderr.String(), "unknown command") + require.Contains(t, stderr.String(), "Usage:") +} + +func TestRunnerStatusIdleWhenSocketUnavailable(t *testing.T) { + paths := setupRunnerEnv(t) + + var stdout bytes.Buffer + var stderr bytes.Buffer + runner := Runner{Stdout: &stdout, Stderr: &stderr} + + exitCode := runner.Execute(context.Background(), []string{"--config", paths.configPath, "status"}) + require.Equal(t, 0, exitCode) + require.Equal(t, "idle\n", stdout.String()) + require.Empty(t, stderr.String()) +} + +func TestRunnerStopReturnsNoActiveSession(t *testing.T) { + paths := setupRunnerEnv(t) + + var stdout bytes.Buffer + var stderr bytes.Buffer + runner := Runner{Stdout: &stdout, Stderr: &stderr} + + exitCode := runner.Execute(context.Background(), []string{"--config", paths.configPath, "stop"}) + require.Equal(t, 1, exitCode) + require.Contains(t, stderr.String(), "no active sotto session") +} + +func TestRunnerForwardsCommandsToActiveSession(t *testing.T) { + paths := setupRunnerEnv(t) + commands := make(chan string, 8) + + shutdown := startIPCServerForRunnerTest(t, filepath.Join(paths.runtimeDir, "sotto.sock"), func(_ context.Context, req ipc.Request) ipc.Response { + commands <- req.Command + switch req.Command { + case "status": + return ipc.Response{OK: true, State: "recording"} + case "stop", "cancel", "toggle": + return ipc.Response{OK: true, Message: req.Command + " handled"} + default: + return ipc.Response{OK: false, Error: "unsupported"} + } + }) + defer shutdown() + + runner := Runner{Stdout: &bytes.Buffer{}, Stderr: &bytes.Buffer{}} + + for _, cmd := range []string{"status", "stop", "cancel", "toggle"} { + stdout := &bytes.Buffer{} + stderr := &bytes.Buffer{} + runner.Stdout = stdout + runner.Stderr = stderr + + exitCode := runner.Execute(context.Background(), []string{"--config", paths.configPath, cmd}) + require.Equal(t, 0, exitCode, cmd) + require.Empty(t, stderr.String(), cmd) + } + + got := []string{<-commands, <-commands, <-commands, <-commands} + require.ElementsMatch(t, []string{"status", "stop", "cancel", "toggle"}, got) +} + +func TestTryForwardSuccessAndFailureResponses(t *testing.T) { + runtimeDir := t.TempDir() + socketPath := filepath.Join(runtimeDir, "sotto.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + + serverCtx, cancelServer := context.WithCancel(context.Background()) + serverDone := make(chan error, 1) + go func() { + serverDone <- ipc.Serve(serverCtx, listener, ipc.HandlerFunc(func(_ context.Context, req ipc.Request) ipc.Response { + switch req.Command { + case "status": + return ipc.Response{OK: true, State: "recording"} + default: + return ipc.Response{OK: false, Error: "unsupported"} + } + })) + }() + + resp, handled, err := tryForward(context.Background(), socketPath, "status") + require.True(t, handled) + require.NoError(t, err) + require.Equal(t, "recording", resp.State) + + _, handled, err = tryForward(context.Background(), socketPath, "cancel") + require.True(t, handled) + require.Error(t, err) + require.Contains(t, err.Error(), "unsupported") + + cancelServer() + require.NoError(t, <-serverDone) +} + +func TestTryForwardDoesNotRemoveSocketPathOnForwardFailure(t *testing.T) { + socketPath := filepath.Join(t.TempDir(), "sotto.sock") + require.NoError(t, os.WriteFile(socketPath, []byte("stale"), 0o600)) + + _, handled, err := tryForward(context.Background(), socketPath, "status") + require.False(t, handled) + require.NoError(t, err) + + _, statErr := os.Stat(socketPath) + require.NoError(t, statErr) +} + +func TestTryForwardTreatsReadFailuresAsHandledErrors(t *testing.T) { + socketPath := filepath.Join(t.TempDir(), "sotto.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + + done := make(chan struct{}) + go func() { + defer close(done) + conn, acceptErr := listener.Accept() + if acceptErr == nil { + _ = conn.Close() + } + }() + + _, handled, err := tryForward(context.Background(), socketPath, "status") + require.True(t, handled) + require.Error(t, err) + require.Contains(t, err.Error(), "forward command \"status\":") + + <-done + _, statErr := os.Stat(socketPath) + require.NoError(t, statErr) + require.NoError(t, listener.Close()) +} + +func TestRunnerDoctorCommandDispatchesAndPrintsReport(t *testing.T) { + paths := setupRunnerEnv(t) + t.Setenv("XDG_SESSION_TYPE", "x11") + t.Setenv("HYPRLAND_INSTANCE_SIGNATURE", "") + + var stdout bytes.Buffer + var stderr bytes.Buffer + runner := Runner{Stdout: &stdout, Stderr: &stderr} + + exitCode := runner.Execute(context.Background(), []string{"--config", paths.configPath, "doctor"}) + require.Equal(t, 1, exitCode) + require.Contains(t, stdout.String(), "config: loaded") + require.Contains(t, stdout.String(), "XDG_SESSION_TYPE") +} + +func TestRunnerDevicesCommandDispatches(t *testing.T) { + paths := setupRunnerEnv(t) + t.Setenv("PULSE_SERVER", "unix:/tmp/definitely-missing-pulse-server") + + var stdout bytes.Buffer + var stderr bytes.Buffer + runner := Runner{Stdout: &stdout, Stderr: &stderr} + + exitCode := runner.Execute(context.Background(), []string{"--config", paths.configPath, "devices"}) + require.Equal(t, 1, exitCode) + require.Contains(t, stderr.String(), "error:") +} + +type runnerPaths struct { + configPath string + runtimeDir string +} + +func setupRunnerEnv(t *testing.T) runnerPaths { + t.Helper() + + xdgStateHome := t.TempDir() + runtimeDir := t.TempDir() + t.Setenv("XDG_STATE_HOME", xdgStateHome) + t.Setenv("XDG_RUNTIME_DIR", runtimeDir) + + configPath := filepath.Join(t.TempDir(), "config.conf") + require.NoError(t, os.WriteFile(configPath, []byte("\n"), 0o600)) + + return runnerPaths{configPath: configPath, runtimeDir: runtimeDir} +} + +func startIPCServerForRunnerTest(t *testing.T, socketPath string, handler func(context.Context, ipc.Request) ipc.Response) func() { + t.Helper() + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan error, 1) + go func() { + done <- ipc.Serve(ctx, listener, ipc.HandlerFunc(handler)) + }() + + return func() { + cancel() + require.NoError(t, <-done) + } +} diff --git a/apps/sotto/internal/audio/pulse.go b/apps/sotto/internal/audio/pulse.go new file mode 100644 index 0000000..1accdf1 --- /dev/null +++ b/apps/sotto/internal/audio/pulse.go @@ -0,0 +1,400 @@ +// Package audio handles device discovery, selection, and PCM capture streams. +package audio + +import ( + "context" + "errors" + "fmt" + "io" + "strings" + "sync" + "sync/atomic" + + "github.com/jfreymuth/pulse" + pulseproto "github.com/jfreymuth/pulse/proto" +) + +const ( + chunkSizeBytes = 640 // 20ms @ 16kHz mono s16 +) + +// Device describes one Pulse input source surfaced to sotto. +type Device struct { + ID string + Description string + State string + Available bool + Muted bool + Default bool +} + +// Selection is the resolved capture source plus optional fallback warning context. +type Selection struct { + Device Device + Warning string + Fallback bool +} + +// ListDevices returns available Pulse input sources with default/availability metadata. +func ListDevices(_ context.Context) ([]Device, error) { + client, err := pulse.NewClient( + pulse.ClientApplicationName("sotto"), + pulse.ClientApplicationIconName("audio-input-microphone"), + ) + if err != nil { + return nil, fmt.Errorf("connect pulse server: %w", err) + } + defer client.Close() + + defaultSource, err := client.DefaultSource() + if err != nil { + return nil, fmt.Errorf("read default source: %w", err) + } + defaultID := defaultSource.ID() + + var sourceInfos pulseproto.GetSourceInfoListReply + if err := client.RawRequest(&pulseproto.GetSourceInfoList{}, &sourceInfos); err != nil { + return nil, fmt.Errorf("list sources: %w", err) + } + + devices := make([]Device, 0, len(sourceInfos)) + for _, source := range sourceInfos { + if source == nil { + continue + } + devices = append(devices, Device{ + ID: source.SourceName, + Description: source.Device, + State: sourceStateString(source.State), + Available: sourceAvailable(source), + Muted: source.Mute, + Default: source.SourceName == defaultID, + }) + } + return devices, nil +} + +// SelectDevice resolves audio.input/audio.fallback preferences against live devices. +func SelectDevice(ctx context.Context, input string, fallback string) (Selection, error) { + devices, err := ListDevices(ctx) + if err != nil { + return Selection{}, err + } + return selectDeviceFromList(devices, input, fallback) +} + +// selectDeviceFromList applies selection policy to a pre-fetched device list. +func selectDeviceFromList(devices []Device, input string, fallback string) (Selection, error) { + if len(devices) == 0 { + return Selection{}, errors.New("no audio input devices found") + } + + var ( + defaultDevice *Device + byInput *Device + byFallback *Device + ) + + input = strings.TrimSpace(strings.ToLower(input)) + fallback = strings.TrimSpace(strings.ToLower(fallback)) + + for i := range devices { + dev := &devices[i] + if dev.Default { + defaultDevice = dev + } + if byInput == nil && input != "" && input != "default" && deviceMatches(*dev, input) { + byInput = dev + } + if byFallback == nil && fallback != "" && fallback != "default" && deviceMatches(*dev, fallback) { + byFallback = dev + } + } + + chooseDefault := func() (*Device, error) { + if defaultDevice == nil { + return nil, errors.New("default audio source is unavailable") + } + return defaultDevice, nil + } + + selectPrimary := func() (*Device, error) { + if input == "" || input == "default" { + return chooseDefault() + } + if byInput != nil { + return byInput, nil + } + return nil, fmt.Errorf("audio.input %q did not match any device", input) + } + + primary, err := selectPrimary() + if err != nil { + return Selection{}, err + } + if primary.Available && !primary.Muted { + return Selection{Device: *primary}, nil + } + + primaryReason := "unavailable" + if primary.Muted { + primaryReason = "muted" + } + + fallbackDevice := primary + if fallback != "" && fallback != "default" { + if byFallback == nil { + return Selection{}, fmt.Errorf("primary input %q is %s and fallback %q not found", primary.ID, primaryReason, fallback) + } + fallbackDevice = byFallback + } else { + d, derr := chooseDefault() + if derr != nil { + return Selection{}, fmt.Errorf("primary input %q is %s and no usable fallback: %w", primary.ID, primaryReason, derr) + } + fallbackDevice = d + } + + if !fallbackDevice.Available { + return Selection{}, fmt.Errorf("audio fallback device %q is not available", fallbackDevice.ID) + } + if fallbackDevice.Muted { + return Selection{}, fmt.Errorf("audio fallback device %q is muted", fallbackDevice.ID) + } + + return Selection{ + Device: *fallbackDevice, + Warning: fmt.Sprintf("audio.input %q is %s; falling back to %q", primary.ID, primaryReason, fallbackDevice.ID), + Fallback: primary.ID != fallbackDevice.ID, + }, nil +} + +// deviceMatches reports whether a search term matches a device id or description. +func deviceMatches(device Device, term string) bool { + if term == "" { + return false + } + id := strings.ToLower(device.ID) + desc := strings.ToLower(device.Description) + return strings.Contains(id, term) || strings.Contains(desc, term) +} + +// Capture streams fixed-size PCM chunks from one selected Pulse source. +type Capture struct { + device Device + + client *pulse.Client + stream *pulse.RecordStream + + chunks chan []byte + stopCh chan struct{} + + mu sync.Mutex + pending []byte + rawPCM []byte + stopped bool + + inflight sync.WaitGroup + bytes atomic.Int64 +} + +// StartCapture creates and starts a 16kHz mono s16 record stream. +func StartCapture(ctx context.Context, selected Device) (*Capture, error) { + client, err := pulse.NewClient( + pulse.ClientApplicationName("sotto"), + pulse.ClientApplicationIconName("audio-input-microphone"), + ) + if err != nil { + return nil, fmt.Errorf("connect pulse server: %w", err) + } + + source, err := client.SourceByID(selected.ID) + if err != nil { + client.Close() + return nil, fmt.Errorf("resolve source %q: %w", selected.ID, err) + } + + capture := &Capture{ + device: selected, + client: client, + chunks: make(chan []byte, 128), + stopCh: make(chan struct{}), + } + + writer := pulse.NewWriter(writerFunc(capture.onPCM), pulseproto.FormatInt16LE) + stream, err := client.NewRecord( + writer, + pulse.RecordSource(source), + pulse.RecordMono, + pulse.RecordSampleRate(16000), + pulse.RecordBufferFragmentSize(chunkSizeBytes), + pulse.RecordMediaName("sotto dictation"), + ) + if err != nil { + capture.Close() + return nil, fmt.Errorf("create pulse record stream: %w", err) + } + + capture.stream = stream + stream.Start() + + go func() { + <-ctx.Done() + _ = capture.Stop() + }() + + return capture, nil +} + +// Device returns capture metadata for logging and diagnostics. +func (c *Capture) Device() Device { + return c.device +} + +// Chunks returns the PCM stream as fixed-size byte slices. +func (c *Capture) Chunks() <-chan []byte { + return c.chunks +} + +// BytesCaptured reports total bytes accepted from Pulse. +func (c *Capture) BytesCaptured() int64 { + return c.bytes.Load() +} + +// RawPCM returns a snapshot of all captured raw PCM bytes. +func (c *Capture) RawPCM() []byte { + c.mu.Lock() + defer c.mu.Unlock() + out := make([]byte, len(c.rawPCM)) + copy(out, c.rawPCM) + return out +} + +// Stop halts the stream, flushes residual PCM, and closes Chunks exactly once. +func (c *Capture) Stop() error { + c.mu.Lock() + if c.stopped { + c.mu.Unlock() + return nil + } + c.stopped = true + close(c.stopCh) + c.mu.Unlock() + + if c.stream != nil { + c.stream.Stop() + c.stream.Close() + } + if c.client != nil { + c.client.Close() + } + + c.inflight.Wait() + + c.mu.Lock() + pending := append([]byte(nil), c.pending...) + c.pending = nil + c.mu.Unlock() + + if len(pending) > 0 { + chunk := make([]byte, len(pending)) + copy(chunk, pending) + select { + case c.chunks <- chunk: + default: + } + } + + close(c.chunks) + return nil +} + +// Close is a convenience alias for Stop. +func (c *Capture) Close() { + _ = c.Stop() +} + +// onPCM receives raw Pulse frames and emits chunkSizeBytes slices to c.chunks. +func (c *Capture) onPCM(buffer []byte) (int, error) { + if len(buffer) == 0 { + return 0, nil + } + + select { + case <-c.stopCh: + return 0, io.EOF + default: + } + + c.mu.Lock() + if c.stopped { + c.mu.Unlock() + return 0, io.EOF + } + // Guard Add under the same mutex as c.stopped to avoid Add/Wait races. + c.inflight.Add(1) + + c.rawPCM = append(c.rawPCM, buffer...) + c.pending = append(c.pending, buffer...) + + chunks := make([][]byte, 0, len(c.pending)/chunkSizeBytes) + for len(c.pending) >= chunkSizeBytes { + chunk := make([]byte, chunkSizeBytes) + copy(chunk, c.pending[:chunkSizeBytes]) + c.pending = c.pending[chunkSizeBytes:] + chunks = append(chunks, chunk) + } + c.mu.Unlock() + defer c.inflight.Done() + + c.bytes.Add(int64(len(buffer))) + + for _, chunk := range chunks { + select { + case <-c.stopCh: + return 0, io.EOF + case c.chunks <- chunk: + } + } + + return len(buffer), nil +} + +// writerFunc adapts a function to io.Writer for pulse.NewWriter. +type writerFunc func([]byte) (int, error) + +func (f writerFunc) Write(b []byte) (int, error) { + return f(b) +} + +// sourceStateString maps Pulse source state constants to human-readable values. +func sourceStateString(state uint32) string { + switch state { + case 0: + return "running" + case 1: + return "idle" + case 2: + return "suspended" + default: + return fmt.Sprintf("unknown(%d)", state) + } +} + +// sourceAvailable maps Pulse source port availability to a simple boolean. +func sourceAvailable(source *pulseproto.GetSourceInfoReply) bool { + if source == nil { + return false + } + if len(source.Ports) == 0 { + return true + } + for _, port := range source.Ports { + if port.Name != source.ActivePortName { + continue + } + // PulseAudio values: unknown=0, no=1, yes=2. + return port.Available == 0 || port.Available == 2 + } + return true +} diff --git a/apps/sotto/internal/audio/pulse_integration_test.go b/apps/sotto/internal/audio/pulse_integration_test.go new file mode 100644 index 0000000..74a72be --- /dev/null +++ b/apps/sotto/internal/audio/pulse_integration_test.go @@ -0,0 +1,20 @@ +//go:build integration + +package audio + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestListDevicesIntegration(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + devices, err := ListDevices(ctx) + require.NoError(t, err) + require.NotEmpty(t, devices) +} diff --git a/apps/sotto/internal/audio/pulse_test.go b/apps/sotto/internal/audio/pulse_test.go new file mode 100644 index 0000000..8ffa322 --- /dev/null +++ b/apps/sotto/internal/audio/pulse_test.go @@ -0,0 +1,158 @@ +package audio + +import ( + "io" + "reflect" + "testing" + + pulseproto "github.com/jfreymuth/pulse/proto" + "github.com/stretchr/testify/require" +) + +func TestSelectDeviceFromListPrimaryDefault(t *testing.T) { + devices := []Device{ + {ID: "elgato", Description: "Elgato Wave 3 Mono", Available: true, Default: true}, + {ID: "sony", Description: "Sony WH-1000XM6", Available: true}, + } + + selection, err := selectDeviceFromList(devices, "default", "default") + require.NoError(t, err) + require.Equal(t, "elgato", selection.Device.ID) + require.Empty(t, selection.Warning) +} + +func TestSelectDeviceFromListMutedPrimaryUsesFallback(t *testing.T) { + devices := []Device{ + {ID: "elgato", Description: "Elgato Wave 3 Mono", Available: true, Muted: true, Default: true}, + {ID: "sony", Description: "Sony WH-1000XM6", Available: true}, + } + + selection, err := selectDeviceFromList(devices, "elgato", "sony") + require.NoError(t, err) + require.Equal(t, "sony", selection.Device.ID) + require.Contains(t, selection.Warning, "muted") + require.True(t, selection.Fallback) +} + +func TestSelectDeviceFromListFailsWhenSelectedAndFallbackMuted(t *testing.T) { + devices := []Device{ + {ID: "elgato", Description: "Elgato Wave 3 Mono", Available: true, Muted: true, Default: true}, + } + + _, err := selectDeviceFromList(devices, "default", "default") + require.Error(t, err) + require.Contains(t, err.Error(), "muted") +} + +func TestSelectDeviceFromListUnknownInput(t *testing.T) { + devices := []Device{{ID: "elgato", Description: "Elgato Wave 3 Mono", Available: true, Default: true}} + + _, err := selectDeviceFromList(devices, "missing", "default") + require.Error(t, err) + require.Contains(t, err.Error(), "did not match") +} + +func TestDeviceMatchesByIDAndDescription(t *testing.T) { + dev := Device{ID: "alsa_input.usb-elgato", Description: "Elgato Wave 3 Mono"} + require.True(t, deviceMatches(dev, "elgato")) + require.True(t, deviceMatches(dev, "wave 3")) + require.False(t, deviceMatches(dev, "missing")) +} + +func TestSourceStateString(t *testing.T) { + require.Equal(t, "running", sourceStateString(0)) + require.Equal(t, "idle", sourceStateString(1)) + require.Equal(t, "suspended", sourceStateString(2)) + require.Equal(t, "unknown(99)", sourceStateString(99)) +} + +func TestSourceAvailable(t *testing.T) { + require.False(t, sourceAvailable(nil)) + require.True(t, sourceAvailable(&pulseproto.GetSourceInfoReply{})) // no ports => available + + available := &pulseproto.GetSourceInfoReply{ActivePortName: "mic"} + setSourcePorts(t, available, []sourcePort{{name: "mic", available: 2}}) + require.True(t, sourceAvailable(available)) + + notAvailable := &pulseproto.GetSourceInfoReply{ActivePortName: "mic"} + setSourcePorts(t, notAvailable, []sourcePort{{name: "mic", available: 1}}) + require.False(t, sourceAvailable(notAvailable)) +} + +func TestWriterFuncDelegatesWrite(t *testing.T) { + called := false + writer := writerFunc(func(b []byte) (int, error) { + called = true + require.Equal(t, []byte{1, 2, 3}, b) + return len(b), nil + }) + + n, err := writer.Write([]byte{1, 2, 3}) + require.NoError(t, err) + require.Equal(t, 3, n) + require.True(t, called) +} + +func TestCaptureOnPCMChunkingAndStopFlushesPending(t *testing.T) { + capture := &Capture{ + chunks: make(chan []byte, 8), + stopCh: make(chan struct{}), + } + + input := make([]byte, chunkSizeBytes+111) + for i := range input { + input[i] = byte(i % 255) + } + + n, err := capture.onPCM(input) + require.NoError(t, err) + require.Equal(t, len(input), n) + require.Equal(t, int64(len(input)), capture.BytesCaptured()) + require.Equal(t, len(input), len(capture.RawPCM())) + + firstChunk := <-capture.Chunks() + require.Len(t, firstChunk, chunkSizeBytes) + + require.NoError(t, capture.Stop()) + + remaining, ok := <-capture.Chunks() + require.True(t, ok) + require.Len(t, remaining, 111) + + _, ok = <-capture.Chunks() + require.False(t, ok) +} + +func TestCaptureOnPCMReturnsEOFWhenStopped(t *testing.T) { + capture := &Capture{ + chunks: make(chan []byte, 1), + stopCh: make(chan struct{}), + } + close(capture.stopCh) + + n, err := capture.onPCM([]byte{1, 2, 3}) + require.Equal(t, 0, n) + require.ErrorIs(t, err, io.EOF) + require.Equal(t, int64(0), capture.BytesCaptured()) +} + +type sourcePort struct { + name string + available uint32 +} + +func setSourcePorts(t *testing.T, reply *pulseproto.GetSourceInfoReply, ports []sourcePort) { + t.Helper() + + sliceType := reflect.TypeOf(reply.Ports) + sliceValue := reflect.MakeSlice(sliceType, len(ports), len(ports)) + + for i, port := range ports { + item := sliceValue.Index(i) + item.FieldByName("Name").SetString(port.name) + item.FieldByName("Available").SetUint(uint64(port.available)) + } + + replyValue := reflect.ValueOf(reply).Elem().FieldByName("Ports") + replyValue.Set(sliceValue) +} diff --git a/apps/sotto/internal/cli/cli.go b/apps/sotto/internal/cli/cli.go new file mode 100644 index 0000000..3a8eef7 --- /dev/null +++ b/apps/sotto/internal/cli/cli.go @@ -0,0 +1,103 @@ +// Package cli defines command parsing and user help text contracts. +package cli + +import ( + "errors" + "fmt" + "strings" +) + +// Command is the user-facing subcommand vocabulary for the CLI. +type Command string + +const ( + CommandToggle Command = "toggle" + CommandStop Command = "stop" + CommandCancel Command = "cancel" + CommandStatus Command = "status" + CommandDevices Command = "devices" + CommandDoctor Command = "doctor" + CommandVersion Command = "version" + CommandHelp Command = "help" +) + +var validCommands = map[Command]struct{}{ + CommandToggle: {}, + CommandStop: {}, + CommandCancel: {}, + CommandStatus: {}, + CommandDevices: {}, + CommandDoctor: {}, + CommandVersion: {}, + CommandHelp: {}, +} + +// Parsed contains normalized argument parsing output. +type Parsed struct { + Command Command + ConfigPath string + ShowHelp bool +} + +// Parse converts argv into a Parsed command contract with validation. +func Parse(args []string) (Parsed, error) { + parsed := Parsed{Command: CommandHelp, ShowHelp: true} + + for i := 0; i < len(args); i++ { + arg := args[i] + + switch arg { + case "-h", "--help": + parsed.ShowHelp = true + parsed.Command = CommandHelp + case "--version": + parsed.ShowHelp = false + parsed.Command = CommandVersion + case "--config": + i++ + if i >= len(args) { + return Parsed{}, errors.New("--config requires a path") + } + parsed.ConfigPath = args[i] + default: + if strings.HasPrefix(arg, "-") { + return Parsed{}, fmt.Errorf("unknown flag: %s", arg) + } + + cmd := Command(arg) + if _, ok := validCommands[cmd]; !ok { + return Parsed{}, fmt.Errorf("unknown command: %s", arg) + } + + parsed.Command = cmd + parsed.ShowHelp = cmd == CommandHelp + if i != len(args)-1 { + return Parsed{}, fmt.Errorf("unexpected arguments after command %q", arg) + } + } + } + + return parsed, nil +} + +// HelpText returns full usage text shown for --help and parse errors. +func HelpText(binaryName string) string { + return fmt.Sprintf(`Usage: + %[1]s [--config PATH] + +Commands: + toggle Start recording or stop+commit when already recording + stop Stop active recording and commit transcript + cancel Cancel active recording and discard transcript + status Print current state + devices List available input devices + doctor Run configuration and environment checks + version Print version information + help Show this help + +Flags: + --config PATH Config file path (default: $XDG_CONFIG_HOME/sotto/config.jsonc) + -h, --help Show help + --version Show version +`, binaryName) +} diff --git a/apps/sotto/internal/cli/cli_test.go b/apps/sotto/internal/cli/cli_test.go new file mode 100644 index 0000000..8a88400 --- /dev/null +++ b/apps/sotto/internal/cli/cli_test.go @@ -0,0 +1,115 @@ +package cli + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseDefaultsToHelp(t *testing.T) { + parsed, err := Parse(nil) + require.NoError(t, err) + require.True(t, parsed.ShowHelp) + require.Equal(t, CommandHelp, parsed.Command) +} + +func TestParseCommandWithConfig(t *testing.T) { + parsed, err := Parse([]string{"--config", "/tmp/sotto.conf", "doctor"}) + require.NoError(t, err) + require.Equal(t, CommandDoctor, parsed.Command) + require.Equal(t, "/tmp/sotto.conf", parsed.ConfigPath) + require.False(t, parsed.ShowHelp) +} + +func TestParseArgMatrix(t *testing.T) { + tests := []struct { + name string + args []string + wantErr string + wantCmd Command + wantHelp bool + wantPath string + }{ + { + name: "help short flag", + args: []string{"-h"}, + wantCmd: CommandHelp, + wantHelp: true, + }, + { + name: "help long flag", + args: []string{"--help"}, + wantCmd: CommandHelp, + wantHelp: true, + }, + { + name: "version flag", + args: []string{"--version"}, + wantCmd: CommandVersion, + wantHelp: false, + }, + { + name: "config after command", + args: []string{"status", "--config", "/tmp/cfg"}, + wantErr: "unexpected arguments after command", + }, + { + name: "missing config path", + args: []string{"--config"}, + wantErr: "requires a path", + }, + { + name: "unknown flag", + args: []string{"--bogus"}, + wantErr: "unknown flag", + }, + { + name: "unknown command", + args: []string{"bogus"}, + wantErr: "unknown command", + }, + { + name: "extra args after command", + args: []string{"doctor", "extra"}, + wantErr: "unexpected arguments", + }, + { + name: "valid cancel command", + args: []string{"cancel"}, + wantCmd: CommandCancel, + wantHelp: false, + }, + { + name: "valid stop with config", + args: []string{"--config", "/tmp/cfg", "stop"}, + wantCmd: CommandStop, + wantHelp: false, + wantPath: "/tmp/cfg", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + parsed, err := Parse(tc.args) + if tc.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.wantErr) + return + } + + require.NoError(t, err) + require.Equal(t, tc.wantCmd, parsed.Command) + require.Equal(t, tc.wantHelp, parsed.ShowHelp) + require.Equal(t, tc.wantPath, parsed.ConfigPath) + }) + } +} + +func TestHelpTextIncludesCoreCommands(t *testing.T) { + text := HelpText("sotto") + require.Contains(t, text, "toggle") + require.Contains(t, text, "stop") + require.Contains(t, text, "cancel") + require.Contains(t, text, "doctor") + require.Contains(t, text, "--config PATH") +} diff --git a/apps/sotto/internal/config/argv.go b/apps/sotto/internal/config/argv.go new file mode 100644 index 0000000..fcd9504 --- /dev/null +++ b/apps/sotto/internal/config/argv.go @@ -0,0 +1,74 @@ +package config + +import ( + "fmt" + "strings" + "unicode" +) + +// parseArgv tokenizes a command string into argv semantics without invoking a shell. +func parseArgv(input string) ([]string, error) { + input = strings.TrimSpace(input) + if input == "" { + return nil, nil + } + if strings.HasPrefix(input, "#") { + return nil, nil + } + + var ( + argv []string + current strings.Builder + quote rune + escape bool + ) + + flush := func() { + if current.Len() == 0 { + return + } + argv = append(argv, current.String()) + current.Reset() + } + + for _, r := range input { + switch { + case escape: + current.WriteRune(r) + escape = false + case r == '\\': + escape = true + case quote != 0: + if r == quote { + quote = 0 + continue + } + current.WriteRune(r) + case r == '\'' || r == '"': + quote = r + case unicode.IsSpace(r): + flush() + default: + current.WriteRune(r) + } + } + + if escape { + return nil, fmt.Errorf("unterminated escape sequence in command: %q", input) + } + if quote != 0 { + return nil, fmt.Errorf("unterminated quote in command: %q", input) + } + + flush() + return argv, nil +} + +// mustParseArgv is a startup helper used only for static default command values. +func mustParseArgv(input string) []string { + argv, err := parseArgv(input) + if err != nil { + panic(err) + } + return argv +} diff --git a/apps/sotto/internal/config/argv_test.go b/apps/sotto/internal/config/argv_test.go new file mode 100644 index 0000000..2ae0425 --- /dev/null +++ b/apps/sotto/internal/config/argv_test.go @@ -0,0 +1,44 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestParseArgv(t *testing.T) { + tests := []struct { + name string + input string + want []string + wantErr string + }{ + {name: "empty", input: "", want: nil}, + {name: "simple", input: "wl-copy --trim-newline", want: []string{"wl-copy", "--trim-newline"}}, + {name: "quoted spaces", input: `mycmd --name "hello world"`, want: []string{"mycmd", "--name", "hello world"}}, + {name: "single quote", input: `mycmd --name 'hello world'`, want: []string{"mycmd", "--name", "hello world"}}, + {name: "escaped space", input: `mycmd hello\ world`, want: []string{"mycmd", "hello world"}}, + {name: "leading comment", input: `# wl-copy --trim-newline`, want: nil}, + {name: "unterminated quote", input: `mycmd "oops`, wantErr: "unterminated quote"}, + {name: "unterminated escape", input: `mycmd hello\`, wantErr: "unterminated escape"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := parseArgv(tc.input) + if tc.wantErr != "" { + require.Error(t, err) + require.Contains(t, err.Error(), tc.wantErr) + return + } + require.NoError(t, err) + require.Equal(t, tc.want, got) + }) + } +} + +func TestMustParseArgvPanicsOnInvalidInput(t *testing.T) { + require.Panics(t, func() { + _ = mustParseArgv(`mycmd "unterminated`) + }) +} diff --git a/apps/sotto/internal/config/defaults.go b/apps/sotto/internal/config/defaults.go new file mode 100644 index 0000000..ce87203 --- /dev/null +++ b/apps/sotto/internal/config/defaults.go @@ -0,0 +1,45 @@ +package config + +// Default returns the canonical runtime configuration used when no file is present. +func Default() Config { + clipboard := "wl-copy --trim-newline" + + return Config{ + RivaGRPC: "127.0.0.1:50051", + RivaHTTP: "127.0.0.1:9000", + RivaHealthPath: "/v1/health/ready", + Audio: AudioConfig{ + Input: "default", + Fallback: "default", + }, + Paste: PasteConfig{Enable: true, Shortcut: "CTRL,V"}, + ASR: ASRConfig{ + AutomaticPunctuation: true, + LanguageCode: "en-US", + Model: "", + }, + Transcript: TranscriptConfig{TrailingSpace: true}, + Indicator: IndicatorConfig{ + Enable: true, + Backend: "hypr", + DesktopAppName: "sotto-indicator", + SoundEnable: true, + SoundStartFile: "", + SoundStopFile: "", + SoundCompleteFile: "", + SoundCancelFile: "", + Height: 28, + TextRecording: "Recording…", + TextProcessing: "Transcribing…", + TextError: "Speech recognition error", + ErrorTimeoutMS: 1600, + }, + Clipboard: CommandConfig{Raw: clipboard, Argv: mustParseArgv(clipboard)}, + Vocab: VocabConfig{ + GlobalSets: nil, + Sets: map[string]VocabSet{}, + MaxPhrases: 1024, + }, + Debug: DebugConfig{}, + } +} diff --git a/apps/sotto/internal/config/load.go b/apps/sotto/internal/config/load.go new file mode 100644 index 0000000..92d3abe --- /dev/null +++ b/apps/sotto/internal/config/load.go @@ -0,0 +1,76 @@ +package config + +import ( + "errors" + "fmt" + "os" + "strings" +) + +// Loaded captures resolved config path, parsed values, and non-fatal warnings. +type Loaded struct { + Path string + Config Config + Warnings []Warning + Exists bool +} + +// Load resolves, reads, parses, and validates the runtime configuration. +func Load(explicitPath string) (Loaded, error) { + resolvedPath, err := ResolvePath(explicitPath) + if err != nil { + return Loaded{}, err + } + + base := Default() + loadedPath := resolvedPath + warnings := make([]Warning, 0) + + content, err := os.ReadFile(resolvedPath) + if err != nil { + if !errors.Is(err, os.ErrNotExist) { + return Loaded{}, fmt.Errorf("read config %q: %w", resolvedPath, err) + } + + if strings.TrimSpace(explicitPath) == "" { + legacyPath := legacyPathFor(resolvedPath) + if legacyPath != "" { + legacyContent, legacyErr := os.ReadFile(legacyPath) + if legacyErr == nil { + content = legacyContent + loadedPath = legacyPath + warnings = append(warnings, Warning{ + Message: fmt.Sprintf("loaded legacy config path %q; migrate to %q (JSONC)", legacyPath, resolvedPath), + }) + } else if !errors.Is(legacyErr, os.ErrNotExist) { + return Loaded{}, fmt.Errorf("read config %q: %w", legacyPath, legacyErr) + } + } + } + + if content == nil { + warnings = append(warnings, Warning{ + Message: fmt.Sprintf("config file %q not found; using defaults", resolvedPath), + }) + return Loaded{ + Path: resolvedPath, + Config: base, + Warnings: warnings, + Exists: false, + }, nil + } + } + + cfg, parseWarnings, err := Parse(string(content), base) + if err != nil { + return Loaded{}, fmt.Errorf("parse config %q: %w", loadedPath, err) + } + warnings = append(warnings, parseWarnings...) + + return Loaded{ + Path: loadedPath, + Config: cfg, + Warnings: warnings, + Exists: true, + }, nil +} diff --git a/apps/sotto/internal/config/parser.go b/apps/sotto/internal/config/parser.go new file mode 100644 index 0000000..5be4909 --- /dev/null +++ b/apps/sotto/internal/config/parser.go @@ -0,0 +1,31 @@ +// Package config resolves, parses, validates, and defaults sotto configuration. +package config + +import "strings" + +const legacyFormatWarning = "legacy key=value config format is deprecated; migrate to JSONC" + +// Parse reads configuration content as JSONC (preferred) or legacy key/value format. +// +// JSONC is selected when the first non-whitespace character is `{`. +func Parse(content string, base Config) (Config, []Warning, error) { + trimmed := strings.TrimSpace(content) + if trimmed == "" { + validatedWarnings, err := Validate(base) + if err != nil { + return Config{}, nil, err + } + return base, validatedWarnings, nil + } + + if strings.HasPrefix(trimmed, "{") { + return parseJSONC(content, base) + } + + cfg, warnings, err := parseLegacy(content, base) + if err != nil { + return Config{}, nil, err + } + warnings = append([]Warning{{Message: legacyFormatWarning}}, warnings...) + return cfg, warnings, nil +} diff --git a/apps/sotto/internal/config/parser_jsonc.go b/apps/sotto/internal/config/parser_jsonc.go new file mode 100644 index 0000000..82f220c --- /dev/null +++ b/apps/sotto/internal/config/parser_jsonc.go @@ -0,0 +1,507 @@ +package config + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "strings" +) + +type jsoncConfig struct { + Riva *jsoncRiva `json:"riva"` + Audio *jsoncAudio `json:"audio"` + Paste *jsoncPaste `json:"paste"` + ASR *jsoncASR `json:"asr"` + Transcript *jsoncTranscript `json:"transcript"` + Indicator *jsoncIndicator `json:"indicator"` + + ClipboardCmd *string `json:"clipboard_cmd"` + PasteCmd *string `json:"paste_cmd"` + Vocab *jsoncVocab `json:"vocab"` + Debug *jsoncDebug `json:"debug"` +} + +type jsoncRiva struct { + GRPC *string `json:"grpc"` + HTTP *string `json:"http"` + HealthPath *string `json:"health_path"` +} + +type jsoncAudio struct { + Input *string `json:"input"` + Fallback *string `json:"fallback"` +} + +type jsoncPaste struct { + Enable *bool `json:"enable"` + Shortcut *string `json:"shortcut"` +} + +type jsoncASR struct { + AutomaticPunctuation *bool `json:"automatic_punctuation"` + LanguageCode *string `json:"language_code"` + Model *string `json:"model"` +} + +type jsoncTranscript struct { + TrailingSpace *bool `json:"trailing_space"` +} + +type jsoncIndicator struct { + Enable *bool `json:"enable"` + Backend *string `json:"backend"` + DesktopAppName *string `json:"desktop_app_name"` + SoundEnable *bool `json:"sound_enable"` + SoundStartFile *string `json:"sound_start_file"` + SoundStopFile *string `json:"sound_stop_file"` + SoundCompleteFile *string `json:"sound_complete_file"` + SoundCancelFile *string `json:"sound_cancel_file"` + Height *int `json:"height"` + TextRecording *string `json:"text_recording"` + TextProcessing *string `json:"text_processing"` + TextTranscribing *string `json:"text_transcribing"` + TextError *string `json:"text_error"` + ErrorTimeoutMS *int `json:"error_timeout_ms"` +} + +type jsoncVocab struct { + Global *jsoncStringList `json:"global"` + MaxPhrases *int `json:"max_phrases"` + Sets map[string]jsoncVocabSet `json:"sets"` +} + +type jsoncVocabSet struct { + Boost *float64 `json:"boost"` + Phrases []string `json:"phrases"` +} + +type jsoncDebug struct { + AudioDump *bool `json:"audio_dump"` + GRPCDump *bool `json:"grpc_dump"` +} + +type jsoncStringList []string + +func (l *jsoncStringList) UnmarshalJSON(data []byte) error { + var list []string + if err := json.Unmarshal(data, &list); err == nil { + *l = list + return nil + } + + var single string + if err := json.Unmarshal(data, &single); err == nil { + parts := strings.Split(single, ",") + out := make([]string, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + out = append(out, part) + } + *l = out + return nil + } + + return fmt.Errorf("expected string array or comma-delimited string") +} + +func parseJSONC(content string, base Config) (Config, []Warning, error) { + normalized, err := normalizeJSONC(content) + if err != nil { + return Config{}, nil, err + } + + decoder := json.NewDecoder(strings.NewReader(normalized)) + decoder.DisallowUnknownFields() + + var payload jsoncConfig + if err := decoder.Decode(&payload); err != nil { + return Config{}, nil, wrapJSONDecodeError(normalized, err) + } + if err := ensureSingleJSONValue(decoder); err != nil { + return Config{}, nil, wrapJSONDecodeError(normalized, err) + } + + cfg := base + warnings, err := payload.applyTo(&cfg) + if err != nil { + return Config{}, nil, err + } + + validatedWarnings, err := Validate(cfg) + if err != nil { + return Config{}, nil, err + } + warnings = append(warnings, validatedWarnings...) + return cfg, warnings, nil +} + +func (payload jsoncConfig) applyTo(cfg *Config) ([]Warning, error) { + warnings := make([]Warning, 0) + + if payload.Riva != nil { + if payload.Riva.GRPC != nil { + cfg.RivaGRPC = *payload.Riva.GRPC + } + if payload.Riva.HTTP != nil { + cfg.RivaHTTP = *payload.Riva.HTTP + } + if payload.Riva.HealthPath != nil { + cfg.RivaHealthPath = *payload.Riva.HealthPath + } + } + + if payload.Audio != nil { + if payload.Audio.Input != nil { + cfg.Audio.Input = *payload.Audio.Input + } + if payload.Audio.Fallback != nil { + cfg.Audio.Fallback = *payload.Audio.Fallback + } + } + + if payload.Paste != nil { + if payload.Paste.Enable != nil { + cfg.Paste.Enable = *payload.Paste.Enable + } + if payload.Paste.Shortcut != nil { + cfg.Paste.Shortcut = strings.TrimSpace(*payload.Paste.Shortcut) + } + } + + if payload.ASR != nil { + if payload.ASR.AutomaticPunctuation != nil { + cfg.ASR.AutomaticPunctuation = *payload.ASR.AutomaticPunctuation + } + if payload.ASR.LanguageCode != nil { + cfg.ASR.LanguageCode = *payload.ASR.LanguageCode + } + if payload.ASR.Model != nil { + cfg.ASR.Model = *payload.ASR.Model + } + } + + if payload.Transcript != nil && payload.Transcript.TrailingSpace != nil { + cfg.Transcript.TrailingSpace = *payload.Transcript.TrailingSpace + } + + if payload.Indicator != nil { + if payload.Indicator.Enable != nil { + cfg.Indicator.Enable = *payload.Indicator.Enable + } + if payload.Indicator.Backend != nil { + cfg.Indicator.Backend = strings.TrimSpace(*payload.Indicator.Backend) + } + if payload.Indicator.DesktopAppName != nil { + cfg.Indicator.DesktopAppName = strings.TrimSpace(*payload.Indicator.DesktopAppName) + } + if payload.Indicator.SoundEnable != nil { + cfg.Indicator.SoundEnable = *payload.Indicator.SoundEnable + } + if payload.Indicator.SoundStartFile != nil { + cfg.Indicator.SoundStartFile = *payload.Indicator.SoundStartFile + } + if payload.Indicator.SoundStopFile != nil { + cfg.Indicator.SoundStopFile = *payload.Indicator.SoundStopFile + } + if payload.Indicator.SoundCompleteFile != nil { + cfg.Indicator.SoundCompleteFile = *payload.Indicator.SoundCompleteFile + } + if payload.Indicator.SoundCancelFile != nil { + cfg.Indicator.SoundCancelFile = *payload.Indicator.SoundCancelFile + } + if payload.Indicator.Height != nil { + cfg.Indicator.Height = *payload.Indicator.Height + } + if payload.Indicator.TextRecording != nil { + cfg.Indicator.TextRecording = *payload.Indicator.TextRecording + } + if payload.Indicator.TextTranscribing != nil { + cfg.Indicator.TextProcessing = *payload.Indicator.TextTranscribing + warnings = append(warnings, Warning{Message: "indicator.text_transcribing is deprecated; use indicator.text_processing"}) + } + if payload.Indicator.TextProcessing != nil { + cfg.Indicator.TextProcessing = *payload.Indicator.TextProcessing + } + if payload.Indicator.TextError != nil { + cfg.Indicator.TextError = *payload.Indicator.TextError + } + if payload.Indicator.ErrorTimeoutMS != nil { + cfg.Indicator.ErrorTimeoutMS = *payload.Indicator.ErrorTimeoutMS + } + } + + if payload.ClipboardCmd != nil { + raw := *payload.ClipboardCmd + argv, err := parseArgv(raw) + if err != nil { + return nil, fmt.Errorf("invalid clipboard_cmd: %w", err) + } + cfg.Clipboard = CommandConfig{Raw: raw, Argv: argv} + } + + if payload.PasteCmd != nil { + raw := *payload.PasteCmd + argv, err := parseArgv(raw) + if err != nil { + return nil, fmt.Errorf("invalid paste_cmd: %w", err) + } + cfg.PasteCmd = CommandConfig{Raw: raw, Argv: argv} + } + + if payload.Vocab != nil { + if payload.Vocab.Global != nil { + cfg.Vocab.GlobalSets = cfg.Vocab.GlobalSets[:0] + for _, name := range *payload.Vocab.Global { + name = strings.TrimSpace(name) + if name == "" { + continue + } + cfg.Vocab.GlobalSets = append(cfg.Vocab.GlobalSets, name) + } + } + if payload.Vocab.MaxPhrases != nil { + cfg.Vocab.MaxPhrases = *payload.Vocab.MaxPhrases + } + if payload.Vocab.Sets != nil { + if cfg.Vocab.Sets == nil { + cfg.Vocab.Sets = make(map[string]VocabSet) + } + for name, set := range payload.Vocab.Sets { + trimmedName := strings.TrimSpace(name) + if trimmedName == "" { + return nil, fmt.Errorf("vocab.sets contains an empty set name") + } + + phrases := make([]string, 0, len(set.Phrases)) + phrases = append(phrases, set.Phrases...) + + entry := VocabSet{Name: trimmedName, Phrases: phrases} + if set.Boost != nil { + entry.Boost = *set.Boost + } + cfg.Vocab.Sets[trimmedName] = entry + } + } + } + + if payload.Debug != nil { + if payload.Debug.AudioDump != nil { + cfg.Debug.EnableAudioDump = *payload.Debug.AudioDump + } + if payload.Debug.GRPCDump != nil { + cfg.Debug.EnableGRPCDump = *payload.Debug.GRPCDump + } + } + + return warnings, nil +} + +func normalizeJSONC(content string) (string, error) { + withoutComments, err := stripJSONCComments(content) + if err != nil { + return "", err + } + return stripJSONCTrailingCommas(withoutComments), nil +} + +func stripJSONCComments(content string) (string, error) { + var out strings.Builder + out.Grow(len(content)) + + inString := false + escape := false + lineComment := false + blockComment := false + + for i := 0; i < len(content); i++ { + ch := content[i] + + if lineComment { + if ch == '\n' { + lineComment = false + out.WriteByte(ch) + continue + } + if ch == '\r' { + lineComment = false + out.WriteByte(ch) + continue + } + out.WriteByte(' ') + continue + } + + if blockComment { + if ch == '*' && i+1 < len(content) && content[i+1] == '/' { + blockComment = false + out.WriteString(" ") + i++ + continue + } + if ch == '\n' || ch == '\r' || ch == '\t' { + out.WriteByte(ch) + } else { + out.WriteByte(' ') + } + continue + } + + if inString { + out.WriteByte(ch) + if escape { + escape = false + continue + } + if ch == '\\' { + escape = true + continue + } + if ch == '"' { + inString = false + } + continue + } + + if ch == '"' { + inString = true + out.WriteByte(ch) + continue + } + + if ch == '/' && i+1 < len(content) { + next := content[i+1] + if next == '/' { + lineComment = true + out.WriteString(" ") + i++ + continue + } + if next == '*' { + blockComment = true + out.WriteString(" ") + i++ + continue + } + } + + out.WriteByte(ch) + } + + if blockComment { + return "", fmt.Errorf("unterminated block comment in JSONC") + } + + return out.String(), nil +} + +func stripJSONCTrailingCommas(content string) string { + var out strings.Builder + out.Grow(len(content)) + + inString := false + escape := false + + for i := 0; i < len(content); i++ { + ch := content[i] + + if inString { + out.WriteByte(ch) + if escape { + escape = false + continue + } + if ch == '\\' { + escape = true + continue + } + if ch == '"' { + inString = false + } + continue + } + + if ch == '"' { + inString = true + out.WriteByte(ch) + continue + } + + if ch == ',' { + j := i + 1 + for j < len(content) && isJSONWhitespace(content[j]) { + j++ + } + if j < len(content) && (content[j] == '}' || content[j] == ']') { + continue + } + } + + out.WriteByte(ch) + } + + return out.String() +} + +func isJSONWhitespace(ch byte) bool { + switch ch { + case ' ', '\n', '\r', '\t': + return true + default: + return false + } +} + +func ensureSingleJSONValue(decoder *json.Decoder) error { + var extra struct{} + err := decoder.Decode(&extra) + if errors.Is(err, io.EOF) { + return nil + } + if err == nil { + return fmt.Errorf("multiple JSON values are not allowed") + } + return err +} + +func wrapJSONDecodeError(content string, err error) error { + var syntaxErr *json.SyntaxError + if errors.As(err, &syntaxErr) { + line, col := offsetToLineCol(content, syntaxErr.Offset) + return fmt.Errorf("line %d column %d: %w", line, col, err) + } + + var typeErr *json.UnmarshalTypeError + if errors.As(err, &typeErr) { + line, col := offsetToLineCol(content, typeErr.Offset) + return fmt.Errorf("line %d column %d: %w", line, col, err) + } + + return err +} + +func offsetToLineCol(content string, offset int64) (int, int) { + if offset <= 0 { + return 1, 1 + } + + limit := int(offset) + if limit > len(content) { + limit = len(content) + } + + line := 1 + col := 1 + for i := 0; i < limit-1; i++ { + if content[i] == '\n' { + line++ + col = 1 + continue + } + col++ + } + return line, col +} diff --git a/apps/sotto/internal/config/parser_legacy.go b/apps/sotto/internal/config/parser_legacy.go new file mode 100644 index 0000000..1e29543 --- /dev/null +++ b/apps/sotto/internal/config/parser_legacy.go @@ -0,0 +1,498 @@ +package config + +import ( + "bufio" + "fmt" + "strconv" + "strings" +) + +// parseState tracks block-level parser context while scanning lines. +type parseState struct { + inVocabSet *VocabSet + vocabSetStartLine int +} + +// parseLegacy applies the legacy line-oriented key/value config grammar. +func parseLegacy(content string, base Config) (Config, []Warning, error) { + cfg := base + warnings := make([]Warning, 0) + state := &parseState{} + + scanner := bufio.NewScanner(strings.NewReader(content)) + for line := 1; scanner.Scan(); line++ { + raw := scanner.Text() + trimmed := strings.TrimSpace(stripComments(raw)) + if trimmed == "" { + continue + } + + if state.inVocabSet != nil { + if trimmed == "}" { + if cfg.Vocab.Sets == nil { + cfg.Vocab.Sets = make(map[string]VocabSet) + } + cfg.Vocab.Sets[state.inVocabSet.Name] = *state.inVocabSet + state.inVocabSet = nil + state.vocabSetStartLine = 0 + continue + } + + key, value, err := parseAssignment(trimmed) + if err != nil { + return Config{}, nil, lineError(line, err) + } + if err := applyVocabSetKey(state.inVocabSet, key, value); err != nil { + return Config{}, nil, lineError(line, err) + } + continue + } + + if strings.HasPrefix(trimmed, "vocabset ") { + set, err := parseVocabSetHeader(trimmed) + if err != nil { + return Config{}, nil, lineError(line, err) + } + if _, exists := cfg.Vocab.Sets[set.Name]; exists { + warnings = append(warnings, Warning{ + Line: line, + Message: fmt.Sprintf("vocabset %q redefined; last definition wins", set.Name), + }) + } + state.inVocabSet = &set + state.vocabSetStartLine = line + continue + } + + key, value, err := parseAssignment(trimmed) + if err != nil { + return Config{}, nil, lineError(line, err) + } + if err := applyRootKey(&cfg, key, value); err != nil { + return Config{}, nil, lineError(line, err) + } + } + + if err := scanner.Err(); err != nil { + return Config{}, nil, err + } + + if state.inVocabSet != nil { + line := state.vocabSetStartLine + if line <= 0 { + line = scannerPosition(content) + } + return Config{}, nil, fmt.Errorf("line %d: unterminated vocabset %q block", line, state.inVocabSet.Name) + } + + validatedWarnings, err := Validate(cfg) + if err != nil { + return Config{}, nil, err + } + warnings = append(warnings, validatedWarnings...) + + return cfg, warnings, nil +} + +// parseAssignment parses a single `key = value` expression. +func parseAssignment(line string) (string, string, error) { + idx := strings.Index(line, "=") + if idx < 0 { + return "", "", fmt.Errorf("expected key = value") + } + key := strings.TrimSpace(line[:idx]) + value := strings.TrimSpace(line[idx+1:]) + if key == "" { + return "", "", fmt.Errorf("empty key") + } + if value == "" { + return "", "", fmt.Errorf("missing value for key %q", key) + } + return key, value, nil +} + +// parseVocabSetHeader parses `vocabset {` headers. +func parseVocabSetHeader(line string) (VocabSet, error) { + if !strings.HasSuffix(line, "{") { + return VocabSet{}, fmt.Errorf("vocabset declaration must end with '{'") + } + line = strings.TrimSpace(strings.TrimSuffix(line, "{")) + parts := strings.Fields(line) + if len(parts) != 2 { + return VocabSet{}, fmt.Errorf("invalid vocabset declaration; expected: vocabset {") + } + if parts[0] != "vocabset" { + return VocabSet{}, fmt.Errorf("invalid block type %q", parts[0]) + } + + return VocabSet{Name: parts[1]}, nil +} + +// applyVocabSetKey applies an assignment within an active vocabset block. +func applyVocabSetKey(set *VocabSet, key, value string) error { + switch key { + case "boost": + f, err := strconv.ParseFloat(value, 64) + if err != nil { + return fmt.Errorf("invalid boost value: %w", err) + } + set.Boost = f + case "phrases": + phrases, err := parseStringList(value) + if err != nil { + return err + } + set.Phrases = phrases + default: + return fmt.Errorf("unknown vocabset key %q", key) + } + return nil +} + +// applyRootKey applies one top-level key/value assignment into cfg. +func applyRootKey(cfg *Config, key, value string) error { + switch key { + case "riva_grpc": + v, err := parseStringValue(value) + if err != nil { + return err + } + cfg.RivaGRPC = v + case "riva_http": + v, err := parseStringValue(value) + if err != nil { + return err + } + cfg.RivaHTTP = v + case "riva_health_path": + v, err := parseStringValue(value) + if err != nil { + return err + } + cfg.RivaHealthPath = v + case "audio.input": + v, err := parseStringValue(value) + if err != nil { + return err + } + cfg.Audio.Input = v + case "audio.fallback": + v, err := parseStringValue(value) + if err != nil { + return err + } + cfg.Audio.Fallback = v + case "paste.enable": + b, err := strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("invalid bool for paste.enable: %w", err) + } + cfg.Paste.Enable = b + case "paste.shortcut": + v, err := parseStringValue(value) + if err != nil { + return err + } + cfg.Paste.Shortcut = strings.TrimSpace(v) + case "asr.automatic_punctuation": + b, err := strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("invalid bool for asr.automatic_punctuation: %w", err) + } + cfg.ASR.AutomaticPunctuation = b + case "asr.language_code": + v, err := parseStringValue(value) + if err != nil { + return err + } + cfg.ASR.LanguageCode = v + case "asr.model": + v, err := parseStringValue(value) + if err != nil { + return err + } + cfg.ASR.Model = v + case "transcript.trailing_space": + b, err := strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("invalid bool for transcript.trailing_space: %w", err) + } + cfg.Transcript.TrailingSpace = b + case "indicator.enable": + b, err := strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("invalid bool for indicator.enable: %w", err) + } + cfg.Indicator.Enable = b + case "indicator.backend": + v, err := parseStringValue(value) + if err != nil { + return err + } + cfg.Indicator.Backend = strings.TrimSpace(v) + case "indicator.desktop_app_name": + v, err := parseStringValue(value) + if err != nil { + return err + } + cfg.Indicator.DesktopAppName = strings.TrimSpace(v) + case "indicator.sound_enable": + b, err := strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("invalid bool for indicator.sound_enable: %w", err) + } + cfg.Indicator.SoundEnable = b + case "indicator.sound_start_file": + v, err := parseStringValue(value) + if err != nil { + return err + } + cfg.Indicator.SoundStartFile = v + case "indicator.sound_stop_file": + v, err := parseStringValue(value) + if err != nil { + return err + } + cfg.Indicator.SoundStopFile = v + case "indicator.sound_complete_file": + v, err := parseStringValue(value) + if err != nil { + return err + } + cfg.Indicator.SoundCompleteFile = v + case "indicator.sound_cancel_file": + v, err := parseStringValue(value) + if err != nil { + return err + } + cfg.Indicator.SoundCancelFile = v + case "indicator.height": + n, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("invalid int for indicator.height: %w", err) + } + cfg.Indicator.Height = n + case "indicator.text_recording": + v, err := parseStringValue(value) + if err != nil { + return err + } + cfg.Indicator.TextRecording = v + case "indicator.text_processing", "indicator.text_transcribing": + v, err := parseStringValue(value) + if err != nil { + return err + } + cfg.Indicator.TextProcessing = v + case "indicator.text_error": + v, err := parseStringValue(value) + if err != nil { + return err + } + cfg.Indicator.TextError = v + case "indicator.error_timeout_ms": + n, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("invalid int for indicator.error_timeout_ms: %w", err) + } + cfg.Indicator.ErrorTimeoutMS = n + case "clipboard_cmd": + v, err := parseStringValue(value) + if err != nil { + return err + } + argv, err := parseArgv(v) + if err != nil { + return fmt.Errorf("invalid clipboard_cmd: %w", err) + } + cfg.Clipboard = CommandConfig{Raw: v, Argv: argv} + case "paste_cmd": + v, err := parseStringValue(value) + if err != nil { + return err + } + argv, err := parseArgv(v) + if err != nil { + return fmt.Errorf("invalid paste_cmd: %w", err) + } + cfg.PasteCmd = CommandConfig{Raw: v, Argv: argv} + case "vocab.global": + sets := strings.Split(value, ",") + cfg.Vocab.GlobalSets = cfg.Vocab.GlobalSets[:0] + for _, s := range sets { + s = strings.TrimSpace(s) + if s == "" { + continue + } + cfg.Vocab.GlobalSets = append(cfg.Vocab.GlobalSets, s) + } + case "vocab.max_phrases": + n, err := strconv.Atoi(value) + if err != nil { + return fmt.Errorf("invalid int for vocab.max_phrases: %w", err) + } + cfg.Vocab.MaxPhrases = n + case "debug.audio_dump": + b, err := strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("invalid bool for debug.audio_dump: %w", err) + } + cfg.Debug.EnableAudioDump = b + case "debug.grpc_dump": + b, err := strconv.ParseBool(value) + if err != nil { + return fmt.Errorf("invalid bool for debug.grpc_dump: %w", err) + } + cfg.Debug.EnableGRPCDump = b + default: + return fmt.Errorf("unknown key %q", key) + } + + return nil +} + +// parseStringValue parses quoted or unquoted scalar strings. +func parseStringValue(raw string) (string, error) { + raw = strings.TrimSpace(raw) + if raw == "" { + return "", fmt.Errorf("value cannot be empty") + } + + if strings.HasPrefix(raw, "\"") { + quoted, err := strconv.Unquote(raw) + if err != nil { + return "", fmt.Errorf("invalid quoted string %q: %w", raw, err) + } + return quoted, nil + } + if strings.HasPrefix(raw, "'") { + return parseSingleQuotedString(raw) + } + + return raw, nil +} + +// parseSingleQuotedString parses shell-style single-quoted values. +func parseSingleQuotedString(raw string) (string, error) { + if len(raw) < 2 || !strings.HasSuffix(raw, "'") { + return "", fmt.Errorf("invalid quoted string %q: missing closing single quote", raw) + } + + inner := raw[1 : len(raw)-1] + var ( + out strings.Builder + escape bool + ) + for _, r := range inner { + switch { + case escape: + out.WriteRune(r) + escape = false + case r == '\\': + escape = true + default: + out.WriteRune(r) + } + } + if escape { + return "", fmt.Errorf("invalid quoted string %q: unterminated escape", raw) + } + return out.String(), nil +} + +// parseStringList parses `[ ... ]` phrase arrays. +func parseStringList(raw string) ([]string, error) { + raw = strings.TrimSpace(raw) + if !strings.HasPrefix(raw, "[") || !strings.HasSuffix(raw, "]") { + return nil, fmt.Errorf("phrases must be in [ ... ]") + } + + raw = strings.TrimSpace(strings.TrimSuffix(strings.TrimPrefix(raw, "["), "]")) + if raw == "" { + return nil, nil + } + + parts := splitCommaAware(raw) + out := make([]string, 0, len(parts)) + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + val, err := parseStringValue(part) + if err != nil { + return nil, fmt.Errorf("invalid phrase %q: %w", part, err) + } + out = append(out, val) + } + return out, nil +} + +// splitCommaAware splits by commas while preserving quoted commas. +func splitCommaAware(input string) []string { + var ( + parts []string + start int + quote rune + escape bool + ) + + for i, r := range input { + switch { + case escape: + escape = false + case r == '\\': + escape = true + case quote != 0: + if r == quote { + quote = 0 + } + case r == '\'' || r == '"': + quote = r + case r == ',': + parts = append(parts, input[start:i]) + start = i + 1 + } + } + + parts = append(parts, input[start:]) + return parts +} + +// stripComments removes # comments unless they appear inside quotes. +func stripComments(line string) string { + var ( + quote rune + escape bool + ) + for i, r := range line { + switch { + case escape: + escape = false + case r == '\\': + escape = true + case quote != 0: + if r == quote { + quote = 0 + } + case r == '\'' || r == '"': + quote = r + case r == '#': + return line[:i] + } + } + return line +} + +// lineError wraps parser errors with 1-indexed source line context. +func lineError(line int, err error) error { + return fmt.Errorf("line %d: %w", line, err) +} + +// scannerPosition returns the final logical line number for EOF errors. +func scannerPosition(content string) int { + if content == "" { + return 1 + } + return strings.Count(content, "\n") + 1 +} diff --git a/apps/sotto/internal/config/parser_test.go b/apps/sotto/internal/config/parser_test.go new file mode 100644 index 0000000..6b071cd --- /dev/null +++ b/apps/sotto/internal/config/parser_test.go @@ -0,0 +1,281 @@ +package config + +import ( + "strings" + "testing" +) + +func TestParseValidJSONCConfig(t *testing.T) { + input := ` +{ + // local endpoints + "riva": { + "grpc": "127.0.0.1:50051", + "http": "127.0.0.1:9000" + }, + "audio": { + "input": "Elgato" + }, + "paste": { + "enable": true, + "shortcut": "SUPER,V" + }, + "vocab": { + "global": ["core", "team"], + "sets": { + "core": { + "boost": 14, + "phrases": ["Sotto", "Hyprland"] + }, + "team": { + "boost": 18, + "phrases": ["Sotto", "Riva"] + } + } + }, +} +` + + cfg, warnings, err := Parse(input, Default()) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + if cfg.RivaGRPC != "127.0.0.1:50051" { + t.Fatalf("unexpected riva.grpc: %s", cfg.RivaGRPC) + } + if cfg.Audio.Input != "Elgato" { + t.Fatalf("unexpected audio.input: %s", cfg.Audio.Input) + } + if cfg.Paste.Shortcut != "SUPER,V" { + t.Fatalf("unexpected paste.shortcut: %s", cfg.Paste.Shortcut) + } + if len(warnings) == 0 { + t.Fatalf("expected dedupe warning for repeated phrase") + } + + phrases, _, err := BuildSpeechPhrases(cfg) + if err != nil { + t.Fatalf("BuildSpeechPhrases() error = %v", err) + } + if len(phrases) != 3 { + t.Fatalf("expected 3 unique phrases, got %d", len(phrases)) + } + + for _, p := range phrases { + if p.Phrase == "Sotto" && p.Boost != 18 { + t.Fatalf("expected highest boost retained for Sotto; got %v", p.Boost) + } + } +} + +func TestParseLegacyFormatStillSupportedWithWarning(t *testing.T) { + cfg, warnings, err := Parse(` +riva_grpc = 127.0.0.1:50051 +paste.enable = false +`, Default()) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + if cfg.RivaGRPC != "127.0.0.1:50051" { + t.Fatalf("unexpected riva_grpc: %s", cfg.RivaGRPC) + } + if cfg.Paste.Enable { + t.Fatalf("expected paste.enable=false") + } + + found := false + for _, w := range warnings { + if strings.Contains(w.Message, "legacy") { + found = true + break + } + } + if !found { + t.Fatalf("expected legacy format warning, warnings=%+v", warnings) + } +} + +func TestParseJSONCUnknownKeyFails(t *testing.T) { + _, _, err := Parse(`{"foo": {"bar": 1}}`, Default()) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "unknown field") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestParseJSONCLineNumberOnError(t *testing.T) { + _, _, err := Parse(` +{ + "riva": { + "grpc": "127.0.0.1:50051" + "http": "127.0.0.1:9000" + } +} +`, Default()) + if err == nil { + t.Fatal("expected error") + } + if !strings.Contains(err.Error(), "line") { + t.Fatalf("expected line number in error, got %v", err) + } +} + +func TestValidateMissingVocabSetReference(t *testing.T) { + cfg := Default() + cfg.Vocab.GlobalSets = []string{"missing"} + + if _, err := Validate(cfg); err == nil { + t.Fatal("expected error for missing vocab set") + } +} + +func TestValidateMaxPhraseLimit(t *testing.T) { + cfg := Default() + cfg.Vocab.MaxPhrases = 1 + cfg.Vocab.GlobalSets = []string{"team"} + cfg.Vocab.Sets["team"] = VocabSet{ + Name: "team", + Boost: 10, + Phrases: []string{"one", "two"}, + } + + _, err := Validate(cfg) + if err == nil { + t.Fatal("expected max phrase limit error") + } + if !strings.Contains(err.Error(), "exceeds") { + t.Fatalf("unexpected error: %v", err) + } +} + +func TestParseCommandArgvQuoted(t *testing.T) { + cfg, _, err := Parse(`{"paste_cmd":"mycmd --name 'hello world'"}`, Default()) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + got := strings.Join(cfg.PasteCmd.Argv, "|") + want := "mycmd|--name|hello world" + if got != want { + t.Fatalf("unexpected argv parse: got %q want %q", got, want) + } +} + +func TestParsePasteShortcut(t *testing.T) { + cfg, _, err := Parse(`{"paste":{"shortcut":"SUPER,V"}}`, Default()) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + if cfg.Paste.Shortcut != "SUPER,V" { + t.Fatalf("unexpected paste.shortcut: %q", cfg.Paste.Shortcut) + } +} + +func TestParseIndicatorBackend(t *testing.T) { + cfg, _, err := Parse(` +{ + "indicator": { + "backend": "desktop", + "desktop_app_name": "sotto-indicator" + } +} +`, Default()) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + if cfg.Indicator.Backend != "desktop" { + t.Fatalf("expected indicator.backend=desktop, got %q", cfg.Indicator.Backend) + } + if cfg.Indicator.DesktopAppName != "sotto-indicator" { + t.Fatalf("unexpected indicator.desktop_app_name: %q", cfg.Indicator.DesktopAppName) + } +} + +func TestParseIndicatorTextTranscribingAliasWarning(t *testing.T) { + cfg, warnings, err := Parse(`{"indicator":{"text_transcribing":"Working..."}}`, Default()) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + if cfg.Indicator.TextProcessing != "Working..." { + t.Fatalf("unexpected text processing value: %q", cfg.Indicator.TextProcessing) + } + + found := false + for _, w := range warnings { + if strings.Contains(w.Message, "text_transcribing") { + found = true + break + } + } + if !found { + t.Fatalf("expected alias warning, warnings=%+v", warnings) + } +} + +func TestParseIndicatorSoundEnable(t *testing.T) { + cfg, _, err := Parse(`{"indicator":{"sound_enable":false}}`, Default()) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + if cfg.Indicator.SoundEnable { + t.Fatalf("expected indicator.sound_enable=false") + } +} + +func TestParseIndicatorSoundFiles(t *testing.T) { + cfg, _, err := Parse(` +{ + "indicator": { + "sound_start_file": "/tmp/start.wav", + "sound_stop_file": "/tmp/stop.wav", + "sound_complete_file": "/tmp/complete.wav", + "sound_cancel_file": "/tmp/cancel.wav" + } +} +`, Default()) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + if cfg.Indicator.SoundStartFile != "/tmp/start.wav" { + t.Fatalf("unexpected start file: %q", cfg.Indicator.SoundStartFile) + } + if cfg.Indicator.SoundStopFile != "/tmp/stop.wav" { + t.Fatalf("unexpected stop file: %q", cfg.Indicator.SoundStopFile) + } + if cfg.Indicator.SoundCompleteFile != "/tmp/complete.wav" { + t.Fatalf("unexpected complete file: %q", cfg.Indicator.SoundCompleteFile) + } + if cfg.Indicator.SoundCancelFile != "/tmp/cancel.wav" { + t.Fatalf("unexpected cancel file: %q", cfg.Indicator.SoundCancelFile) + } +} + +func TestParseInitializesNilVocabMap(t *testing.T) { + base := Default() + base.Vocab.Sets = nil + + cfg, _, err := Parse(` +{ + "vocab": { + "sets": { + "team": { + "boost": 10, + "phrases": ["sotto"] + } + } + } +} +`, base) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + if cfg.Vocab.Sets == nil { + t.Fatal("expected vocab map to be initialized") + } + if _, ok := cfg.Vocab.Sets["team"]; !ok { + t.Fatalf("expected parsed vocab set to be present") + } +} diff --git a/apps/sotto/internal/config/path.go b/apps/sotto/internal/config/path.go new file mode 100644 index 0000000..7aba784 --- /dev/null +++ b/apps/sotto/internal/config/path.go @@ -0,0 +1,33 @@ +package config + +import ( + "errors" + "os" + "path/filepath" + "strings" +) + +// ResolvePath applies CLI/XDG/home fallback rules for config.jsonc location. +func ResolvePath(explicit string) (string, error) { + if strings.TrimSpace(explicit) != "" { + return explicit, nil + } + + if xdg := strings.TrimSpace(os.Getenv("XDG_CONFIG_HOME")); xdg != "" { + return filepath.Join(xdg, "sotto", "config.jsonc"), nil + } + + home, err := os.UserHomeDir() + if err != nil { + return "", errors.New("unable to resolve user home for config fallback") + } + + return filepath.Join(home, ".config", "sotto", "config.jsonc"), nil +} + +func legacyPathFor(path string) string { + if strings.HasSuffix(path, "config.jsonc") { + return strings.TrimSuffix(path, "config.jsonc") + "config.conf" + } + return "" +} diff --git a/apps/sotto/internal/config/path_load_test.go b/apps/sotto/internal/config/path_load_test.go new file mode 100644 index 0000000..99f9ab9 --- /dev/null +++ b/apps/sotto/internal/config/path_load_test.go @@ -0,0 +1,96 @@ +package config + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestResolvePathPrecedence(t *testing.T) { + explicit := "/tmp/custom.jsonc" + resolved, err := ResolvePath(explicit) + require.NoError(t, err) + require.Equal(t, explicit, resolved) + + xdg := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", xdg) + resolved, err = ResolvePath("") + require.NoError(t, err) + require.Equal(t, filepath.Join(xdg, "sotto", "config.jsonc"), resolved) + + t.Setenv("XDG_CONFIG_HOME", "") + home := t.TempDir() + t.Setenv("HOME", home) + resolved, err = ResolvePath("") + require.NoError(t, err) + require.Equal(t, filepath.Join(home, ".config", "sotto", "config.jsonc"), resolved) +} + +func TestLoadMissingConfigUsesDefaultsWithWarning(t *testing.T) { + path := filepath.Join(t.TempDir(), "missing.jsonc") + + loaded, err := Load(path) + require.NoError(t, err) + require.Equal(t, path, loaded.Path) + require.False(t, loaded.Exists) + require.Equal(t, Default(), loaded.Config) + require.NotEmpty(t, loaded.Warnings) + require.Contains(t, loaded.Warnings[0].Message, "not found") +} + +func TestLoadExistingJSONCParsesAndValidates(t *testing.T) { + path := filepath.Join(t.TempDir(), "config.jsonc") + contents := ` +{ + "riva": { + "grpc": "127.0.0.1:50051", + "http": "127.0.0.1:9000" + }, + "audio": { + "input": "default", + "fallback": "default" + }, + "paste": { + "enable": false + } +} +` + require.NoError(t, os.WriteFile(path, []byte(contents), 0o600)) + + loaded, err := Load(path) + require.NoError(t, err) + require.True(t, loaded.Exists) + require.Equal(t, path, loaded.Path) + require.Equal(t, "127.0.0.1:50051", loaded.Config.RivaGRPC) + require.Equal(t, "127.0.0.1:9000", loaded.Config.RivaHTTP) + require.False(t, loaded.Config.Paste.Enable) +} + +func TestLoadImplicitPathFallsBackToLegacyConfigConf(t *testing.T) { + xdg := t.TempDir() + t.Setenv("XDG_CONFIG_HOME", xdg) + + legacyPath := filepath.Join(xdg, "sotto", "config.conf") + require.NoError(t, os.MkdirAll(filepath.Dir(legacyPath), 0o700)) + require.NoError(t, os.WriteFile(legacyPath, []byte("paste.enable = false\n"), 0o600)) + + loaded, err := Load("") + require.NoError(t, err) + require.True(t, loaded.Exists) + require.Equal(t, legacyPath, loaded.Path) + require.False(t, loaded.Config.Paste.Enable) + require.NotEmpty(t, loaded.Warnings) + require.Contains(t, loaded.Warnings[0].Message, "legacy config path") +} + +func TestLoadParseErrorIncludesPath(t *testing.T) { + path := filepath.Join(t.TempDir(), "broken.jsonc") + require.NoError(t, os.WriteFile(path, []byte("{ not-json }"), 0o600)) + + _, err := Load(path) + require.Error(t, err) + require.Contains(t, err.Error(), "parse config") + require.Contains(t, err.Error(), path) +} diff --git a/apps/sotto/internal/config/types.go b/apps/sotto/internal/config/types.go new file mode 100644 index 0000000..917a92a --- /dev/null +++ b/apps/sotto/internal/config/types.go @@ -0,0 +1,97 @@ +// Package config resolves, parses, validates, and defaults sotto configuration. +package config + +// Config is the fully materialized runtime configuration used by sotto. +type Config struct { + RivaGRPC string + RivaHTTP string + RivaHealthPath string + Audio AudioConfig + Paste PasteConfig + ASR ASRConfig + Transcript TranscriptConfig + Indicator IndicatorConfig + Clipboard CommandConfig + PasteCmd CommandConfig + Vocab VocabConfig + Debug DebugConfig +} + +// AudioConfig controls preferred and fallback input-source selection. +type AudioConfig struct { + Input string + Fallback string +} + +// PasteConfig controls post-commit paste behavior. +type PasteConfig struct { + Enable bool + Shortcut string +} + +// ASRConfig controls request-level hints passed to Riva. +type ASRConfig struct { + AutomaticPunctuation bool + LanguageCode string + Model string +} + +// TranscriptConfig controls transcript assembly formatting. +type TranscriptConfig struct { + TrailingSpace bool +} + +// IndicatorConfig controls visual indicator and audio cue behavior. +type IndicatorConfig struct { + Enable bool + Backend string + DesktopAppName string + SoundEnable bool + SoundStartFile string + SoundStopFile string + SoundCompleteFile string + SoundCancelFile string + Height int + TextRecording string + TextProcessing string + TextError string + ErrorTimeoutMS int +} + +// CommandConfig stores a raw command string and its parsed argv form. +type CommandConfig struct { + Raw string + Argv []string +} + +// VocabConfig controls enabled speech phrase sets and dedupe limits. +type VocabConfig struct { + GlobalSets []string + Sets map[string]VocabSet + MaxPhrases int +} + +// VocabSet is one named phrase group with a shared boost value. +type VocabSet struct { + Name string + Boost float64 + Phrases []string +} + +// DebugConfig controls optional debug artifact output. +type DebugConfig struct { + EnableAudioDump bool + EnableGRPCDump bool +} + +// Warning is a non-fatal parse/validation message. +type Warning struct { + Line int + Message string +} + +// SpeechPhrase is the normalized phrase payload sent to ASR adapters. +type SpeechPhrase struct { + Phrase string + Boost float32 +} diff --git a/apps/sotto/internal/config/validate.go b/apps/sotto/internal/config/validate.go new file mode 100644 index 0000000..75b46af --- /dev/null +++ b/apps/sotto/internal/config/validate.go @@ -0,0 +1,120 @@ +package config + +import ( + "fmt" + "sort" + "strings" +) + +// Validate enforces config invariants and returns non-fatal warnings. +func Validate(cfg Config) ([]Warning, error) { + warnings := make([]Warning, 0) + + if strings.TrimSpace(cfg.RivaGRPC) == "" { + return nil, fmt.Errorf("riva_grpc must not be empty") + } + if strings.TrimSpace(cfg.RivaHTTP) == "" { + return nil, fmt.Errorf("riva_http must not be empty") + } + if strings.TrimSpace(cfg.RivaHealthPath) == "" { + return nil, fmt.Errorf("riva_health_path must not be empty") + } + if !strings.HasPrefix(strings.TrimSpace(cfg.RivaHealthPath), "/") { + return nil, fmt.Errorf("riva_health_path must start with '/'") + } + if strings.TrimSpace(cfg.ASR.LanguageCode) == "" { + return nil, fmt.Errorf("asr.language_code must not be empty") + } + backend := strings.ToLower(strings.TrimSpace(cfg.Indicator.Backend)) + if backend == "" { + return nil, fmt.Errorf("indicator.backend must not be empty") + } + if backend != "hypr" && backend != "desktop" { + return nil, fmt.Errorf("indicator.backend must be one of: hypr, desktop") + } + if backend == "desktop" && strings.TrimSpace(cfg.Indicator.DesktopAppName) == "" { + return nil, fmt.Errorf("indicator.desktop_app_name must not be empty when indicator.backend=desktop") + } + if cfg.Indicator.Height <= 0 { + return nil, fmt.Errorf("indicator.height must be > 0") + } + if cfg.Indicator.ErrorTimeoutMS < 0 { + return nil, fmt.Errorf("indicator.error_timeout_ms must be >= 0") + } + if cfg.Vocab.MaxPhrases <= 0 { + return nil, fmt.Errorf("vocab.max_phrases must be > 0") + } + if len(cfg.Clipboard.Argv) == 0 { + return nil, fmt.Errorf("clipboard_cmd must not be empty") + } + + if cfg.Paste.Enable && cfg.PasteCmd.Raw != "" && len(cfg.PasteCmd.Argv) == 0 { + return nil, fmt.Errorf("paste_cmd is configured but empty") + } + if cfg.Paste.Enable && len(cfg.PasteCmd.Argv) == 0 && strings.TrimSpace(cfg.Paste.Shortcut) == "" { + return nil, fmt.Errorf("paste.shortcut must not be empty when paste.enable=true and paste_cmd is unset") + } + + _, vocabWarnings, err := BuildSpeechPhrases(cfg) + if err != nil { + return nil, err + } + warnings = append(warnings, vocabWarnings...) + + return warnings, nil +} + +// BuildSpeechPhrases merges enabled vocab sets into deterministic ASR phrase payloads. +func BuildSpeechPhrases(cfg Config) ([]SpeechPhrase, []Warning, error) { + enabledSets := cfg.Vocab.GlobalSets + if len(enabledSets) == 0 { + return nil, nil, nil + } + + type candidate struct { + boost float64 + from string + } + + warnings := make([]Warning, 0) + selected := make(map[string]candidate) + + for _, name := range enabledSets { + set, ok := cfg.Vocab.Sets[name] + if !ok { + return nil, nil, fmt.Errorf("vocab.global references unknown set %q", name) + } + for _, phrase := range set.Phrases { + phrase = strings.TrimSpace(phrase) + if phrase == "" { + continue + } + if existing, exists := selected[phrase]; exists { + if set.Boost > existing.boost { + warnings = append(warnings, Warning{Message: fmt.Sprintf("phrase %q present in %q and %q; using higher boost %.2f", phrase, existing.from, name, set.Boost)}) + selected[phrase] = candidate{boost: set.Boost, from: name} + } + continue + } + selected[phrase] = candidate{boost: set.Boost, from: name} + } + } + + if len(selected) > cfg.Vocab.MaxPhrases { + return nil, nil, fmt.Errorf("vocabulary phrase count %d exceeds vocab.max_phrases=%d", len(selected), cfg.Vocab.MaxPhrases) + } + + phrases := make([]SpeechPhrase, 0, len(selected)) + for phrase, c := range selected { + phrases = append(phrases, SpeechPhrase{Phrase: phrase, Boost: float32(c.boost)}) + } + + sort.Slice(phrases, func(i, j int) bool { + if phrases[i].Phrase == phrases[j].Phrase { + return phrases[i].Boost < phrases[j].Boost + } + return phrases[i].Phrase < phrases[j].Phrase + }) + + return phrases, warnings, nil +} diff --git a/apps/sotto/internal/config/validate_test.go b/apps/sotto/internal/config/validate_test.go new file mode 100644 index 0000000..20e78e8 --- /dev/null +++ b/apps/sotto/internal/config/validate_test.go @@ -0,0 +1,66 @@ +package config + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestBuildSpeechPhrasesSortedAndHighestBoostWins(t *testing.T) { + cfg := Default() + cfg.Vocab.GlobalSets = []string{"core", "team"} + cfg.Vocab.Sets["core"] = VocabSet{Name: "core", Boost: 10, Phrases: []string{"beta", "alpha"}} + cfg.Vocab.Sets["team"] = VocabSet{Name: "team", Boost: 20, Phrases: []string{"alpha", "gamma"}} + + phrases, warnings, err := BuildSpeechPhrases(cfg) + require.NoError(t, err) + require.Len(t, warnings, 1) + require.Equal(t, []SpeechPhrase{ + {Phrase: "alpha", Boost: 20}, + {Phrase: "beta", Boost: 10}, + {Phrase: "gamma", Boost: 20}, + }, phrases) +} + +func TestValidateRejectsInvalidCoreFields(t *testing.T) { + tests := []struct { + name string + mutate func(*Config) + wantErr string + }{ + {name: "empty riva grpc", mutate: func(c *Config) { c.RivaGRPC = "" }, wantErr: "riva_grpc"}, + {name: "empty riva http", mutate: func(c *Config) { c.RivaHTTP = "" }, wantErr: "riva_http"}, + {name: "bad health path", mutate: func(c *Config) { c.RivaHealthPath = "v1/health" }, wantErr: "must start"}, + {name: "empty language", mutate: func(c *Config) { c.ASR.LanguageCode = "" }, wantErr: "language_code"}, + {name: "invalid indicator backend", mutate: func(c *Config) { c.Indicator.Backend = "unknown" }, wantErr: "indicator.backend"}, + {name: "missing desktop app name", mutate: func(c *Config) { + c.Indicator.Backend = "desktop" + c.Indicator.DesktopAppName = "" + }, wantErr: "indicator.desktop_app_name"}, + {name: "invalid indicator height", mutate: func(c *Config) { c.Indicator.Height = 0 }, wantErr: "indicator.height"}, + {name: "negative error timeout", mutate: func(c *Config) { c.Indicator.ErrorTimeoutMS = -1 }, wantErr: "error_timeout"}, + {name: "invalid max phrases", mutate: func(c *Config) { c.Vocab.MaxPhrases = 0 }, wantErr: "vocab.max_phrases"}, + {name: "empty clipboard argv", mutate: func(c *Config) { c.Clipboard.Argv = nil }, wantErr: "clipboard_cmd"}, + {name: "paste command raw but empty argv", mutate: func(c *Config) { + c.Paste.Enable = true + c.PasteCmd.Raw = "mycmd" + c.PasteCmd.Argv = nil + }, wantErr: "paste_cmd"}, + {name: "missing paste shortcut when using default paste", mutate: func(c *Config) { + c.Paste.Enable = true + c.PasteCmd = CommandConfig{} + c.Paste.Shortcut = "" + }, wantErr: "paste.shortcut"}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + cfg := Default() + tc.mutate(&cfg) + + _, err := Validate(cfg) + require.Error(t, err) + require.Contains(t, err.Error(), tc.wantErr) + }) + } +} diff --git a/apps/sotto/internal/doctor/doctor.go b/apps/sotto/internal/doctor/doctor.go new file mode 100644 index 0000000..e5eed49 --- /dev/null +++ b/apps/sotto/internal/doctor/doctor.go @@ -0,0 +1,155 @@ +// Package doctor runs runtime readiness diagnostics for config, tools, audio, and Riva. +package doctor + +import ( + "context" + "fmt" + "io" + "net/http" + "os" + "os/exec" + "strings" + "time" + + "github.com/rbright/sotto/internal/audio" + "github.com/rbright/sotto/internal/config" +) + +// Check is one doctor assertion result. +type Check struct { + Name string + Pass bool + Message string +} + +// Report is the full doctor output contract. +type Report struct { + Checks []Check +} + +// OK returns true when all checks pass. +func (r Report) OK() bool { + for _, check := range r.Checks { + if !check.Pass { + return false + } + } + return true +} + +// String renders the report as user-facing text output. +func (r Report) String() string { + var b strings.Builder + for _, check := range r.Checks { + status := "OK" + if !check.Pass { + status = "FAIL" + } + b.WriteString(fmt.Sprintf("[%s] %s: %s\n", status, check.Name, check.Message)) + } + return strings.TrimSuffix(b.String(), "\n") +} + +// Run executes environment/config/runtime checks for a loaded config. +func Run(cfg config.Loaded) Report { + checks := []Check{} + + checks = append(checks, Check{ + Name: "config", + Pass: true, + Message: fmt.Sprintf("loaded %q", cfg.Path), + }) + + checks = append(checks, checkEnv("XDG_SESSION_TYPE", func(v string) bool { + return strings.EqualFold(strings.TrimSpace(v), "wayland") + }, "session type is wayland", "expected XDG_SESSION_TYPE=wayland")) + + checks = append(checks, checkEnv("HYPRLAND_INSTANCE_SIGNATURE", func(v string) bool { + return strings.TrimSpace(v) != "" + }, "Hyprland session detected", "HYPRLAND_INSTANCE_SIGNATURE is empty")) + + checks = append(checks, checkCommand(cfg.Config.Clipboard.Argv, "clipboard_cmd")) + + if cfg.Config.Paste.Enable { + if len(cfg.Config.PasteCmd.Argv) > 0 { + checks = append(checks, checkCommand(cfg.Config.PasteCmd.Argv, "paste_cmd")) + } else { + checks = append(checks, checkBinary("hyprctl", "default paste path requires hyprctl")) + } + } + + checks = append(checks, checkAudioSelection(cfg.Config)) + checks = append(checks, checkRivaReady(cfg.Config)) + + return Report{Checks: checks} +} + +// checkEnv validates an environment variable through a caller-supplied predicate. +func checkEnv(name string, predicate func(string) bool, okMsg, failMsg string) Check { + value := os.Getenv(name) + if predicate(value) { + return Check{Name: name, Pass: true, Message: okMsg} + } + return Check{Name: name, Pass: false, Message: failMsg} +} + +// checkCommand validates that argv contains a runnable command. +func checkCommand(argv []string, name string) Check { + if len(argv) == 0 { + return Check{Name: name, Pass: false, Message: "command is empty"} + } + return checkBinary(argv[0], fmt.Sprintf("%s command is available", name)) +} + +// checkBinary validates that a binary exists in PATH. +func checkBinary(bin string, okMsg string) Check { + path, err := exec.LookPath(bin) + if err != nil { + return Check{Name: bin, Pass: false, Message: fmt.Sprintf("binary not found in PATH: %s", bin)} + } + return Check{Name: bin, Pass: true, Message: fmt.Sprintf("found at %s (%s)", path, okMsg)} +} + +// checkAudioSelection runs live device selection to surface selection/fallback issues. +func checkAudioSelection(cfg config.Config) Check { + selection, err := audio.SelectDevice(context.Background(), cfg.Audio.Input, cfg.Audio.Fallback) + if err != nil { + return Check{Name: "audio.device", Pass: false, Message: err.Error()} + } + message := fmt.Sprintf("selected %q", selection.Device.ID) + if selection.Warning != "" { + message = message + " (" + selection.Warning + ")" + } + return Check{Name: "audio.device", Pass: true, Message: message} +} + +// checkRivaReady probes the configured Riva HTTP ready endpoint. +func checkRivaReady(cfg config.Config) Check { + base := strings.TrimSpace(cfg.RivaHTTP) + if base == "" { + return Check{Name: "riva.ready", Pass: false, Message: "riva_http is empty"} + } + if !strings.HasPrefix(base, "http://") && !strings.HasPrefix(base, "https://") { + base = "http://" + base + } + + url := strings.TrimRight(base, "/") + cfg.RivaHealthPath + client := http.Client{Timeout: 2 * time.Second} + resp, err := client.Get(url) + if err != nil { + return Check{Name: "riva.ready", Pass: false, Message: fmt.Sprintf("request failed: %v", err)} + } + defer resp.Body.Close() + + body, _ := io.ReadAll(io.LimitReader(resp.Body, 256)) + if resp.StatusCode < 200 || resp.StatusCode >= 300 { + return Check{Name: "riva.ready", Pass: false, Message: fmt.Sprintf("HTTP %d from %s", resp.StatusCode, url)} + } + + bodyText := strings.ToLower(strings.TrimSpace(string(body))) + if bodyText != "" && !strings.Contains(bodyText, "ready") { + return Check{Name: "riva.ready", Pass: true, Message: fmt.Sprintf("HTTP %d from %s", resp.StatusCode, url)} + } + + return Check{Name: "riva.ready", Pass: true, Message: fmt.Sprintf("ready at %s", url)} +} diff --git a/apps/sotto/internal/doctor/doctor_test.go b/apps/sotto/internal/doctor/doctor_test.go new file mode 100644 index 0000000..428eafb --- /dev/null +++ b/apps/sotto/internal/doctor/doctor_test.go @@ -0,0 +1,133 @@ +package doctor + +import ( + "net/http" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/rbright/sotto/internal/config" + "github.com/stretchr/testify/require" +) + +func TestReportOKAndString(t *testing.T) { + report := Report{Checks: []Check{ + {Name: "one", Pass: true, Message: "good"}, + {Name: "two", Pass: false, Message: "bad"}, + }} + + require.False(t, report.OK()) + text := report.String() + require.Contains(t, text, "[OK] one: good") + require.Contains(t, text, "[FAIL] two: bad") +} + +func TestCheckEnv(t *testing.T) { + t.Setenv("TEST_DOCTOR_ENV", "wayland") + + check := checkEnv( + "TEST_DOCTOR_ENV", + func(v string) bool { return strings.EqualFold(v, "wayland") }, + "looks good", + "unexpected", + ) + + require.True(t, check.Pass) + require.Equal(t, "looks good", check.Message) +} + +func TestCheckCommandEmpty(t *testing.T) { + check := checkCommand(nil, "clipboard_cmd") + require.False(t, check.Pass) + require.Contains(t, check.Message, "command is empty") +} + +func TestCheckBinaryFound(t *testing.T) { + check := checkBinary("sh", "shell available") + require.True(t, check.Pass) + require.Contains(t, check.Message, "shell available") +} + +func TestCheckBinaryMissing(t *testing.T) { + check := checkBinary("definitely-not-a-real-binary", "unused") + require.False(t, check.Pass) + require.Contains(t, check.Message, "binary not found") +} + +func TestCheckCommandUsesBinaryFromPath(t *testing.T) { + dir := t.TempDir() + scriptPath := filepath.Join(dir, "fake-bin") + require.NoError(t, os.WriteFile(scriptPath, []byte("#!/usr/bin/env bash\nexit 0\n"), 0o755)) + t.Setenv("PATH", dir+":"+os.Getenv("PATH")) + + check := checkCommand([]string{"fake-bin", "--arg"}, "clipboard_cmd") + require.True(t, check.Pass) + require.Contains(t, check.Message, "clipboard_cmd command is available") +} + +func TestCheckRivaReadySuccess(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + require.Equal(t, "/v1/health/ready", r.URL.Path) + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ready")) + })) + t.Cleanup(server.Close) + + cfg := config.Default() + cfg.RivaHTTP = strings.TrimPrefix(server.URL, "http://") + cfg.RivaHealthPath = "/v1/health/ready" + + check := checkRivaReady(cfg) + require.True(t, check.Pass) + require.Contains(t, check.Message, "ready at") +} + +func TestCheckRivaReadyFailureStatusCode(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusServiceUnavailable) + })) + t.Cleanup(server.Close) + + cfg := config.Default() + cfg.RivaHTTP = strings.TrimPrefix(server.URL, "http://") + cfg.RivaHealthPath = "/v1/health/ready" + + check := checkRivaReady(cfg) + require.False(t, check.Pass) + require.Contains(t, check.Message, "HTTP 503") +} + +func TestCheckRivaReadyPassesOnHTTP200NonReadyBody(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("warming-up")) + })) + t.Cleanup(server.Close) + + cfg := config.Default() + cfg.RivaHTTP = strings.TrimPrefix(server.URL, "http://") + cfg.RivaHealthPath = "/v1/health/ready" + + check := checkRivaReady(cfg) + require.True(t, check.Pass) + require.Contains(t, check.Message, "HTTP 200") +} + +func TestCheckRivaReadyEmptyBaseURL(t *testing.T) { + cfg := config.Default() + cfg.RivaHTTP = "" + + check := checkRivaReady(cfg) + require.False(t, check.Pass) + require.Contains(t, check.Message, "riva_http is empty") +} + +func TestCheckAudioSelectionFailureWithInvalidPulseServer(t *testing.T) { + t.Setenv("PULSE_SERVER", "unix:/tmp/definitely-missing-pulse-server") + + check := checkAudioSelection(config.Default()) + require.False(t, check.Pass) + require.Contains(t, check.Name, "audio.device") +} diff --git a/apps/sotto/internal/fsm/fsm.go b/apps/sotto/internal/fsm/fsm.go new file mode 100644 index 0000000..27a2a07 --- /dev/null +++ b/apps/sotto/internal/fsm/fsm.go @@ -0,0 +1,73 @@ +// Package fsm contains the session lifecycle state machine. +package fsm + +import "fmt" + +// State is one lifecycle state for a dictation session. +type State string + +// Event is one transition trigger consumed by the state machine. +type Event string + +const ( + StateIdle State = "idle" + StateRecording State = "recording" + StateTranscribing State = "transcribing" + StateError State = "error" +) + +const ( + EventStart Event = "start" + EventStop Event = "stop" + EventCancel Event = "cancel" + EventTranscribed Event = "transcribed" + EventFail Event = "fail" + EventReset Event = "reset" +) + +// Transition validates and applies one state transition. +func Transition(current State, event Event) (State, error) { + if event == EventFail { + return StateError, nil + } + + switch current { + case StateIdle: + switch event { + case EventStart: + return StateRecording, nil + default: + return current, invalidTransition(current, event) + } + case StateRecording: + switch event { + case EventStop: + return StateTranscribing, nil + case EventCancel: + return StateIdle, nil + default: + return current, invalidTransition(current, event) + } + case StateTranscribing: + switch event { + case EventTranscribed: + return StateIdle, nil + default: + return current, invalidTransition(current, event) + } + case StateError: + switch event { + case EventReset: + return StateIdle, nil + default: + return current, invalidTransition(current, event) + } + default: + return current, fmt.Errorf("unknown state %q", current) + } +} + +// invalidTransition formats a stable error message used by tests and callers. +func invalidTransition(state State, event Event) error { + return fmt.Errorf("invalid transition: %s --(%s)--> ?", state, event) +} diff --git a/apps/sotto/internal/fsm/fsm_test.go b/apps/sotto/internal/fsm/fsm_test.go new file mode 100644 index 0000000..b800982 --- /dev/null +++ b/apps/sotto/internal/fsm/fsm_test.go @@ -0,0 +1,72 @@ +package fsm + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestTransitionHappyPath(t *testing.T) { + s := StateIdle + + next, err := Transition(s, EventStart) + require.NoError(t, err) + require.Equal(t, StateRecording, next) + + next, err = Transition(next, EventStop) + require.NoError(t, err) + require.Equal(t, StateTranscribing, next) + + next, err = Transition(next, EventTranscribed) + require.NoError(t, err) + require.Equal(t, StateIdle, next) +} + +func TestTransitionFailFromAnyStateGoesError(t *testing.T) { + states := []State{StateIdle, StateRecording, StateTranscribing, StateError} + for _, state := range states { + next, err := Transition(state, EventFail) + require.NoError(t, err) + require.Equal(t, StateError, next) + } +} + +func TestTransitionMatrixInvalidTransitions(t *testing.T) { + tests := []struct { + name string + state State + event Event + want State + wantErr bool + }{ + {name: "idle stop invalid", state: StateIdle, event: EventStop, want: StateIdle, wantErr: true}, + {name: "idle cancel invalid", state: StateIdle, event: EventCancel, want: StateIdle, wantErr: true}, + {name: "recording start invalid", state: StateRecording, event: EventStart, want: StateRecording, wantErr: true}, + {name: "recording transcribed invalid", state: StateRecording, event: EventTranscribed, want: StateRecording, wantErr: true}, + {name: "transcribing stop invalid", state: StateTranscribing, event: EventStop, want: StateTranscribing, wantErr: true}, + {name: "transcribing cancel invalid", state: StateTranscribing, event: EventCancel, want: StateTranscribing, wantErr: true}, + {name: "error start invalid", state: StateError, event: EventStart, want: StateError, wantErr: true}, + {name: "error stop invalid", state: StateError, event: EventStop, want: StateError, wantErr: true}, + {name: "error reset valid", state: StateError, event: EventReset, want: StateIdle, wantErr: false}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + next, err := Transition(tc.state, tc.event) + require.Equal(t, tc.want, next) + if tc.wantErr { + require.Error(t, err) + require.Contains(t, err.Error(), "invalid transition") + return + } + require.NoError(t, err) + }) + } +} + +func TestTransitionUnknownState(t *testing.T) { + next, err := Transition(State("mystery"), EventStart) + require.Error(t, err) + require.Contains(t, err.Error(), "unknown state") + require.Equal(t, State("mystery"), next) +} diff --git a/apps/sotto/internal/hypr/hypr.go b/apps/sotto/internal/hypr/hypr.go new file mode 100644 index 0000000..8b863d0 --- /dev/null +++ b/apps/sotto/internal/hypr/hypr.go @@ -0,0 +1,53 @@ +// Package hypr wraps Hyprland IPC commands and JSON query decoding. +package hypr + +import ( + "context" + "errors" + "fmt" + "os/exec" + "strings" +) + +// Controller abstracts submap control for adapters that need it. +type Controller interface { + SetSubmap(ctx context.Context, name string) error + ResetSubmap(ctx context.Context) error +} + +// CLIController issues submap commands through hyprctl. +type CLIController struct{} + +// SetSubmap sets an explicit Hyprland submap name. +func (CLIController) SetSubmap(ctx context.Context, name string) error { + name = strings.TrimSpace(name) + if name == "" { + return errors.New("submap name must not be empty") + } + return runHyprctl(ctx, "dispatch", "submap", name) +} + +// ResetSubmap resets back to the default Hyprland submap. +func (c CLIController) ResetSubmap(ctx context.Context) error { + return c.SetSubmap(ctx, "reset") +} + +// runHyprctl executes hyprctl and discards stdout on success. +func runHyprctl(ctx context.Context, args ...string) error { + _, err := runHyprctlOutput(ctx, args...) + return err +} + +// runHyprctlOutput executes hyprctl and returns combined output for diagnostics. +func runHyprctlOutput(ctx context.Context, args ...string) ([]byte, error) { + cmd := exec.CommandContext(ctx, "hyprctl", args...) + out, err := cmd.CombinedOutput() + if err != nil { + trimmed := strings.TrimSpace(string(out)) + if trimmed == "" { + return nil, fmt.Errorf("hyprctl %v failed: %w", args, err) + } + return nil, fmt.Errorf("hyprctl %v failed: %w (%s)", args, err, trimmed) + } + return out, nil +} diff --git a/apps/sotto/internal/hypr/hypr_test.go b/apps/sotto/internal/hypr/hypr_test.go new file mode 100644 index 0000000..18feac4 --- /dev/null +++ b/apps/sotto/internal/hypr/hypr_test.go @@ -0,0 +1,99 @@ +package hypr + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestQueryActiveWindowAndFocusedMonitor(t *testing.T) { + argsFile := filepath.Join(t.TempDir(), "hypr-args.log") + t.Setenv("HYPR_ARGS_FILE", argsFile) + installHyprctlStub(t, ` +if [[ "${1:-}" == "-j" && "${2:-}" == "activewindow" ]]; then + echo '{"address":" 0xabc ","class":" brave-browser ","initialClass":" Brave "}' + exit 0 +fi +if [[ "${1:-}" == "-j" && "${2:-}" == "monitors" ]]; then + echo '[{"name":"HDMI-A-1","focused":false},{"name":" DP-1 ","focused":true}]' + exit 0 +fi +printf '%s\n' "$*" >> "${HYPR_ARGS_FILE}" +`) + + window, err := QueryActiveWindow(context.Background()) + require.NoError(t, err) + require.Equal(t, "0xabc", window.Address) + require.Equal(t, "brave-browser", window.Class) + require.Equal(t, "Brave", window.InitialClass) + + monitor, err := QueryFocusedMonitor(context.Background()) + require.NoError(t, err) + require.Equal(t, "DP-1", monitor) +} + +func TestQueryActiveWindowRejectsEmptyAddress(t *testing.T) { + installHyprctlStub(t, ` +if [[ "${1:-}" == "-j" && "${2:-}" == "activewindow" ]]; then + echo '{"address":"","class":"brave"}' + exit 0 +fi +echo '[]' +`) + + _, err := QueryActiveWindow(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "empty address") +} + +func TestSendShortcutRequiresNonEmptyPayload(t *testing.T) { + err := SendShortcut(context.Background(), " ") + require.Error(t, err) + require.Contains(t, err.Error(), "non-empty payload") +} + +func TestNotifyAndDismissUseHyprctlDispatch(t *testing.T) { + argsFile := filepath.Join(t.TempDir(), "hypr-args.log") + t.Setenv("HYPR_ARGS_FILE", argsFile) + installHyprctlStub(t, ` +printf '%s\n' "$*" >> "${HYPR_ARGS_FILE}" +`) + + err := Notify(context.Background(), 3, 1200, "", "Speech recognition error") + require.NoError(t, err) + + err = DismissNotify(context.Background()) + require.NoError(t, err) + + data, err := os.ReadFile(argsFile) + require.NoError(t, err) + lines := strings.Split(strings.TrimSpace(string(data)), "\n") + require.Len(t, lines, 2) + require.Equal(t, "--quiet dispatch notify 3 1200 rgb(89b4fa) Speech recognition error", lines[0]) + require.Equal(t, "--quiet dispatch dismissnotify", lines[1]) +} + +func TestSendShortcutReturnsCombinedOutputOnFailure(t *testing.T) { + installHyprctlStub(t, ` +echo 'boom from hyprctl' >&2 +exit 1 +`) + + err := SendShortcut(context.Background(), "CTRL,V,address:0xabc") + require.Error(t, err) + require.Contains(t, err.Error(), "boom from hyprctl") +} + +func installHyprctlStub(t *testing.T, body string) { + t.Helper() + + dir := t.TempDir() + path := filepath.Join(dir, "hyprctl") + script := "#!/usr/bin/env bash\nset -euo pipefail\n" + body + "\n" + require.NoError(t, os.WriteFile(path, []byte(script), 0o755)) + t.Setenv("PATH", dir+":"+os.Getenv("PATH")) +} diff --git a/apps/sotto/internal/hypr/query.go b/apps/sotto/internal/hypr/query.go new file mode 100644 index 0000000..f400d8f --- /dev/null +++ b/apps/sotto/internal/hypr/query.go @@ -0,0 +1,103 @@ +package hypr + +import ( + "context" + "encoding/json" + "fmt" + "strconv" + "strings" +) + +// ActiveWindow contains the fields needed for paste dispatch targeting. +type ActiveWindow struct { + Address string `json:"address"` + Class string `json:"class"` + InitialClass string `json:"initialClass"` +} + +type monitor struct { + Name string `json:"name"` + Focused bool `json:"focused"` +} + +// QueryActiveWindow fetches and validates the active-window contract from hyprctl. +func QueryActiveWindow(ctx context.Context) (ActiveWindow, error) { + output, err := runHyprctlJSON(ctx, "activewindow") + if err != nil { + return ActiveWindow{}, err + } + + var window ActiveWindow + if err := json.Unmarshal(output, &window); err != nil { + return ActiveWindow{}, fmt.Errorf("decode hyprctl activewindow json: %w", err) + } + window.Address = strings.TrimSpace(window.Address) + window.Class = strings.TrimSpace(window.Class) + window.InitialClass = strings.TrimSpace(window.InitialClass) + if window.Address == "" { + return ActiveWindow{}, fmt.Errorf("hyprctl activewindow returned empty address") + } + return window, nil +} + +// QueryFocusedMonitor returns the focused monitor name (or the first monitor fallback). +func QueryFocusedMonitor(ctx context.Context) (string, error) { + output, err := runHyprctlJSON(ctx, "monitors") + if err != nil { + return "", err + } + + var monitors []monitor + if err := json.Unmarshal(output, &monitors); err != nil { + return "", fmt.Errorf("decode hyprctl monitors json: %w", err) + } + for _, mon := range monitors { + if mon.Focused { + return strings.TrimSpace(mon.Name), nil + } + } + if len(monitors) == 0 { + return "", fmt.Errorf("hyprctl monitors returned no outputs") + } + return strings.TrimSpace(monitors[0].Name), nil +} + +// SendShortcut sends a literal hyprctl sendshortcut payload. +func SendShortcut(ctx context.Context, shortcut string) error { + shortcut = strings.TrimSpace(shortcut) + if shortcut == "" { + return fmt.Errorf("sendshortcut requires a non-empty payload") + } + return runHyprctl(ctx, "--quiet", "dispatch", "sendshortcut", shortcut) +} + +// Notify sends a Hyprland notification payload. +func Notify(ctx context.Context, icon int, timeoutMS int, color string, text string) error { + if strings.TrimSpace(color) == "" { + color = "rgb(89b4fa)" + } + return runHyprctl( + ctx, + "--quiet", + "dispatch", + "notify", + strconv.Itoa(icon), + strconv.Itoa(timeoutMS), + color, + text, + ) +} + +// DismissNotify dismisses active Hyprland notifications. +func DismissNotify(ctx context.Context) error { + return runHyprctl(ctx, "--quiet", "dispatch", "dismissnotify") +} + +// runHyprctlJSON executes a JSON-returning hyprctl subcommand. +func runHyprctlJSON(ctx context.Context, target string) ([]byte, error) { + output, err := runHyprctlOutput(ctx, "-j", target) + if err != nil { + return nil, err + } + return output, nil +} diff --git a/apps/sotto/internal/indicator/desktop_notify.go b/apps/sotto/internal/indicator/desktop_notify.go new file mode 100644 index 0000000..bc7e41f --- /dev/null +++ b/apps/sotto/internal/indicator/desktop_notify.go @@ -0,0 +1,76 @@ +package indicator + +import ( + "context" + "fmt" + "os/exec" + "strconv" + "strings" +) + +// desktopNotify sends a freedesktop notification over DBus via busctl. +// It returns the notification ID assigned by the server. +func desktopNotify(ctx context.Context, appName string, replaceID uint32, summary string, timeoutMS int) (uint32, error) { + args := []string{ + "--user", + "call", + "org.freedesktop.Notifications", + "/org/freedesktop/Notifications", + "org.freedesktop.Notifications", + "Notify", + "susssasa{sv}i", + appName, + fmt.Sprintf("%d", replaceID), + "", + summary, + "", + "0", // actions array length + "0", // hints map length + fmt.Sprintf("%d", timeoutMS), + } + + out, err := exec.CommandContext(ctx, "busctl", args...).CombinedOutput() + if err != nil { + trimmed := strings.TrimSpace(string(out)) + if trimmed == "" { + return 0, fmt.Errorf("desktop notify failed: %w", err) + } + return 0, fmt.Errorf("desktop notify failed: %w (%s)", err, trimmed) + } + + fields := strings.Fields(strings.TrimSpace(string(out))) + if len(fields) < 2 || fields[0] != "u" { + return 0, fmt.Errorf("desktop notify invalid response: %q", strings.TrimSpace(string(out))) + } + + value, parseErr := strconv.ParseUint(fields[1], 10, 32) + if parseErr != nil { + return 0, fmt.Errorf("desktop notify parse id %q: %w", fields[1], parseErr) + } + return uint32(value), nil +} + +// desktopDismiss requests explicit close by notification ID. +func desktopDismiss(ctx context.Context, id uint32) error { + args := []string{ + "--user", + "call", + "org.freedesktop.Notifications", + "/org/freedesktop/Notifications", + "org.freedesktop.Notifications", + "CloseNotification", + "u", + fmt.Sprintf("%d", id), + } + + out, err := exec.CommandContext(ctx, "busctl", args...).CombinedOutput() + if err != nil { + trimmed := strings.TrimSpace(string(out)) + if trimmed == "" { + return fmt.Errorf("desktop dismiss failed: %w", err) + } + return fmt.Errorf("desktop dismiss failed: %w (%s)", err, trimmed) + } + + return nil +} diff --git a/apps/sotto/internal/indicator/indicator.go b/apps/sotto/internal/indicator/indicator.go new file mode 100644 index 0000000..3a71239 --- /dev/null +++ b/apps/sotto/internal/indicator/indicator.go @@ -0,0 +1,213 @@ +// Package indicator handles visual state notifications and audio cue playback. +package indicator + +import ( + "context" + "log/slog" + "strings" + "sync" + "time" + + "github.com/rbright/sotto/internal/config" + "github.com/rbright/sotto/internal/hypr" +) + +// Controller is the session-facing indicator contract. +type Controller interface { + ShowRecording(context.Context) + ShowTranscribing(context.Context) + ShowError(context.Context, string) + CueStop(context.Context) + CueComplete(context.Context) + CueCancel(context.Context) + Hide(context.Context) + FocusedMonitor() string +} + +// HyprNotify is the concrete indicator implementation used by runtime sessions. +// It can route notifications via Hyprland or desktop DBus based on config backend. +type HyprNotify struct { + cfg config.IndicatorConfig + logger *slog.Logger + + mu sync.Mutex + focusedMonitor string + desktopNotificationID uint32 + soundMu sync.Mutex +} + +// NewHyprNotify creates an indicator controller from config. +func NewHyprNotify(cfg config.IndicatorConfig, logger *slog.Logger) *HyprNotify { + return &HyprNotify{cfg: cfg, logger: logger} +} + +// ShowRecording signals recording start and emits the start cue. +func (h *HyprNotify) ShowRecording(ctx context.Context) { + h.playCue(cueStart) + if !h.cfg.Enable { + return + } + h.ensureFocusedMonitor(ctx) + h.run(ctx, func(ctx context.Context) error { + return h.notify(ctx, 1, 300000, "rgb(89b4fa)", h.cfg.TextRecording) + }) +} + +// ShowTranscribing signals the post-capture transcription state. +func (h *HyprNotify) ShowTranscribing(ctx context.Context) { + if !h.cfg.Enable { + return + } + h.run(ctx, func(ctx context.Context) error { + return h.notify(ctx, 1, 300000, "rgb(cba6f7)", h.cfg.TextProcessing) + }) +} + +// ShowError displays an error-state indicator message. +func (h *HyprNotify) ShowError(ctx context.Context, text string) { + if !h.cfg.Enable { + return + } + if text == "" { + text = h.cfg.TextError + } + timeout := h.cfg.ErrorTimeoutMS + if timeout <= 0 { + timeout = 1200 + } + h.run(ctx, func(ctx context.Context) error { + return h.notify(ctx, 3, timeout, "rgb(f38ba8)", text) + }) +} + +// CueStop emits the stop cue. +func (h *HyprNotify) CueStop(context.Context) { + h.playCue(cueStop) +} + +// CueComplete emits the successful-commit cue. +func (h *HyprNotify) CueComplete(context.Context) { + h.playCue(cueComplete) +} + +// CueCancel emits the cancel cue. +func (h *HyprNotify) CueCancel(context.Context) { + h.playCue(cueCancel) +} + +// Hide dismisses the active indicator surface. +func (h *HyprNotify) Hide(ctx context.Context) { + if !h.cfg.Enable { + return + } + h.run(ctx, h.dismiss) +} + +// FocusedMonitor returns the monitor captured when recording began. +func (h *HyprNotify) FocusedMonitor() string { + h.mu.Lock() + defer h.mu.Unlock() + return h.focusedMonitor +} + +// ensureFocusedMonitor resolves and caches the focused monitor once per session. +func (h *HyprNotify) ensureFocusedMonitor(ctx context.Context) { + h.mu.Lock() + alreadySet := h.focusedMonitor != "" + h.mu.Unlock() + if alreadySet { + return + } + + monitor, err := hypr.QueryFocusedMonitor(ctx) + if err != nil { + h.log("indicator focused monitor query failed", err) + return + } + + h.mu.Lock() + h.focusedMonitor = monitor + h.mu.Unlock() +} + +// notify dispatches indicator output through the configured backend. +func (h *HyprNotify) notify(ctx context.Context, icon int, timeoutMS int, color string, text string) error { + if strings.EqualFold(strings.TrimSpace(h.cfg.Backend), "desktop") { + return h.notifyDesktop(ctx, timeoutMS, text) + } + return hypr.Notify(ctx, icon, timeoutMS, color, text) +} + +// dismiss removes indicator output from the configured backend. +func (h *HyprNotify) dismiss(ctx context.Context) error { + if strings.EqualFold(strings.TrimSpace(h.cfg.Backend), "desktop") { + return h.dismissDesktop(ctx) + } + return hypr.DismissNotify(ctx) +} + +// notifyDesktop sends a replaceable desktop notification and stores its ID. +func (h *HyprNotify) notifyDesktop(ctx context.Context, timeoutMS int, text string) error { + h.mu.Lock() + replaceID := h.desktopNotificationID + h.mu.Unlock() + + appName := strings.TrimSpace(h.cfg.DesktopAppName) + if appName == "" { + appName = "sotto-indicator" + } + + id, err := desktopNotify(ctx, appName, replaceID, text, timeoutMS) + if err != nil { + return err + } + + h.mu.Lock() + h.desktopNotificationID = id + h.mu.Unlock() + return nil +} + +// dismissDesktop closes the current desktop notification ID when present. +func (h *HyprNotify) dismissDesktop(ctx context.Context) error { + h.mu.Lock() + id := h.desktopNotificationID + h.desktopNotificationID = 0 + h.mu.Unlock() + + if id == 0 { + return nil + } + return desktopDismiss(ctx, id) +} + +// run executes an indicator operation with a bounded timeout. +func (h *HyprNotify) run(ctx context.Context, fn func(context.Context) error) { + runCtx, cancel := context.WithTimeout(ctx, 400*time.Millisecond) + defer cancel() + if err := fn(runCtx); err != nil { + h.log("indicator dispatch failed", err) + } +} + +// playCue serializes cue playback and emits audio asynchronously. +func (h *HyprNotify) playCue(kind cueKind) { + if !h.cfg.SoundEnable { + return + } + go func() { + h.soundMu.Lock() + defer h.soundMu.Unlock() + if err := emitCue(kind, h.cfg); err != nil { + h.log("indicator audio cue failed", err) + } + }() +} + +// log emits debug-only indicator failures to the runtime logger. +func (h *HyprNotify) log(message string, err error) { + if h.logger == nil || err == nil { + return + } + h.logger.Debug(message, "error", err.Error()) +} diff --git a/apps/sotto/internal/indicator/indicator_test.go b/apps/sotto/internal/indicator/indicator_test.go new file mode 100644 index 0000000..9d86349 --- /dev/null +++ b/apps/sotto/internal/indicator/indicator_test.go @@ -0,0 +1,164 @@ +package indicator + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/rbright/sotto/internal/config" + "github.com/stretchr/testify/require" +) + +func TestHyprNotifyDispatchAndFocusedMonitorTracking(t *testing.T) { + argsFile := filepath.Join(t.TempDir(), "hypr-args.log") + t.Setenv("HYPR_ARGS_FILE", argsFile) + installHyprctlStub(t, ` +if [[ "${1:-}" == "-j" && "${2:-}" == "monitors" ]]; then + echo '[{"name":"DP-1","focused":true}]' + exit 0 +fi +printf '%s\n' "$*" >> "${HYPR_ARGS_FILE}" +`) + + cfg := config.Default().Indicator + cfg.SoundEnable = false + cfg.Enable = true + cfg.TextRecording = "Recording" + cfg.TextProcessing = "Transcribing" + cfg.TextError = "Speech error" + + notify := NewHyprNotify(cfg, nil) + notify.ShowRecording(context.Background()) + notify.ShowTranscribing(context.Background()) + notify.ShowError(context.Background(), "") + notify.Hide(context.Background()) + + require.Equal(t, "DP-1", notify.FocusedMonitor()) + + data, err := os.ReadFile(argsFile) + require.NoError(t, err) + lines := strings.Split(strings.TrimSpace(string(data)), "\n") + require.Len(t, lines, 4) + require.Equal(t, "--quiet dispatch notify 1 300000 rgb(89b4fa) Recording", lines[0]) + require.Equal(t, "--quiet dispatch notify 1 300000 rgb(cba6f7) Transcribing", lines[1]) + require.Equal(t, "--quiet dispatch notify 3 1600 rgb(f38ba8) Speech error", lines[2]) + require.Equal(t, "--quiet dispatch dismissnotify", lines[3]) +} + +func TestHyprNotifyShowErrorUsesProvidedTextAndDefaultTimeout(t *testing.T) { + argsFile := filepath.Join(t.TempDir(), "hypr-args.log") + t.Setenv("HYPR_ARGS_FILE", argsFile) + installHyprctlStub(t, ` +printf '%s\n' "$*" >> "${HYPR_ARGS_FILE}" +`) + + cfg := config.Default().Indicator + cfg.SoundEnable = false + cfg.ErrorTimeoutMS = 0 // exercises fallback to 1200ms + + notify := NewHyprNotify(cfg, nil) + notify.ShowError(context.Background(), "custom error") + + data, err := os.ReadFile(argsFile) + require.NoError(t, err) + require.Equal(t, "--quiet dispatch notify 3 1200 rgb(f38ba8) custom error\n", string(data)) +} + +func TestDesktopIndicatorUsesBusctlNotifyAndDismiss(t *testing.T) { + busctlArgs := filepath.Join(t.TempDir(), "busctl-args.log") + t.Setenv("BUSCTL_ARGS_FILE", busctlArgs) + installBusctlStub(t) + + hyprArgs := filepath.Join(t.TempDir(), "hypr-args.log") + t.Setenv("HYPR_ARGS_FILE", hyprArgs) + installHyprctlStub(t, ` +if [[ "${1:-}" == "-j" && "${2:-}" == "monitors" ]]; then + echo '[{"name":"DP-1","focused":true}]' + exit 0 +fi +printf '%s\n' "$*" >> "${HYPR_ARGS_FILE}" +`) + + cfg := config.Default().Indicator + cfg.Enable = true + cfg.SoundEnable = false + cfg.Backend = "desktop" + cfg.DesktopAppName = "sotto-indicator" + + notify := NewHyprNotify(cfg, nil) + notify.ShowRecording(context.Background()) + notify.ShowTranscribing(context.Background()) + notify.Hide(context.Background()) + + data, err := os.ReadFile(busctlArgs) + require.NoError(t, err) + lines := strings.Split(strings.TrimSpace(string(data)), "\n") + require.Len(t, lines, 3) + require.Contains(t, lines[0], "Notify susssasa{sv}i sotto-indicator 0") + require.Contains(t, lines[1], "Notify susssasa{sv}i sotto-indicator 42") + require.Contains(t, lines[2], "CloseNotification u 42") +} + +func TestHyprNotifyDisabledSkipsHyprctlDispatch(t *testing.T) { + argsFile := filepath.Join(t.TempDir(), "hypr-args.log") + t.Setenv("HYPR_ARGS_FILE", argsFile) + installHyprctlStub(t, ` +printf '%s\n' "$*" >> "${HYPR_ARGS_FILE}" +`) + + cfg := config.Default().Indicator + cfg.Enable = false + cfg.SoundEnable = false + + notify := NewHyprNotify(cfg, nil) + notify.ShowRecording(context.Background()) + notify.ShowTranscribing(context.Background()) + notify.ShowError(context.Background(), "ignored") + notify.Hide(context.Background()) + + _, err := os.Stat(argsFile) + require.Error(t, err) + require.True(t, os.IsNotExist(err)) +} + +func TestFocusedMonitorStaysEmptyWhenQueryFails(t *testing.T) { + installHyprctlStub(t, ` +exit 1 +`) + + cfg := config.Default().Indicator + cfg.Enable = true + cfg.SoundEnable = false + + notify := NewHyprNotify(cfg, nil) + notify.ShowRecording(context.Background()) + require.Empty(t, notify.FocusedMonitor()) +} + +func installHyprctlStub(t *testing.T, body string) { + t.Helper() + + dir := t.TempDir() + path := filepath.Join(dir, "hyprctl") + script := "#!/usr/bin/env bash\nset -euo pipefail\n" + body + "\n" + require.NoError(t, os.WriteFile(path, []byte(script), 0o755)) + t.Setenv("PATH", dir+":"+os.Getenv("PATH")) +} + +func installBusctlStub(t *testing.T) { + t.Helper() + + dir := t.TempDir() + path := filepath.Join(dir, "busctl") + script := `#!/usr/bin/env bash +set -euo pipefail +printf '%s\n' "$*" >> "${BUSCTL_ARGS_FILE}" +if [[ "$*" == *" Notify "* ]]; then + echo "u 42" +fi +` + require.NoError(t, os.WriteFile(path, []byte(script), 0o755)) + t.Setenv("PATH", dir+":"+os.Getenv("PATH")) +} diff --git a/apps/sotto/internal/indicator/sound.go b/apps/sotto/internal/indicator/sound.go new file mode 100644 index 0000000..8a0141a --- /dev/null +++ b/apps/sotto/internal/indicator/sound.go @@ -0,0 +1,257 @@ +package indicator + +import ( + "context" + "fmt" + "math" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + "github.com/jfreymuth/pulse" + "github.com/rbright/sotto/internal/config" +) + +// cueKind identifies each cue event used by the session lifecycle. +type cueKind int + +const ( + cueStart cueKind = iota + 1 + cueStop + cueComplete + cueCancel +) + +const cueSampleRate = 16000 + +// toneSpec describes one synthesized cue tone segment. +type toneSpec struct { + frequencyHz float64 + duration time.Duration + volume float64 +} + +var ( + startCuePCM = synthesizeCue([]toneSpec{ + {frequencyHz: 880, duration: 70 * time.Millisecond, volume: 0.18}, + {frequencyHz: 1175, duration: 70 * time.Millisecond, volume: 0.18}, + }) + stopCuePCM = synthesizeCue([]toneSpec{ + {frequencyHz: 620, duration: 120 * time.Millisecond, volume: 0.18}, + }) + completeCuePCM = synthesizeCue([]toneSpec{ + {frequencyHz: 740, duration: 65 * time.Millisecond, volume: 0.18}, + {frequencyHz: 988, duration: 90 * time.Millisecond, volume: 0.18}, + }) + cancelCuePCM = synthesizeCue([]toneSpec{ + {frequencyHz: 480, duration: 75 * time.Millisecond, volume: 0.18}, + {frequencyHz: 360, duration: 90 * time.Millisecond, volume: 0.18}, + }) +) + +// emitCue plays a configured WAV file when present, otherwise falls back to synthesis. +func emitCue(kind cueKind, cfg config.IndicatorConfig) error { + if path := cuePath(kind, cfg); path != "" { + if err := playCueFile(path); err == nil { + return nil + } + } + + samples := cueSamples(kind) + if len(samples) == 0 { + return nil + } + + return playSynthCue(samples) +} + +// cuePath resolves the configured WAV path for one cue kind. +func cuePath(kind cueKind, cfg config.IndicatorConfig) string { + var raw string + switch kind { + case cueStart: + raw = cfg.SoundStartFile + case cueStop: + raw = cfg.SoundStopFile + case cueComplete: + raw = cfg.SoundCompleteFile + case cueCancel: + raw = cfg.SoundCancelFile + default: + return "" + } + return expandUserPath(raw) +} + +// expandUserPath expands `~` prefixes for user-provided cue file paths. +func expandUserPath(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + if raw == "~" { + home, err := os.UserHomeDir() + if err != nil { + return raw + } + return home + } + if !strings.HasPrefix(raw, "~/") { + return raw + } + home, err := os.UserHomeDir() + if err != nil { + return raw + } + return filepath.Join(home, strings.TrimPrefix(raw, "~/")) +} + +// playCueFile plays a configured WAV file through pw-play. +func playCueFile(path string) error { + if _, err := os.Stat(path); err != nil { + return fmt.Errorf("stat cue file %q: %w", path, err) + } + + ctx, cancel := context.WithTimeout(context.Background(), 4*time.Second) + defer cancel() + + cmd := exec.CommandContext(ctx, "pw-play", "--media-role", "Notification", path) + if err := cmd.Run(); err != nil { + return fmt.Errorf("play cue file %q: %w", path, err) + } + return nil +} + +// playSynthCue streams synthesized PCM samples through Pulse playback. +func playSynthCue(samples []int16) error { + client, err := pulse.NewClient( + pulse.ClientApplicationName("sotto"), + pulse.ClientApplicationIconName("audio-input-microphone"), + ) + if err != nil { + return fmt.Errorf("connect pulse server: %w", err) + } + defer client.Close() + + cursor := 0 + reader := pulse.Int16Reader(func(buf []int16) (int, error) { + if cursor >= len(samples) { + return 0, pulse.EndOfData + } + + n := copy(buf, samples[cursor:]) + cursor += n + if cursor >= len(samples) { + return n, pulse.EndOfData + } + return n, nil + }) + + stream, err := client.NewPlayback( + reader, + pulse.PlaybackMono, + pulse.PlaybackSampleRate(cueSampleRate), + pulse.PlaybackLatency(0.02), + pulse.PlaybackMediaName("sotto indicator cue"), + ) + if err != nil { + return fmt.Errorf("create pulse playback stream: %w", err) + } + defer stream.Close() + + stream.Start() + stream.Drain() + if err := stream.Error(); err != nil { + return fmt.Errorf("play cue stream: %w", err) + } + + return nil +} + +// cueSamples returns the synthesized PCM table for one cue kind. +func cueSamples(kind cueKind) []int16 { + switch kind { + case cueStart: + return startCuePCM + case cueStop: + return stopCuePCM + case cueComplete: + return completeCuePCM + case cueCancel: + return cancelCuePCM + default: + return nil + } +} + +// synthesizeCue concatenates one or more tone segments with short silence gaps. +func synthesizeCue(parts []toneSpec) []int16 { + if len(parts) == 0 { + return nil + } + gapSamples := samplesForDuration(22 * time.Millisecond) + total := 0 + for i, part := range parts { + total += samplesForDuration(part.duration) + if i < len(parts)-1 { + total += gapSamples + } + } + + pcm := make([]int16, 0, total) + for i, part := range parts { + pcm = append(pcm, synthesizeTone(part)...) + if i < len(parts)-1 && gapSamples > 0 { + pcm = append(pcm, make([]int16, gapSamples)...) + } + } + + return pcm +} + +// synthesizeTone creates one windowed sine-wave segment. +func synthesizeTone(spec toneSpec) []int16 { + n := samplesForDuration(spec.duration) + if n <= 0 || spec.frequencyHz <= 0 || spec.volume <= 0 { + return nil + } + + attackRelease := n / 10 + maxRamp := cueSampleRate / 200 // 5ms + if attackRelease > maxRamp { + attackRelease = maxRamp + } + if attackRelease < 1 { + attackRelease = 1 + } + + pcm := make([]int16, n) + for i := 0; i < n; i++ { + envelope := 1.0 + if i < attackRelease { + envelope = float64(i) / float64(attackRelease) + } + releaseIndex := n - i - 1 + if releaseIndex < attackRelease { + release := float64(releaseIndex) / float64(attackRelease) + if release < envelope { + envelope = release + } + } + t := float64(i) / cueSampleRate + sample := math.Sin(2 * math.Pi * spec.frequencyHz * t) + pcm[i] = int16(math.Round(sample * spec.volume * envelope * 32767)) + } + + return pcm +} + +// samplesForDuration converts a time duration into cue sample count. +func samplesForDuration(d time.Duration) int { + if d <= 0 { + return 0 + } + return int(math.Round(d.Seconds() * cueSampleRate)) +} diff --git a/apps/sotto/internal/indicator/sound_test.go b/apps/sotto/internal/indicator/sound_test.go new file mode 100644 index 0000000..6ac39a4 --- /dev/null +++ b/apps/sotto/internal/indicator/sound_test.go @@ -0,0 +1,60 @@ +package indicator + +import ( + "path/filepath" + "testing" + "time" + + "github.com/rbright/sotto/internal/config" + "github.com/stretchr/testify/require" +) + +func TestCueSamplesPresent(t *testing.T) { + require.NotEmpty(t, cueSamples(cueStart)) + require.NotEmpty(t, cueSamples(cueStop)) + require.NotEmpty(t, cueSamples(cueComplete)) + require.NotEmpty(t, cueSamples(cueCancel)) +} + +func TestSynthesizeToneDuration(t *testing.T) { + got := synthesizeTone(toneSpec{frequencyHz: 440, duration: 100 * time.Millisecond, volume: 0.2}) + want := samplesForDuration(100 * time.Millisecond) + require.Len(t, got, want) +} + +func TestSynthesizeToneInvalidSpecReturnsEmpty(t *testing.T) { + require.Empty(t, synthesizeTone(toneSpec{frequencyHz: 0, duration: 100 * time.Millisecond, volume: 0.2})) + require.Empty(t, synthesizeTone(toneSpec{frequencyHz: 440, duration: 0, volume: 0.2})) + require.Empty(t, synthesizeTone(toneSpec{frequencyHz: 440, duration: 100 * time.Millisecond, volume: 0})) +} + +func TestCuePathMapping(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + cfg := config.IndicatorConfig{ + SoundStartFile: "~/start.wav", + SoundStopFile: "/tmp/stop.wav", + SoundCompleteFile: "/tmp/complete.wav", + SoundCancelFile: "/tmp/cancel.wav", + } + + require.Equal(t, filepath.Join(home, "start.wav"), cuePath(cueStart, cfg)) + require.Equal(t, "/tmp/stop.wav", cuePath(cueStop, cfg)) + require.Equal(t, "/tmp/complete.wav", cuePath(cueComplete, cfg)) + require.Equal(t, "/tmp/cancel.wav", cuePath(cueCancel, cfg)) +} + +func TestExpandUserPath(t *testing.T) { + home := t.TempDir() + t.Setenv("HOME", home) + + require.Equal(t, home, expandUserPath("~")) + require.Equal(t, filepath.Join(home, "Downloads", "sound.wav"), expandUserPath("~/Downloads/sound.wav")) + require.Equal(t, "/tmp/sound.wav", expandUserPath("/tmp/sound.wav")) + require.Empty(t, expandUserPath(" ")) +} + +func TestSamplesForDuration(t *testing.T) { + require.Equal(t, 0, samplesForDuration(0)) + require.Greater(t, samplesForDuration(25*time.Millisecond), 0) +} diff --git a/apps/sotto/internal/ipc/client.go b/apps/sotto/internal/ipc/client.go new file mode 100644 index 0000000..a5ce890 --- /dev/null +++ b/apps/sotto/internal/ipc/client.go @@ -0,0 +1,74 @@ +package ipc + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "net" + "os" + "syscall" + "time" +) + +// Send opens a unix-socket request/response roundtrip with a deadline. +func Send(ctx context.Context, path string, req Request, timeout time.Duration) (Response, error) { + dialer := net.Dialer{Timeout: timeout} + conn, err := dialer.DialContext(ctx, "unix", path) + if err != nil { + return Response{}, err + } + defer conn.Close() + + deadline := time.Now().Add(timeout) + if err := conn.SetDeadline(deadline); err != nil { + return Response{}, fmt.Errorf("set deadline: %w", err) + } + + enc := json.NewEncoder(conn) + if err := enc.Encode(req); err != nil { + return Response{}, fmt.Errorf("encode request: %w", err) + } + + reader := bufio.NewReader(conn) + line, err := reader.ReadBytes('\n') + if err != nil { + return Response{}, fmt.Errorf("read response: %w", err) + } + + var resp Response + if err := json.Unmarshal(line, &resp); err != nil { + return Response{}, fmt.Errorf("decode response: %w", err) + } + + return resp, nil +} + +// Probe checks whether a responsive owner is currently listening on path. +func Probe(ctx context.Context, path string, timeout time.Duration) (bool, error) { + _, err := Send(ctx, path, Request{Command: "status"}, timeout) + if err == nil { + return true, nil + } + if isSocketMissing(err) || isConnectionRefused(err) { + return false, nil + } + return false, fmt.Errorf("probe socket: %w", err) +} + +// isSocketMissing reports absent-socket failures. +func isSocketMissing(err error) bool { + if err == nil { + return false + } + return errors.Is(err, os.ErrNotExist) +} + +// isConnectionRefused reports no-listener failures. +func isConnectionRefused(err error) bool { + if err == nil { + return false + } + return errors.Is(err, syscall.ECONNREFUSED) +} diff --git a/apps/sotto/internal/ipc/client_server_test.go b/apps/sotto/internal/ipc/client_server_test.go new file mode 100644 index 0000000..4450d1d --- /dev/null +++ b/apps/sotto/internal/ipc/client_server_test.go @@ -0,0 +1,156 @@ +package ipc + +import ( + "bufio" + "context" + "encoding/json" + "net" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestSendRoundTrip(t *testing.T) { + runtimeDir := t.TempDir() + socketPath := filepath.Join(runtimeDir, "sotto.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serveDone := make(chan error, 1) + go func() { + serveDone <- Serve(ctx, listener, HandlerFunc(func(_ context.Context, req Request) Response { + require.Equal(t, "status", req.Command) + return Response{OK: true, State: "recording", Message: "ok"} + })) + }() + + resp, err := Send(context.Background(), socketPath, Request{Command: "status"}, 200*time.Millisecond) + require.NoError(t, err) + require.True(t, resp.OK) + require.Equal(t, "recording", resp.State) + require.Equal(t, "ok", resp.Message) + + cancel() + require.NoError(t, <-serveDone) +} + +func TestSendDecodeResponseError(t *testing.T) { + runtimeDir := t.TempDir() + socketPath := filepath.Join(runtimeDir, "sotto.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + t.Cleanup(func() { _ = listener.Close() }) + + go func() { + conn, acceptErr := listener.Accept() + if acceptErr != nil { + return + } + defer conn.Close() + + reader := bufio.NewReader(conn) + _, _ = reader.ReadBytes('\n') + _, _ = conn.Write([]byte("not-json\n")) + }() + + _, err = Send(context.Background(), socketPath, Request{Command: "status"}, 200*time.Millisecond) + require.Error(t, err) + require.Contains(t, err.Error(), "decode response") +} + +func TestSendReadResponseError(t *testing.T) { + runtimeDir := t.TempDir() + socketPath := filepath.Join(runtimeDir, "sotto.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + t.Cleanup(func() { _ = listener.Close() }) + + go func() { + conn, acceptErr := listener.Accept() + if acceptErr != nil { + return + } + defer conn.Close() + + reader := bufio.NewReader(conn) + _, _ = reader.ReadBytes('\n') + }() + + _, err = Send(context.Background(), socketPath, Request{Command: "status"}, 200*time.Millisecond) + require.Error(t, err) + require.Contains(t, err.Error(), "read response") +} + +func TestServeDecodeRequestErrorResponse(t *testing.T) { + runtimeDir := t.TempDir() + socketPath := filepath.Join(runtimeDir, "sotto.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + serveDone := make(chan error, 1) + go func() { + serveDone <- Serve(ctx, listener, HandlerFunc(func(_ context.Context, _ Request) Response { + return Response{OK: true} + })) + }() + + conn, err := net.Dial("unix", socketPath) + require.NoError(t, err) + defer conn.Close() + + _, err = conn.Write([]byte("not-json\n")) + require.NoError(t, err) + + line, err := bufio.NewReader(conn).ReadBytes('\n') + require.NoError(t, err) + + var resp Response + require.NoError(t, json.Unmarshal(line, &resp)) + require.False(t, resp.OK) + require.Contains(t, resp.Error, "decode request") + + cancel() + require.NoError(t, <-serveDone) +} + +func TestProbe(t *testing.T) { + runtimeDir := t.TempDir() + socketPath := filepath.Join(runtimeDir, "sotto.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + serveDone := make(chan error, 1) + go func() { + serveDone <- Serve(ctx, listener, HandlerFunc(func(_ context.Context, req Request) Response { + if req.Command == "status" { + return Response{OK: true, State: "idle"} + } + return Response{OK: false, Error: "bad"} + })) + }() + + alive, probeErr := Probe(context.Background(), socketPath, 200*time.Millisecond) + require.NoError(t, probeErr) + require.True(t, alive) + + cancel() + require.NoError(t, <-serveDone) + + alive, probeErr = Probe(context.Background(), socketPath, 100*time.Millisecond) + require.NoError(t, probeErr) + require.False(t, alive) +} diff --git a/apps/sotto/internal/ipc/protocol.go b/apps/sotto/internal/ipc/protocol.go new file mode 100644 index 0000000..fcfe326 --- /dev/null +++ b/apps/sotto/internal/ipc/protocol.go @@ -0,0 +1,15 @@ +// Package ipc provides single-instance unix-socket protocol and server/client helpers. +package ipc + +// Request is one command sent over the local unix-domain socket. +type Request struct { + Command string `json:"command"` +} + +// Response is the normalized command outcome returned by the owner session. +type Response struct { + OK bool `json:"ok"` + State string `json:"state,omitempty"` + Message string `json:"message,omitempty"` + Error string `json:"error,omitempty"` +} diff --git a/apps/sotto/internal/ipc/server.go b/apps/sotto/internal/ipc/server.go new file mode 100644 index 0000000..dfe4011 --- /dev/null +++ b/apps/sotto/internal/ipc/server.go @@ -0,0 +1,66 @@ +package ipc + +import ( + "bufio" + "context" + "encoding/json" + "errors" + "fmt" + "net" + "sync" +) + +// Handler processes one IPC command request. +type Handler interface { + Handle(context.Context, Request) Response +} + +// HandlerFunc adapts a function to the Handler interface. +type HandlerFunc func(context.Context, Request) Response + +func (f HandlerFunc) Handle(ctx context.Context, req Request) Response { + return f(ctx, req) +} + +// Serve accepts unix-socket clients until context cancellation or listener close. +func Serve(ctx context.Context, listener net.Listener, handler Handler) error { + var wg sync.WaitGroup + + go func() { + <-ctx.Done() + _ = listener.Close() + }() + + for { + conn, err := listener.Accept() + if err != nil { + if errors.Is(err, net.ErrClosed) || ctx.Err() != nil { + wg.Wait() + return nil + } + return fmt.Errorf("accept IPC connection: %w", err) + } + + wg.Add(1) + go func(c net.Conn) { + defer wg.Done() + defer c.Close() + + reader := bufio.NewReader(c) + line, err := reader.ReadBytes('\n') + if err != nil { + _ = json.NewEncoder(c).Encode(Response{OK: false, Error: fmt.Sprintf("read request: %v", err)}) + return + } + + var req Request + if err := json.Unmarshal(line, &req); err != nil { + _ = json.NewEncoder(c).Encode(Response{OK: false, Error: fmt.Sprintf("decode request: %v", err)}) + return + } + + resp := handler.Handle(ctx, req) + _ = json.NewEncoder(c).Encode(resp) + }(conn) + } +} diff --git a/apps/sotto/internal/ipc/socket.go b/apps/sotto/internal/ipc/socket.go new file mode 100644 index 0000000..fecf243 --- /dev/null +++ b/apps/sotto/internal/ipc/socket.go @@ -0,0 +1,83 @@ +package ipc + +import ( + "context" + "errors" + "fmt" + "net" + "os" + "path/filepath" + "strings" + "time" +) + +// ErrAlreadyRunning indicates a responsive owner already holds the runtime socket. +var ErrAlreadyRunning = errors.New("sotto session already running") + +// RuntimeSocketPath returns the owner socket path derived from XDG_RUNTIME_DIR. +func RuntimeSocketPath() (string, error) { + runtimeDir := strings.TrimSpace(os.Getenv("XDG_RUNTIME_DIR")) + if runtimeDir == "" { + return "", errors.New("XDG_RUNTIME_DIR is not set") + } + return filepath.Join(runtimeDir, "sotto.sock"), nil +} + +// Acquire attempts to become the owner listener, cleaning stale sockets when safe. +func Acquire( + ctx context.Context, + path string, + probeTimeout time.Duration, + retries int, + rescue func(context.Context) error, +) (net.Listener, error) { + if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { + return nil, fmt.Errorf("ensure runtime socket dir: %w", err) + } + + for attempt := 0; attempt <= retries; attempt++ { + listener, err := net.Listen("unix", path) + if err == nil { + _ = os.Chmod(path, 0o600) + return listener, nil + } + + if !isAddrInUse(err) { + return nil, fmt.Errorf("listen unix %s: %w", path, err) + } + + alive, probeErr := Probe(ctx, path, probeTimeout) + if alive { + return nil, ErrAlreadyRunning + } + if probeErr != nil { + return nil, fmt.Errorf("probe existing socket %s: %w", path, probeErr) + } + + if removeErr := os.Remove(path); removeErr != nil && !errors.Is(removeErr, os.ErrNotExist) { + return nil, fmt.Errorf("remove stale socket %s: %w", path, removeErr) + } + + if rescue != nil { + _ = rescue(ctx) + } + + if attempt < retries { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(time.Duration(25*(attempt+1)) * time.Millisecond): + } + } + } + + return nil, fmt.Errorf("failed to acquire socket %s after %d retries", path, retries) +} + +// isAddrInUse identifies listener errors caused by an existing socket path. +func isAddrInUse(err error) bool { + if err == nil { + return false + } + return strings.Contains(err.Error(), "address already in use") +} diff --git a/apps/sotto/internal/ipc/socket_test.go b/apps/sotto/internal/ipc/socket_test.go new file mode 100644 index 0000000..92a8f57 --- /dev/null +++ b/apps/sotto/internal/ipc/socket_test.go @@ -0,0 +1,112 @@ +package ipc + +import ( + "context" + "errors" + "net" + "os" + "path/filepath" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestAcquireRecoversStaleSocket(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + socketPath := filepath.Join(dir, "sotto.sock") + if err := os.WriteFile(socketPath, []byte("stale"), 0o600); err != nil { + t.Fatalf("WriteFile() error = %v", err) + } + + rescueCalls := 0 + listener, err := Acquire(context.Background(), socketPath, 50*time.Millisecond, 2, func(context.Context) error { + rescueCalls++ + return nil + }) + if err != nil { + t.Fatalf("Acquire() error = %v", err) + } + defer listener.Close() + + if rescueCalls == 0 { + t.Fatalf("expected stale-socket rescue to run") + } +} + +func TestAcquireReturnsAlreadyRunningWhenSocketResponsive(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + socketPath := filepath.Join(dir, "sotto.sock") + listener, err := net.Listen("unix", socketPath) + if err != nil { + t.Fatalf("net.Listen() error = %v", err) + } + defer listener.Close() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serverDone := make(chan error, 1) + go func() { + serverDone <- Serve(ctx, listener, HandlerFunc(func(_ context.Context, _ Request) Response { + return Response{OK: true, State: "recording"} + })) + }() + + _, err = Acquire(context.Background(), socketPath, 80*time.Millisecond, 1, nil) + if !errors.Is(err, ErrAlreadyRunning) { + t.Fatalf("Acquire() error = %v, want ErrAlreadyRunning", err) + } + + cancel() + if serveErr := <-serverDone; serveErr != nil { + t.Fatalf("Serve() error = %v", serveErr) + } +} + +func TestAcquireDoesNotUnlinkWhenProbeInconclusive(t *testing.T) { + t.Parallel() + + dir := t.TempDir() + socketPath := filepath.Join(dir, "sotto.sock") + + listener, err := net.Listen("unix", socketPath) + require.NoError(t, err) + + acceptDone := make(chan struct{}) + go func() { + defer close(acceptDone) + for { + conn, acceptErr := listener.Accept() + if acceptErr != nil { + return + } + go func(c net.Conn) { + defer c.Close() + time.Sleep(250 * time.Millisecond) + }(conn) + } + }() + + _, err = Acquire(context.Background(), socketPath, 30*time.Millisecond, 0, nil) + require.Error(t, err) + require.NotErrorIs(t, err, ErrAlreadyRunning) + require.Contains(t, err.Error(), "probe existing socket") + + _, statErr := os.Stat(socketPath) + require.NoError(t, statErr) + require.NoError(t, listener.Close()) + <-acceptDone +} + +func TestRuntimeSocketPathRequiresXDG(t *testing.T) { + t.Setenv("XDG_RUNTIME_DIR", "") + _, err := RuntimeSocketPath() + if err == nil { + t.Fatal("expected error") + } +} diff --git a/apps/sotto/internal/logging/logger.go b/apps/sotto/internal/logging/logger.go new file mode 100644 index 0000000..6327e88 --- /dev/null +++ b/apps/sotto/internal/logging/logger.go @@ -0,0 +1,57 @@ +// Package logging configures runtime JSONL logging output. +package logging + +import ( + "io" + "log/slog" + "os" + "path/filepath" + "strings" +) + +// Runtime bundles the configured logger and its open file handle lifecycle. +type Runtime struct { + Logger *slog.Logger + Path string + closer io.Closer +} + +// Close flushes and closes the logger output sink. +func (r Runtime) Close() error { + if r.closer == nil { + return nil + } + return r.closer.Close() +} + +// New builds a JSONL logger rooted at the resolved state path. +func New() (Runtime, error) { + path, err := resolveLogPath() + if err != nil { + return Runtime{}, err + } + if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil { + return Runtime{}, err + } + + f, err := os.OpenFile(path, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0o600) + if err != nil { + return Runtime{}, err + } + + h := slog.NewJSONHandler(f, &slog.HandlerOptions{Level: slog.LevelInfo}) + logger := slog.New(h) + return Runtime{Logger: logger, Path: path, closer: f}, nil +} + +// resolveLogPath selects XDG_STATE_HOME when available, otherwise ~/.local/state. +func resolveLogPath() (string, error) { + if xdg := strings.TrimSpace(os.Getenv("XDG_STATE_HOME")); xdg != "" { + return filepath.Join(xdg, "sotto", "log.jsonl"), nil + } + home, err := os.UserHomeDir() + if err != nil { + return "", err + } + return filepath.Join(home, ".local", "state", "sotto", "log.jsonl"), nil +} diff --git a/apps/sotto/internal/logging/logger_test.go b/apps/sotto/internal/logging/logger_test.go new file mode 100644 index 0000000..d37771f --- /dev/null +++ b/apps/sotto/internal/logging/logger_test.go @@ -0,0 +1,48 @@ +package logging + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestResolveLogPathUsesXDGStateHome(t *testing.T) { + xdgStateHome := t.TempDir() + t.Setenv("XDG_STATE_HOME", xdgStateHome) + t.Setenv("HOME", t.TempDir()) + + path, err := resolveLogPath() + require.NoError(t, err) + require.Equal(t, filepath.Join(xdgStateHome, "sotto", "log.jsonl"), path) +} + +func TestResolveLogPathFallsBackToHome(t *testing.T) { + home := t.TempDir() + t.Setenv("XDG_STATE_HOME", "") + t.Setenv("HOME", home) + + path, err := resolveLogPath() + require.NoError(t, err) + require.Equal(t, filepath.Join(home, ".local", "state", "sotto", "log.jsonl"), path) +} + +func TestNewCreatesWritableJSONLogFile(t *testing.T) { + t.Setenv("XDG_STATE_HOME", t.TempDir()) + + runtime, err := New() + require.NoError(t, err) + + runtime.Logger.Info("unit-test-log", "component", "logging") + require.NoError(t, runtime.Close()) + + contents, err := os.ReadFile(runtime.Path) + require.NoError(t, err) + require.Contains(t, string(contents), `"msg":"unit-test-log"`) + require.Contains(t, string(contents), `"component":"logging"`) + + stat, err := os.Stat(runtime.Path) + require.NoError(t, err) + require.Equal(t, os.FileMode(0o600), stat.Mode().Perm()) +} diff --git a/apps/sotto/internal/output/clipboard.go b/apps/sotto/internal/output/clipboard.go new file mode 100644 index 0000000..cbdd58a --- /dev/null +++ b/apps/sotto/internal/output/clipboard.go @@ -0,0 +1,96 @@ +// Package output applies transcript commit side effects (clipboard and paste). +package output + +import ( + "context" + "fmt" + "log/slog" + "os/exec" + "time" + + "github.com/rbright/sotto/internal/config" +) + +// Committer applies transcript output side effects (clipboard + optional paste). +type Committer struct { + config config.Config + logger *slog.Logger +} + +// NewCommitter constructs a transcript committer from runtime config. +func NewCommitter(cfg config.Config, logger *slog.Logger) *Committer { + return &Committer{config: cfg, logger: logger} +} + +// Commit writes transcript text to clipboard and optionally dispatches paste. +func (c *Committer) Commit(ctx context.Context, transcript string) error { + if transcript == "" { + return nil + } + + clipboardCtx, clipboardCancel := context.WithTimeout(ctx, 2*time.Second) + defer clipboardCancel() + if err := runCommandWithInput(clipboardCtx, c.config.Clipboard.Argv, transcript); err != nil { + return fmt.Errorf("set clipboard: %w", err) + } + + if !c.config.Paste.Enable { + return nil + } + + if len(c.config.PasteCmd.Argv) > 0 { + pasteCtx, pasteCancel := context.WithTimeout(ctx, 2*time.Second) + defer pasteCancel() + if err := runCommandWithInput(pasteCtx, c.config.PasteCmd.Argv, ""); err != nil { + c.logPasteFailure(err) + } + return nil + } + + pasteCtx, pasteCancel := context.WithTimeout(ctx, 1200*time.Millisecond) + defer pasteCancel() + if err := defaultPaste(pasteCtx, c.config.Paste.Shortcut); err != nil { + c.logPasteFailure(err) + } + return nil +} + +// runCommandWithInput executes argv and optionally writes input to stdin. +func runCommandWithInput(ctx context.Context, argv []string, input string) error { + if len(argv) == 0 { + return fmt.Errorf("command argv cannot be empty") + } + + cmd := exec.CommandContext(ctx, argv[0], argv[1:]...) + stdin, err := cmd.StdinPipe() + if err != nil { + return fmt.Errorf("open stdin for %s: %w", argv[0], err) + } + + if err := cmd.Start(); err != nil { + _ = stdin.Close() + return fmt.Errorf("start command %s: %w", argv[0], err) + } + + if input != "" { + if _, err := stdin.Write([]byte(input)); err != nil { + _ = stdin.Close() + _ = cmd.Wait() + return fmt.Errorf("write stdin for %s: %w", argv[0], err) + } + } + _ = stdin.Close() + + if err := cmd.Wait(); err != nil { + return fmt.Errorf("wait for %s: %w", argv[0], err) + } + return nil +} + +// logPasteFailure records paste errors while preserving clipboard success semantics. +func (c *Committer) logPasteFailure(err error) { + if c.logger == nil || err == nil { + return + } + c.logger.Error("paste dispatch failed; clipboard remains set", "error", err.Error()) +} diff --git a/apps/sotto/internal/output/clipboard_test.go b/apps/sotto/internal/output/clipboard_test.go new file mode 100644 index 0000000..07081ef --- /dev/null +++ b/apps/sotto/internal/output/clipboard_test.go @@ -0,0 +1,162 @@ +package output + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + + "github.com/rbright/sotto/internal/config" + "github.com/stretchr/testify/require" +) + +func TestRunCommandWithInputWritesStdin(t *testing.T) { + scriptPath := writeStdinCaptureScript(t) + outputPath := filepath.Join(t.TempDir(), "stdin.txt") + + err := runCommandWithInput(context.Background(), []string{scriptPath, outputPath}, "hello from sotto") + require.NoError(t, err) + + data, err := os.ReadFile(outputPath) + require.NoError(t, err) + require.Equal(t, "hello from sotto", string(data)) +} + +func TestRunCommandWithInputRejectsEmptyArgv(t *testing.T) { + err := runCommandWithInput(context.Background(), nil, "payload") + require.Error(t, err) + require.Contains(t, err.Error(), "argv cannot be empty") +} + +func TestCommitterCommitWritesClipboardWhenPasteDisabled(t *testing.T) { + scriptPath := writeStdinCaptureScript(t) + clipboardPath := filepath.Join(t.TempDir(), "clipboard.txt") + + cfg := config.Default() + cfg.Paste.Enable = false + cfg.Clipboard = config.CommandConfig{Argv: []string{scriptPath, clipboardPath}} + + committer := NewCommitter(cfg, nil) + err := committer.Commit(context.Background(), "captured transcript") + require.NoError(t, err) + + data, err := os.ReadFile(clipboardPath) + require.NoError(t, err) + require.Equal(t, "captured transcript", string(data)) +} + +func TestCommitterCommitSkipsEmptyTranscript(t *testing.T) { + scriptPath := writeStdinCaptureScript(t) + clipboardPath := filepath.Join(t.TempDir(), "clipboard.txt") + + cfg := config.Default() + cfg.Paste.Enable = false + cfg.Clipboard = config.CommandConfig{Argv: []string{scriptPath, clipboardPath}} + + committer := NewCommitter(cfg, nil) + err := committer.Commit(context.Background(), "") + require.NoError(t, err) + + _, statErr := os.Stat(clipboardPath) + require.Error(t, statErr) + require.True(t, os.IsNotExist(statErr)) +} + +func TestCommitterCommitReturnsErrorWhenClipboardCommandFails(t *testing.T) { + failScript := writeFailScript(t, "clipboard failed") + + cfg := config.Default() + cfg.Paste.Enable = false + cfg.Clipboard = config.CommandConfig{Argv: []string{failScript}} + + committer := NewCommitter(cfg, nil) + err := committer.Commit(context.Background(), "captured transcript") + require.Error(t, err) + require.Contains(t, err.Error(), "set clipboard") +} + +func TestCommitterCommitPasteCmdFailureDoesNotFailCommit(t *testing.T) { + clipboardScript := writeStdinCaptureScript(t) + clipboardPath := filepath.Join(t.TempDir(), "clipboard.txt") + pasteFailScript := writeFailScript(t, "paste failed") + + cfg := config.Default() + cfg.Clipboard = config.CommandConfig{Argv: []string{clipboardScript, clipboardPath}} + cfg.Paste.Enable = true + cfg.PasteCmd = config.CommandConfig{Argv: []string{pasteFailScript}} + + committer := NewCommitter(cfg, nil) + err := committer.Commit(context.Background(), "captured transcript") + require.NoError(t, err) + + data, readErr := os.ReadFile(clipboardPath) + require.NoError(t, readErr) + require.Equal(t, "captured transcript", string(data)) +} + +func TestCommitterCommitDefaultPasteFailureDoesNotFailCommit(t *testing.T) { + clipboardScript := writeStdinCaptureScript(t) + clipboardPath := filepath.Join(t.TempDir(), "clipboard.txt") + + argsFile := filepath.Join(t.TempDir(), "hypr-args.log") + t.Setenv("HYPR_ARGS_FILE", argsFile) + installHyprctlDefaultPasteFailStub(t) + + cfg := config.Default() + cfg.Clipboard = config.CommandConfig{Argv: []string{clipboardScript, clipboardPath}} + cfg.Paste.Enable = true + cfg.PasteCmd = config.CommandConfig{} + + committer := NewCommitter(cfg, nil) + err := committer.Commit(context.Background(), "captured transcript") + require.NoError(t, err) + + data, readErr := os.ReadFile(clipboardPath) + require.NoError(t, readErr) + require.Equal(t, "captured transcript", string(data)) +} + +func writeStdinCaptureScript(t *testing.T) string { + t.Helper() + + dir := t.TempDir() + path := filepath.Join(dir, "capture-stdin.sh") + script := `#!/usr/bin/env bash +set -euo pipefail +cat > "$1" +` + require.NoError(t, os.WriteFile(path, []byte(script), 0o755)) + return path +} + +func writeFailScript(t *testing.T, message string) string { + t.Helper() + + dir := t.TempDir() + path := filepath.Join(dir, "fail.sh") + script := "#!/usr/bin/env bash\nset -euo pipefail\necho " + "\"" + message + "\"" + " >&2\nexit 1\n" + require.NoError(t, os.WriteFile(path, []byte(script), 0o755)) + return path +} + +func installHyprctlDefaultPasteFailStub(t *testing.T) { + t.Helper() + + dir := t.TempDir() + path := filepath.Join(dir, "hyprctl") + script := `#!/usr/bin/env bash +set -euo pipefail +if [[ "${1:-}" == "-j" && "${2:-}" == "activewindow" ]]; then + echo '{"address":"0xabc","class":"brave-browser","initialClass":"brave-browser"}' + exit 0 +fi +if [[ "${1:-}" == "--quiet" && "${2:-}" == "dispatch" && "${3:-}" == "sendshortcut" ]]; then + echo "sendshortcut failed" >&2 + exit 1 +fi +printf '%s\n' "$*" >> "${HYPR_ARGS_FILE}" +` + require.NoError(t, os.WriteFile(path, []byte(strings.TrimSpace(script)+"\n"), 0o755)) + t.Setenv("PATH", dir+":"+os.Getenv("PATH")) +} diff --git a/apps/sotto/internal/output/paste.go b/apps/sotto/internal/output/paste.go new file mode 100644 index 0000000..b95ba7d --- /dev/null +++ b/apps/sotto/internal/output/paste.go @@ -0,0 +1,68 @@ +package output + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/rbright/sotto/internal/hypr" +) + +// defaultPaste dispatches a sendshortcut payload to the current active window. +func defaultPaste(ctx context.Context, shortcut string) error { + window, err := activeWindowWithRetry(ctx, 5, 10*time.Millisecond) + if err != nil { + return err + } + + payload, err := buildPasteShortcut(shortcut, strings.TrimSpace(window.Address)) + if err != nil { + return err + } + return hypr.SendShortcut(ctx, payload) +} + +// buildPasteShortcut renders `,address:` payload format. +func buildPasteShortcut(shortcut string, windowAddress string) (string, error) { + shortcut = strings.TrimSpace(shortcut) + if shortcut == "" { + return "", fmt.Errorf("paste shortcut cannot be empty") + } + + address := strings.TrimSpace(windowAddress) + if address == "" { + return "", fmt.Errorf("active window address is required") + } + + return fmt.Sprintf("%s,address:%s", shortcut, address), nil +} + +// activeWindowWithRetry retries active-window lookup within short bounded delays. +func activeWindowWithRetry(ctx context.Context, attempts int, delay time.Duration) (hypr.ActiveWindow, error) { + if attempts <= 0 { + attempts = 1 + } + + var lastErr error + for i := 0; i < attempts; i++ { + window, err := hypr.QueryActiveWindow(ctx) + if err == nil { + return window, nil + } + lastErr = err + if i == attempts-1 { + break + } + select { + case <-ctx.Done(): + return hypr.ActiveWindow{}, ctx.Err() + case <-time.After(delay): + } + } + + if lastErr == nil { + lastErr = fmt.Errorf("active window unavailable") + } + return hypr.ActiveWindow{}, fmt.Errorf("resolve active window: %w", lastErr) +} diff --git a/apps/sotto/internal/output/paste_test.go b/apps/sotto/internal/output/paste_test.go new file mode 100644 index 0000000..2423e1a --- /dev/null +++ b/apps/sotto/internal/output/paste_test.go @@ -0,0 +1,89 @@ +package output + +import ( + "context" + "os" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestBuildPasteShortcut(t *testing.T) { + t.Parallel() + + t.Run("builds payload", func(t *testing.T) { + got, err := buildPasteShortcut("SUPER,V", "0xabc") + require.NoError(t, err) + require.Equal(t, "SUPER,V,address:0xabc", got) + }) + + t.Run("rejects empty shortcut", func(t *testing.T) { + _, err := buildPasteShortcut("", "0xabc") + require.Error(t, err) + require.Contains(t, err.Error(), "shortcut") + }) + + t.Run("rejects empty address", func(t *testing.T) { + _, err := buildPasteShortcut("CTRL,V", "") + require.Error(t, err) + require.Contains(t, err.Error(), "address") + }) +} + +func TestDefaultPasteDispatchesShortcut(t *testing.T) { + argsFile := filepath.Join(t.TempDir(), "hypr-args.log") + t.Setenv("HYPR_ARGS_FILE", argsFile) + t.Setenv("HYPR_ACTIVEWINDOW_JSON", `{"address":"0xabc","class":"ghostty","initialClass":"ghostty"}`) + installHyprctlPasteStub(t) + + err := defaultPaste(context.Background(), "SUPER,V") + require.NoError(t, err) + + data, err := os.ReadFile(argsFile) + require.NoError(t, err) + require.Contains(t, string(data), "--quiet dispatch sendshortcut SUPER,V,address:0xabc") +} + +func TestActiveWindowWithRetryHonorsContextCancel(t *testing.T) { + emptyPathDir := t.TempDir() + t.Setenv("PATH", emptyPathDir) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + _, err := activeWindowWithRetry(ctx, 3, 10*time.Millisecond) + require.ErrorIs(t, err, context.Canceled) +} + +func TestDefaultPasteFailsWhenActiveWindowAddressMissing(t *testing.T) { + t.Setenv("HYPR_ACTIVEWINDOW_JSON", `{"address":"","class":"brave-browser"}`) + installHyprctlPasteStub(t) + + err := defaultPaste(context.Background(), "CTRL,V") + require.Error(t, err) + require.Contains(t, err.Error(), "empty address") +} + +func installHyprctlPasteStub(t *testing.T) { + t.Helper() + + dir := t.TempDir() + path := filepath.Join(dir, "hyprctl") + script := `#!/usr/bin/env bash +set -euo pipefail +if [[ "${1:-}" == "-j" && "${2:-}" == "activewindow" ]]; then + if [[ -n "${HYPR_ACTIVEWINDOW_JSON:-}" ]]; then + echo "${HYPR_ACTIVEWINDOW_JSON}" + else + echo '{"address":"0xabc","class":"brave-browser","initialClass":"brave-browser"}' + fi + exit 0 +fi +printf '%s\n' "$*" >> "${HYPR_ARGS_FILE}" +` + require.NoError(t, os.WriteFile(path, []byte(strings.TrimSpace(script)+"\n"), 0o755)) + t.Setenv("PATH", dir+":"+os.Getenv("PATH")) +} diff --git a/apps/sotto/internal/pipeline/transcriber.go b/apps/sotto/internal/pipeline/transcriber.go new file mode 100644 index 0000000..bdff4f5 --- /dev/null +++ b/apps/sotto/internal/pipeline/transcriber.go @@ -0,0 +1,360 @@ +// Package pipeline orchestrates audio capture and Riva streaming for transcription. +package pipeline + +import ( + "context" + "encoding/binary" + "fmt" + "log/slog" + "os" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/rbright/sotto/internal/audio" + "github.com/rbright/sotto/internal/config" + "github.com/rbright/sotto/internal/riva" + "github.com/rbright/sotto/internal/session" + "github.com/rbright/sotto/internal/transcript" +) + +// Transcriber owns one end-to-end capture -> ASR -> transcript pipeline instance. +type Transcriber struct { + cfg config.Config + logger *slog.Logger + + mu sync.Mutex + started bool + + selection audio.Selection + capture *audio.Capture + stream *riva.Stream + + sendErrCh chan error + + debugGRPCFile *os.File +} + +// NewTranscriber constructs a pipeline transcriber from runtime config. +func NewTranscriber(cfg config.Config, logger *slog.Logger) *Transcriber { + return &Transcriber{cfg: cfg, logger: logger} +} + +// Start resolves device selection, opens Riva stream, and starts audio capture. +func (t *Transcriber) Start(ctx context.Context) error { + t.mu.Lock() + defer t.mu.Unlock() + + if t.started { + return fmt.Errorf("transcriber already started") + } + + selection, err := audio.SelectDevice(ctx, t.cfg.Audio.Input, t.cfg.Audio.Fallback) + if err != nil { + return err + } + t.selection = selection + if selection.Warning != "" { + t.logWarn(selection.Warning) + } + + speechPhrases, _, err := config.BuildSpeechPhrases(t.cfg) + if err != nil { + return fmt.Errorf("build speech contexts: %w", err) + } + + if t.cfg.Debug.EnableGRPCDump { + file, ferr := createDebugFile("grpc", "json") + if ferr != nil { + return ferr + } + t.debugGRPCFile = file + } + + rivaPhrases := make([]riva.SpeechPhrase, 0, len(speechPhrases)) + for _, phrase := range speechPhrases { + rivaPhrases = append(rivaPhrases, riva.SpeechPhrase{Phrase: phrase.Phrase, Boost: phrase.Boost}) + } + + stream, err := riva.DialStream(ctx, riva.StreamConfig{ + Endpoint: t.cfg.RivaGRPC, + LanguageCode: t.cfg.ASR.LanguageCode, + Model: t.cfg.ASR.Model, + AutomaticPunctuation: t.cfg.ASR.AutomaticPunctuation, + SpeechPhrases: rivaPhrases, + DialTimeout: 3 * time.Second, + DebugResponseSinkJSON: func() *os.File { + if t.debugGRPCFile == nil { + return nil + } + return t.debugGRPCFile + }(), + }) + if err != nil { + t.closeDebugArtifactsLocked() + return err + } + t.stream = stream + + capture, err := audio.StartCapture(ctx, selection.Device) + if err != nil { + _ = stream.Cancel() + t.closeDebugArtifactsLocked() + return err + } + t.capture = capture + + t.sendErrCh = make(chan error, 1) + go t.sendLoop() + + t.started = true + return nil +} + +// StopAndTranscribe stops capture, closes stream, and assembles the transcript. +func (t *Transcriber) StopAndTranscribe(ctx context.Context) (session.StopResult, error) { + t.mu.Lock() + started := t.started + capture := t.capture + stream := t.stream + sendErrCh := t.sendErrCh + selection := t.selection + t.mu.Unlock() + + if !started || capture == nil || stream == nil { + return session.StopResult{}, session.ErrPipelineUnavailable + } + defer t.resetRuntimeState() + + _ = capture.Stop() + + var sendErr error + if sendErrCh != nil { + sendErr = <-sendErrCh + } + if sendErr != nil { + _ = stream.Cancel() + result := session.StopResult{ + AudioDevice: describeDevice(selection.Device), + BytesCaptured: capture.BytesCaptured(), + } + t.writeDebugAudio(capture.RawPCM()) + t.closeDebugArtifacts() + return result, fmt.Errorf("send audio stream: %w", sendErr) + } + + closeCtx, cancel := context.WithTimeout(ctx, 20*time.Second) + defer cancel() + segments, grpcLatency, err := stream.CloseAndCollect(closeCtx) + if err != nil { + result := session.StopResult{ + AudioDevice: describeDevice(selection.Device), + BytesCaptured: capture.BytesCaptured(), + GRPCLatency: grpcLatency, + } + t.writeDebugAudio(capture.RawPCM()) + t.closeDebugArtifacts() + return result, fmt.Errorf("collect final transcript: %w", err) + } + + transcribed := transcript.Assemble(segments, t.cfg.Transcript.TrailingSpace) + rawPCM := capture.RawPCM() + t.writeDebugAudio(rawPCM) + t.closeDebugArtifacts() + + return session.StopResult{ + Transcript: transcribed, + AudioDevice: describeDevice(selection.Device), + BytesCaptured: capture.BytesCaptured(), + GRPCLatency: grpcLatency, + }, nil +} + +// Cancel stops capture and stream immediately without transcript commit. +func (t *Transcriber) Cancel(_ context.Context) error { + t.mu.Lock() + capture := t.capture + stream := t.stream + t.mu.Unlock() + defer t.resetRuntimeState() + + if capture != nil { + _ = capture.Stop() + t.writeDebugAudio(capture.RawPCM()) + } + if stream != nil { + _ = stream.Cancel() + } + t.closeDebugArtifacts() + return nil +} + +// resetRuntimeState clears one-shot runtime resources so the transcriber can be reused. +func (t *Transcriber) resetRuntimeState() { + t.mu.Lock() + defer t.mu.Unlock() + t.started = false + t.capture = nil + t.stream = nil + t.sendErrCh = nil +} + +// sendLoop forwards capture chunks to Riva and reports the first send failure. +func (t *Transcriber) sendLoop() { + t.mu.Lock() + capture := t.capture + stream := t.stream + errCh := t.sendErrCh + t.mu.Unlock() + + if errCh == nil { + return + } + + sent := false + sendResult := func(err error) { + if sent { + return + } + errCh <- err + sent = true + } + defer sendResult(nil) + + if capture == nil || stream == nil { + sendResult(session.ErrPipelineUnavailable) + return + } + + for chunk := range capture.Chunks() { + if len(chunk) == 0 { + continue + } + if err := stream.SendAudio(chunk); err != nil { + _ = capture.Stop() + sendResult(err) + return + } + } +} + +// describeDevice formats device metadata for logs/session results. +func describeDevice(device audio.Device) string { + description := strings.TrimSpace(device.Description) + id := strings.TrimSpace(device.ID) + if description == "" { + return id + } + if id == "" { + return description + } + return fmt.Sprintf("%s (%s)", description, id) +} + +// logWarn emits warning-level logs when logger is configured. +func (t *Transcriber) logWarn(message string) { + if t.logger == nil { + return + } + t.logger.Warn(message) +} + +// createDebugFile creates timestamped debug artifacts under state/sotto/debug. +func createDebugFile(prefix string, extension string) (*os.File, error) { + stateDir, err := resolveStateDir() + if err != nil { + return nil, err + } + debugDir := filepath.Join(stateDir, "sotto", "debug") + if err := os.MkdirAll(debugDir, 0o700); err != nil { + return nil, fmt.Errorf("create debug dir: %w", err) + } + + timestamp := time.Now().Format("20060102-150405.000") + path := filepath.Join(debugDir, fmt.Sprintf("%s-%s.%s", prefix, timestamp, extension)) + file, err := os.OpenFile(path, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0o600) + if err != nil { + return nil, fmt.Errorf("open debug file %q: %w", path, err) + } + return file, nil +} + +// resolveStateDir returns XDG_STATE_HOME fallback path for debug artifacts. +func resolveStateDir() (string, error) { + if xdg := strings.TrimSpace(os.Getenv("XDG_STATE_HOME")); xdg != "" { + return xdg, nil + } + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("resolve home directory for state: %w", err) + } + return filepath.Join(home, ".local", "state"), nil +} + +// closeDebugArtifacts closes open debug sinks. +func (t *Transcriber) closeDebugArtifacts() { + t.mu.Lock() + defer t.mu.Unlock() + t.closeDebugArtifactsLocked() +} + +// closeDebugArtifactsLocked closes debug sinks while caller holds t.mu. +func (t *Transcriber) closeDebugArtifactsLocked() { + if t.debugGRPCFile != nil { + _ = t.debugGRPCFile.Close() + t.debugGRPCFile = nil + } +} + +// writeDebugAudio writes raw PCM to WAV when debug.audio_dump is enabled. +func (t *Transcriber) writeDebugAudio(rawPCM []byte) { + if !t.cfg.Debug.EnableAudioDump || len(rawPCM) == 0 { + return + } + + file, err := createDebugFile("audio", "wav") + if err != nil { + t.logWarn(fmt.Sprintf("unable to create debug audio dump: %v", err)) + return + } + defer file.Close() + + if err := writePCM16WAV(file, rawPCM, 16000, 1); err != nil { + t.logWarn(fmt.Sprintf("unable to write debug audio dump: %v", err)) + } +} + +// writePCM16WAV writes raw little-endian PCM bytes with a minimal WAV header. +func writePCM16WAV(file *os.File, pcm []byte, sampleRate int, channels int) error { + if channels <= 0 { + channels = 1 + } + const bitsPerSample = 16 + byteRate := sampleRate * channels * (bitsPerSample / 8) + blockAlign := channels * (bitsPerSample / 8) + + chunkSize := uint32(36 + len(pcm)) + subChunk2Size := uint32(len(pcm)) + + header := make([]byte, 44) + copy(header[0:4], []byte("RIFF")) + binary.LittleEndian.PutUint32(header[4:8], chunkSize) + copy(header[8:12], []byte("WAVE")) + copy(header[12:16], []byte("fmt ")) + binary.LittleEndian.PutUint32(header[16:20], 16) + binary.LittleEndian.PutUint16(header[20:22], 1) // PCM + binary.LittleEndian.PutUint16(header[22:24], uint16(channels)) + binary.LittleEndian.PutUint32(header[24:28], uint32(sampleRate)) + binary.LittleEndian.PutUint32(header[28:32], uint32(byteRate)) + binary.LittleEndian.PutUint16(header[32:34], uint16(blockAlign)) + binary.LittleEndian.PutUint16(header[34:36], bitsPerSample) + copy(header[36:40], []byte("data")) + binary.LittleEndian.PutUint32(header[40:44], subChunk2Size) + + if _, err := file.Write(header); err != nil { + return err + } + _, err := file.Write(pcm) + return err +} diff --git a/apps/sotto/internal/pipeline/transcriber_test.go b/apps/sotto/internal/pipeline/transcriber_test.go new file mode 100644 index 0000000..295fa91 --- /dev/null +++ b/apps/sotto/internal/pipeline/transcriber_test.go @@ -0,0 +1,181 @@ +package pipeline + +import ( + "context" + "encoding/binary" + "os" + "path/filepath" + "testing" + + "github.com/rbright/sotto/internal/audio" + "github.com/rbright/sotto/internal/config" + "github.com/rbright/sotto/internal/session" + "github.com/stretchr/testify/require" +) + +func TestDescribeDevice(t *testing.T) { + require.Equal(t, "Elgato (alsa_input.wave3)", describeDevice(audio.Device{Description: "Elgato", ID: "alsa_input.wave3"})) + require.Equal(t, "Elgato", describeDevice(audio.Device{Description: "Elgato"})) + require.Equal(t, "alsa_input.wave3", describeDevice(audio.Device{ID: "alsa_input.wave3"})) +} + +func TestResolveStateDirUsesXDGStateHome(t *testing.T) { + xdgStateHome := t.TempDir() + t.Setenv("XDG_STATE_HOME", xdgStateHome) + t.Setenv("HOME", t.TempDir()) + + dir, err := resolveStateDir() + require.NoError(t, err) + require.Equal(t, xdgStateHome, dir) +} + +func TestResolveStateDirFallsBackToHome(t *testing.T) { + home := t.TempDir() + t.Setenv("XDG_STATE_HOME", "") + t.Setenv("HOME", home) + + dir, err := resolveStateDir() + require.NoError(t, err) + require.Equal(t, filepath.Join(home, ".local", "state"), dir) +} + +func TestCreateDebugFileCreatesExpectedPath(t *testing.T) { + t.Setenv("XDG_STATE_HOME", t.TempDir()) + + file, err := createDebugFile("grpc", "json") + require.NoError(t, err) + path := file.Name() + require.NoError(t, file.Close()) + + require.FileExists(t, path) + require.Contains(t, path, string(filepath.Separator)+"sotto"+string(filepath.Separator)+"debug"+string(filepath.Separator)) + require.Contains(t, filepath.Base(path), "grpc-") + require.Equal(t, ".json", filepath.Ext(path)) + + stat, err := os.Stat(path) + require.NoError(t, err) + require.Equal(t, os.FileMode(0o600), stat.Mode().Perm()) +} + +func TestWritePCM16WAVWritesHeaderAndPCM(t *testing.T) { + file, err := os.CreateTemp(t.TempDir(), "*.wav") + require.NoError(t, err) + + pcm := []byte{0x01, 0x00, 0xFF, 0x7F} + require.NoError(t, writePCM16WAV(file, pcm, 16000, 0)) + require.NoError(t, file.Close()) + + data, err := os.ReadFile(file.Name()) + require.NoError(t, err) + require.Len(t, data, 44+len(pcm)) + + require.Equal(t, "RIFF", string(data[0:4])) + require.Equal(t, "WAVE", string(data[8:12])) + require.Equal(t, "fmt ", string(data[12:16])) + require.Equal(t, "data", string(data[36:40])) + require.Equal(t, uint16(1), binary.LittleEndian.Uint16(data[22:24])) // channels default to mono + require.Equal(t, uint32(len(pcm)), binary.LittleEndian.Uint32(data[40:44])) + require.Equal(t, pcm, data[44:]) +} + +func TestWriteDebugAudioCreatesWavWhenEnabled(t *testing.T) { + xdgStateHome := t.TempDir() + t.Setenv("XDG_STATE_HOME", xdgStateHome) + + cfg := config.Default() + cfg.Debug.EnableAudioDump = true + transcriber := NewTranscriber(cfg, nil) + + transcriber.writeDebugAudio([]byte{0x01, 0x00, 0x02, 0x00}) + + matches, err := filepath.Glob(filepath.Join(xdgStateHome, "sotto", "debug", "audio-*.wav")) + require.NoError(t, err) + require.NotEmpty(t, matches) +} + +func TestWriteDebugAudioSkippedWhenDisabled(t *testing.T) { + xdgStateHome := t.TempDir() + t.Setenv("XDG_STATE_HOME", xdgStateHome) + + cfg := config.Default() + cfg.Debug.EnableAudioDump = false + transcriber := NewTranscriber(cfg, nil) + + transcriber.writeDebugAudio([]byte{0x01, 0x00, 0x02, 0x00}) + + matches, err := filepath.Glob(filepath.Join(xdgStateHome, "sotto", "debug", "audio-*.wav")) + require.NoError(t, err) + require.Empty(t, matches) +} + +func TestStartFailsWhenAlreadyStarted(t *testing.T) { + transcriber := NewTranscriber(config.Default(), nil) + transcriber.started = true + + err := transcriber.Start(context.Background()) + require.Error(t, err) + require.Contains(t, err.Error(), "already started") +} + +func TestStartFailsWhenAudioSelectionUnavailable(t *testing.T) { + t.Setenv("PULSE_SERVER", "unix:/tmp/definitely-missing-pulse-server") + + transcriber := NewTranscriber(config.Default(), nil) + err := transcriber.Start(context.Background()) + require.Error(t, err) +} + +func TestStopAndTranscribeUnavailableWhenNotStarted(t *testing.T) { + result, err := NewTranscriber(config.Default(), nil).StopAndTranscribe(context.Background()) + require.ErrorIs(t, err, session.ErrPipelineUnavailable) + require.Equal(t, session.StopResult{}, result) +} + +func TestCancelWithoutInitializedPipeline(t *testing.T) { + transcriber := NewTranscriber(config.Default(), nil) + require.NoError(t, transcriber.Cancel(context.Background())) +} + +func TestSendLoopNoopWhenUninitialized(t *testing.T) { + transcriber := NewTranscriber(config.Default(), nil) + transcriber.sendLoop() // should return immediately without panic +} + +func TestSendLoopSignalsPipelineUnavailableWhenChannelPresent(t *testing.T) { + transcriber := NewTranscriber(config.Default(), nil) + transcriber.sendErrCh = make(chan error, 1) + + transcriber.sendLoop() + + err := <-transcriber.sendErrCh + require.ErrorIs(t, err, session.ErrPipelineUnavailable) +} + +func TestCloseDebugArtifactsClosesFile(t *testing.T) { + file, err := os.CreateTemp(t.TempDir(), "*.json") + require.NoError(t, err) + + transcriber := NewTranscriber(config.Default(), nil) + transcriber.debugGRPCFile = file + transcriber.closeDebugArtifacts() + + _, err = file.Write([]byte("x")) + require.Error(t, err) + require.Nil(t, transcriber.debugGRPCFile) +} + +func TestCloseDebugArtifactsLockedClosesFileWhileMutexHeld(t *testing.T) { + file, err := os.CreateTemp(t.TempDir(), "*.json") + require.NoError(t, err) + + transcriber := NewTranscriber(config.Default(), nil) + transcriber.debugGRPCFile = file + + transcriber.mu.Lock() + transcriber.closeDebugArtifactsLocked() + transcriber.mu.Unlock() + + _, err = file.Write([]byte("x")) + require.Error(t, err) + require.Nil(t, transcriber.debugGRPCFile) +} diff --git a/apps/sotto/internal/riva/client.go b/apps/sotto/internal/riva/client.go new file mode 100644 index 0000000..dee7031 --- /dev/null +++ b/apps/sotto/internal/riva/client.go @@ -0,0 +1,212 @@ +// Package riva implements the Riva gRPC streaming client adapter. +package riva + +import ( + "context" + "errors" + "fmt" + "io" + "strings" + "sync" + "time" + + asrpb "github.com/rbright/sotto/proto/gen/go/riva/proto" + "google.golang.org/grpc" + "google.golang.org/grpc/credentials/insecure" +) + +// SpeechPhrase is one vocabulary boost phrase in request-ready form. +type SpeechPhrase struct { + Phrase string + Boost float32 +} + +// StreamConfig controls stream initialization and recognition behavior. +type StreamConfig struct { + Endpoint string + LanguageCode string + Model string + AutomaticPunctuation bool + SpeechPhrases []SpeechPhrase + DialTimeout time.Duration + DebugResponseSinkJSON io.Writer +} + +// Stream wraps one active Riva StreamingRecognize RPC lifecycle. +type Stream struct { + conn *grpc.ClientConn + stream asrpb.RivaSpeechRecognition_StreamingRecognizeClient + + cancel context.CancelFunc + + recvDone chan struct{} + + mu sync.Mutex + segments []string // committed transcript segments (final and pause-committed interim) + lastInterim string + recvErr error + closedSend bool + debugSinkJSON io.Writer +} + +// DialStream establishes a stream, sends config, and starts the receive loop. +func DialStream(ctx context.Context, cfg StreamConfig) (*Stream, error) { + endpoint := strings.TrimSpace(cfg.Endpoint) + if endpoint == "" { + return nil, errors.New("riva endpoint is empty") + } + if cfg.DialTimeout <= 0 { + cfg.DialTimeout = 3 * time.Second + } + if strings.TrimSpace(cfg.LanguageCode) == "" { + cfg.LanguageCode = "en-US" + } + + conn, err := grpc.NewClient( + endpoint, + grpc.WithTransportCredentials(insecure.NewCredentials()), + ) + if err != nil { + return nil, fmt.Errorf("dial riva grpc %q: %w", endpoint, err) + } + + readyCtx, cancel := context.WithTimeout(ctx, cfg.DialTimeout) + defer cancel() + conn.Connect() + if err := waitForReady(readyCtx, conn); err != nil { + _ = conn.Close() + return nil, fmt.Errorf("wait for riva grpc readiness: %w", err) + } + + streamCtx, streamCancel := context.WithCancel(ctx) + client := asrpb.NewRivaSpeechRecognitionClient(conn) + stream, err := openRecognizeWithTimeout(streamCtx, cfg.DialTimeout, func() (asrpb.RivaSpeechRecognition_StreamingRecognizeClient, error) { + return client.StreamingRecognize(streamCtx) + }) + if err != nil { + streamCancel() + _ = conn.Close() + return nil, fmt.Errorf("open streaming recognizer: %w", err) + } + + req := &asrpb.StreamingRecognizeRequest{ + StreamingRequest: &asrpb.StreamingRecognizeRequest_StreamingConfig{ + StreamingConfig: &asrpb.StreamingRecognitionConfig{ + Config: &asrpb.RecognitionConfig{ + Encoding: asrpb.AudioEncoding_LINEAR_PCM, + SampleRateHertz: 16000, + LanguageCode: cfg.LanguageCode, + EnableAutomaticPunctuation: cfg.AutomaticPunctuation, + AudioChannelCount: 1, + Model: strings.TrimSpace(cfg.Model), + }, + InterimResults: true, + }, + }, + } + + for _, phrase := range cfg.SpeechPhrases { + phraseText := strings.TrimSpace(phrase.Phrase) + if phraseText == "" { + continue + } + req.GetStreamingConfig().GetConfig().SpeechContexts = append( + req.GetStreamingConfig().GetConfig().SpeechContexts, + &asrpb.SpeechContext{Phrases: []string{phraseText}, Boost: phrase.Boost}, + ) + } + + if err := runWithTimeout(streamCtx, cfg.DialTimeout, func() error { + return stream.Send(req) + }); err != nil { + streamCancel() + _ = conn.Close() + return nil, fmt.Errorf("send initial streaming config: %w", err) + } + + s := &Stream{ + conn: conn, + stream: stream, + cancel: streamCancel, + recvDone: make(chan struct{}), + debugSinkJSON: cfg.DebugResponseSinkJSON, + } + go s.recvLoop() + return s, nil +} + +// SendAudio sends one chunk of PCM audio over the active stream. +func (s *Stream) SendAudio(chunk []byte) error { + if len(chunk) == 0 { + return nil + } + + s.mu.Lock() + closed := s.closedSend + recvErr := s.recvErr + s.mu.Unlock() + + if closed { + return errors.New("stream already closed for sending") + } + if recvErr != nil { + return fmt.Errorf("stream receive loop failed: %w", recvErr) + } + + return s.stream.Send(&asrpb.StreamingRecognizeRequest{ + StreamingRequest: &asrpb.StreamingRecognizeRequest_AudioContent{AudioContent: chunk}, + }) +} + +// CloseAndCollect closes send-side audio and returns merged transcript segments. +func (s *Stream) CloseAndCollect(ctx context.Context) ([]string, time.Duration, error) { + closedAt := time.Now() + + s.mu.Lock() + if !s.closedSend { + s.closedSend = true + _ = s.stream.CloseSend() + } + s.mu.Unlock() + + select { + case <-s.recvDone: + case <-ctx.Done(): + if s.cancel != nil { + s.cancel() + } + _ = s.conn.Close() + return nil, 0, ctx.Err() + } + latency := time.Since(closedAt) + + s.mu.Lock() + defer s.mu.Unlock() + defer func() { + if s.cancel != nil { + s.cancel() + } + _ = s.conn.Close() + }() + + if s.recvErr != nil { + return nil, latency, s.recvErr + } + + segments := collectSegments(s.segments, s.lastInterim) + return segments, latency, nil +} + +// Cancel aborts stream processing and closes the underlying grpc connection. +func (s *Stream) Cancel() error { + s.mu.Lock() + if !s.closedSend { + s.closedSend = true + _ = s.stream.CloseSend() + } + s.mu.Unlock() + if s.cancel != nil { + s.cancel() + } + return s.conn.Close() +} diff --git a/apps/sotto/internal/riva/client_test.go b/apps/sotto/internal/riva/client_test.go new file mode 100644 index 0000000..679103f --- /dev/null +++ b/apps/sotto/internal/riva/client_test.go @@ -0,0 +1,299 @@ +package riva + +import ( + "bytes" + "context" + "errors" + "io" + "net" + "testing" + "time" + + asrpb "github.com/rbright/sotto/proto/gen/go/riva/proto" + "github.com/stretchr/testify/require" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func TestCollectSegmentsAppendsTrailingInterim(t *testing.T) { + got := collectSegments([]string{"hello there"}, "how are you") + require.Equal(t, []string{"hello there", "how are you"}, got) +} + +func TestCollectSegmentsFallsBackToInterim(t *testing.T) { + got := collectSegments(nil, " tentative words ") + require.Equal(t, []string{"tentative words"}, got) +} + +func TestRecordResponseTracksInterimThenFinal(t *testing.T) { + s := &Stream{} + + s.recordResponse(&asrpb.StreamingRecognizeResponse{ + Results: []*asrpb.StreamingRecognitionResult{{ + IsFinal: false, + Alternatives: []*asrpb.SpeechRecognitionAlternative{{Transcript: "hello wor"}}, + }}, + }) + + require.Equal(t, "hello wor", s.lastInterim) + require.Empty(t, s.segments) + + s.recordResponse(&asrpb.StreamingRecognizeResponse{ + Results: []*asrpb.StreamingRecognitionResult{{ + IsFinal: true, + Alternatives: []*asrpb.SpeechRecognitionAlternative{{Transcript: "hello world"}}, + }}, + }) + + require.Empty(t, s.lastInterim) + require.Equal(t, []string{"hello world"}, s.segments) +} + +func TestRecordResponseCommitsInterimAcrossPauseLikeReset(t *testing.T) { + s := &Stream{} + + s.recordResponse(&asrpb.StreamingRecognizeResponse{ + Results: []*asrpb.StreamingRecognitionResult{{ + IsFinal: false, + Alternatives: []*asrpb.SpeechRecognitionAlternative{{Transcript: "first phrase"}}, + }}, + }) + + s.recordResponse(&asrpb.StreamingRecognizeResponse{ + Results: []*asrpb.StreamingRecognitionResult{{ + IsFinal: false, + Alternatives: []*asrpb.SpeechRecognitionAlternative{{Transcript: "second phrase"}}, + }}, + }) + + segments := collectSegments(s.segments, s.lastInterim) + require.Equal(t, []string{"first phrase", "second phrase"}, segments) +} + +func TestAppendSegmentDedupAndPrefixMerge(t *testing.T) { + segments := appendSegment(nil, "hello") + require.Equal(t, []string{"hello"}, segments) + + segments = appendSegment(segments, "hello") + require.Equal(t, []string{"hello"}, segments) + + segments = appendSegment(segments, "hello world") + require.Equal(t, []string{"hello world"}, segments) + + segments = appendSegment(segments, "hello") + require.Equal(t, []string{"hello world"}, segments) + + segments = appendSegment(segments, "new sentence") + require.Equal(t, []string{"hello world", "new sentence"}, segments) +} + +func TestCleanSegmentAndInterimContinuation(t *testing.T) { + require.Equal(t, "hello world", cleanSegment(" hello\n world ")) + require.Empty(t, cleanSegment(" \n\t")) + + require.True(t, isInterimContinuation("hello", "hello world")) + require.True(t, isInterimContinuation("hello world", "hello")) + require.False(t, isInterimContinuation("first phrase", "second phrase")) +} + +func TestDialStreamEndToEndWithDebugSinkAndSpeechContexts(t *testing.T) { + server := &testRivaServer{ + responses: []*asrpb.StreamingRecognizeResponse{ + {Results: []*asrpb.StreamingRecognitionResult{{ + IsFinal: false, + Alternatives: []*asrpb.SpeechRecognitionAlternative{{Transcript: "hello wor"}}, + }}}, + {Results: []*asrpb.StreamingRecognitionResult{{ + IsFinal: true, + Alternatives: []*asrpb.SpeechRecognitionAlternative{{Transcript: "hello world"}}, + }}}, + {Results: []*asrpb.StreamingRecognitionResult{{ + IsFinal: false, + Alternatives: []*asrpb.SpeechRecognitionAlternative{{Transcript: "second phrase"}}, + }}}, + }, + } + endpoint, shutdown := startTestRivaServer(t, server) + defer shutdown() + + var debug bytes.Buffer + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + stream, err := DialStream(ctx, StreamConfig{ + Endpoint: endpoint, + LanguageCode: "en-US", + Model: "parakeet", + AutomaticPunctuation: true, + SpeechPhrases: []SpeechPhrase{ + {Phrase: " Sotto ", Boost: 12}, + {Phrase: "", Boost: 20}, + }, + DialTimeout: 2 * time.Second, + DebugResponseSinkJSON: &debug, + }) + require.NoError(t, err) + + require.NoError(t, stream.SendAudio([]byte{1, 2, 3, 4})) + require.NoError(t, stream.SendAudio(nil)) // no-op path + + segments, latency, err := stream.CloseAndCollect(ctx) + require.NoError(t, err) + require.Equal(t, []string{"hello world", "second phrase"}, segments) + require.GreaterOrEqual(t, latency, time.Duration(0)) + + require.NotNil(t, server.receivedConfig) + require.Equal(t, int32(16000), server.receivedConfig.Config.SampleRateHertz) + require.Equal(t, int32(1), server.receivedConfig.Config.AudioChannelCount) + require.Equal(t, "en-US", server.receivedConfig.Config.LanguageCode) + require.Equal(t, "parakeet", server.receivedConfig.Config.Model) + require.True(t, server.receivedConfig.Config.EnableAutomaticPunctuation) + require.Len(t, server.receivedConfig.Config.SpeechContexts, 1) + require.Equal(t, []string{"Sotto"}, server.receivedConfig.Config.SpeechContexts[0].Phrases) + require.Equal(t, 1, server.audioChunks) + + require.Contains(t, debug.String(), "results") +} + +func TestDialStreamEmptyEndpoint(t *testing.T) { + _, err := DialStream(context.Background(), StreamConfig{Endpoint: " "}) + require.Error(t, err) + require.Contains(t, err.Error(), "endpoint is empty") +} + +func TestDialStreamReadinessTimeout(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + _, err := DialStream(ctx, StreamConfig{ + Endpoint: "127.0.0.1:1", + DialTimeout: 100 * time.Millisecond, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "readiness") +} + +func TestRunWithTimeoutTimesOut(t *testing.T) { + err := runWithTimeout(context.Background(), 20*time.Millisecond, func() error { + time.Sleep(120 * time.Millisecond) + return nil + }) + require.Error(t, err) + require.Contains(t, err.Error(), "timed out") +} + +func TestOpenRecognizeWithTimeoutTimesOut(t *testing.T) { + _, err := openRecognizeWithTimeout(context.Background(), 20*time.Millisecond, func() (asrpb.RivaSpeechRecognition_StreamingRecognizeClient, error) { + time.Sleep(120 * time.Millisecond) + return nil, nil + }) + require.Error(t, err) + require.Contains(t, err.Error(), "timed out") +} + +func TestRunWithTimeoutReturnsCallError(t *testing.T) { + want := errors.New("boom") + err := runWithTimeout(context.Background(), time.Second, func() error { + return want + }) + require.ErrorIs(t, err, want) +} + +func TestCloseAndCollectReturnsServerStreamError(t *testing.T) { + server := &testRivaServer{streamErr: status.Error(codes.Internal, "boom")} + endpoint, shutdown := startTestRivaServer(t, server) + defer shutdown() + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + stream, err := DialStream(ctx, StreamConfig{Endpoint: endpoint, DialTimeout: time.Second}) + require.NoError(t, err) + require.NoError(t, stream.SendAudio([]byte{1, 2})) + + _, _, err = stream.CloseAndCollect(ctx) + require.Error(t, err) + require.Contains(t, err.Error(), "boom") +} + +func TestSendAudioAfterCloseReturnsError(t *testing.T) { + server := &testRivaServer{} + endpoint, shutdown := startTestRivaServer(t, server) + defer shutdown() + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + stream, err := DialStream(ctx, StreamConfig{Endpoint: endpoint, DialTimeout: time.Second}) + require.NoError(t, err) + + _, _, err = stream.CloseAndCollect(ctx) + require.NoError(t, err) + + err = stream.SendAudio([]byte{9, 9, 9}) + require.Error(t, err) + require.Contains(t, err.Error(), "closed") +} + +type testRivaServer struct { + asrpb.UnimplementedRivaSpeechRecognitionServer + + responses []*asrpb.StreamingRecognizeResponse + streamErr error + + receivedConfig *asrpb.StreamingRecognitionConfig + audioChunks int +} + +func (s *testRivaServer) StreamingRecognize(stream grpc.BidiStreamingServer[asrpb.StreamingRecognizeRequest, asrpb.StreamingRecognizeResponse]) error { + for { + req, err := stream.Recv() + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return err + } + + if cfg := req.GetStreamingConfig(); cfg != nil { + s.receivedConfig = cfg + continue + } + if len(req.GetAudioContent()) > 0 { + s.audioChunks++ + } + } + + for _, resp := range s.responses { + if err := stream.Send(resp); err != nil { + return err + } + } + if s.streamErr != nil { + return s.streamErr + } + return nil +} + +func startTestRivaServer(t *testing.T, srv asrpb.RivaSpeechRecognitionServer) (string, func()) { + t.Helper() + + lis, err := net.Listen("tcp", "127.0.0.1:0") + require.NoError(t, err) + + grpcServer := grpc.NewServer() + asrpb.RegisterRivaSpeechRecognitionServer(grpcServer, srv) + + go func() { + _ = grpcServer.Serve(lis) + }() + + shutdown := func() { + grpcServer.Stop() + _ = lis.Close() + } + + return lis.Addr().String(), shutdown +} diff --git a/apps/sotto/internal/riva/grpc_ready.go b/apps/sotto/internal/riva/grpc_ready.go new file mode 100644 index 0000000..6810cea --- /dev/null +++ b/apps/sotto/internal/riva/grpc_ready.go @@ -0,0 +1,30 @@ +package riva + +import ( + "context" + "errors" + "fmt" + + "google.golang.org/grpc" + "google.golang.org/grpc/connectivity" +) + +// waitForReady blocks until gRPC connection enters Ready or fails. +func waitForReady(ctx context.Context, conn *grpc.ClientConn) error { + for { + state := conn.GetState() + switch state { + case connectivity.Ready: + return nil + case connectivity.Shutdown: + return errors.New("grpc connection entered shutdown state") + } + + if !conn.WaitForStateChange(ctx, state) { + if ctx.Err() != nil { + return ctx.Err() + } + return fmt.Errorf("grpc readiness wait timed out in state %s", state.String()) + } + } +} diff --git a/apps/sotto/internal/riva/stream_init.go b/apps/sotto/internal/riva/stream_init.go new file mode 100644 index 0000000..a375d08 --- /dev/null +++ b/apps/sotto/internal/riva/stream_init.go @@ -0,0 +1,67 @@ +package riva + +import ( + "context" + "fmt" + "time" + + asrpb "github.com/rbright/sotto/proto/gen/go/riva/proto" +) + +type openResult struct { + stream asrpb.RivaSpeechRecognition_StreamingRecognizeClient + err error +} + +// openRecognizeWithTimeout bounds stream-open latency when backend RPCs stall. +func openRecognizeWithTimeout( + ctx context.Context, + timeout time.Duration, + open func() (asrpb.RivaSpeechRecognition_StreamingRecognizeClient, error), +) (asrpb.RivaSpeechRecognition_StreamingRecognizeClient, error) { + if timeout <= 0 { + return open() + } + + resultCh := make(chan openResult, 1) + go func() { + stream, err := open() + resultCh <- openResult{stream: stream, err: err} + }() + + timer := time.NewTimer(timeout) + defer timer.Stop() + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-timer.C: + return nil, fmt.Errorf("timed out after %s", timeout) + case result := <-resultCh: + return result.stream, result.err + } +} + +// runWithTimeout bounds one blocking stream operation (for example initial Send). +func runWithTimeout(ctx context.Context, timeout time.Duration, call func() error) error { + if timeout <= 0 { + return call() + } + + resultCh := make(chan error, 1) + go func() { + resultCh <- call() + }() + + timer := time.NewTimer(timeout) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return fmt.Errorf("timed out after %s", timeout) + case err := <-resultCh: + return err + } +} diff --git a/apps/sotto/internal/riva/stream_receive.go b/apps/sotto/internal/riva/stream_receive.go new file mode 100644 index 0000000..ef0e226 --- /dev/null +++ b/apps/sotto/internal/riva/stream_receive.go @@ -0,0 +1,64 @@ +package riva + +import ( + "encoding/json" + "errors" + "io" + + asrpb "github.com/rbright/sotto/proto/gen/go/riva/proto" +) + +// recvLoop continuously receives recognition responses until stream close/error. +func (s *Stream) recvLoop() { + defer close(s.recvDone) + + for { + resp, err := s.stream.Recv() + if err == nil { + s.recordResponse(resp) + continue + } + if errors.Is(err, io.EOF) { + return + } + + s.mu.Lock() + s.recvErr = err + s.mu.Unlock() + return + } +} + +// recordResponse merges final/interim segments into stream state. +func (s *Stream) recordResponse(resp *asrpb.StreamingRecognizeResponse) { + if sink := s.debugSinkJSON; sink != nil { + b, err := json.Marshal(resp) + if err == nil { + _, _ = sink.Write(append(b, '\n')) + } + } + + s.mu.Lock() + defer s.mu.Unlock() + + for _, result := range resp.GetResults() { + alternatives := result.GetAlternatives() + if len(alternatives) == 0 { + continue + } + transcript := cleanSegment(alternatives[0].GetTranscript()) + if transcript == "" { + continue + } + if result.GetIsFinal() { + s.segments = appendSegment(s.segments, transcript) + s.lastInterim = "" + continue + } + + if s.lastInterim != "" && !isInterimContinuation(s.lastInterim, transcript) { + s.segments = appendSegment(s.segments, s.lastInterim) + } + s.lastInterim = transcript + } +} diff --git a/apps/sotto/internal/riva/transcript_segments.go b/apps/sotto/internal/riva/transcript_segments.go new file mode 100644 index 0000000..0e887f6 --- /dev/null +++ b/apps/sotto/internal/riva/transcript_segments.go @@ -0,0 +1,88 @@ +package riva + +import "strings" + +// collectSegments appends a valid trailing interim segment when needed. +func collectSegments(committedSegments []string, lastInterim string) []string { + segments := append([]string(nil), committedSegments...) + if interim := cleanSegment(lastInterim); interim != "" { + segments = appendSegment(segments, interim) + } + return segments +} + +// appendSegment merges continuation segments to avoid duplicate transcript growth. +func appendSegment(segments []string, transcript string) []string { + transcript = cleanSegment(transcript) + if transcript == "" { + return segments + } + if len(segments) == 0 { + return append(segments, transcript) + } + + last := cleanSegment(segments[len(segments)-1]) + switch { + case transcript == last: + return segments + case strings.HasPrefix(transcript, last): + segments[len(segments)-1] = transcript + return segments + case strings.HasPrefix(last, transcript): + return segments + default: + return append(segments, transcript) + } +} + +// isInterimContinuation decides whether an interim update extends prior speech. +func isInterimContinuation(previous string, current string) bool { + previous = cleanSegment(previous) + current = cleanSegment(current) + if previous == "" || current == "" { + return true + } + if previous == current { + return true + } + if strings.HasPrefix(current, previous) || strings.HasPrefix(previous, current) { + return true + } + + prevWords := strings.Fields(previous) + currWords := strings.Fields(current) + common := commonPrefixWords(prevWords, currWords) + shorter := len(prevWords) + if len(currWords) < shorter { + shorter = len(currWords) + } + if shorter == 0 { + return true + } + return common*2 >= shorter +} + +// commonPrefixWords counts shared leading words across two slices. +func commonPrefixWords(left []string, right []string) int { + limit := len(left) + if len(right) < limit { + limit = len(right) + } + count := 0 + for i := 0; i < limit; i++ { + if left[i] != right[i] { + break + } + count++ + } + return count +} + +// cleanSegment normalizes transcript whitespace. +func cleanSegment(raw string) string { + raw = strings.TrimSpace(raw) + if raw == "" { + return "" + } + return strings.Join(strings.Fields(raw), " ") +} diff --git a/apps/sotto/internal/session/commit.go b/apps/sotto/internal/session/commit.go new file mode 100644 index 0000000..8ec1be5 --- /dev/null +++ b/apps/sotto/internal/session/commit.go @@ -0,0 +1,15 @@ +package session + +import "context" + +// Committer persists/dispatches a transcript when session stop succeeds. +type Committer interface { + Commit(context.Context, string) error +} + +// CommitFunc adapts a function to the Committer interface. +type CommitFunc func(context.Context, string) error + +func (f CommitFunc) Commit(ctx context.Context, transcript string) error { + return f(ctx, transcript) +} diff --git a/apps/sotto/internal/session/session.go b/apps/sotto/internal/session/session.go new file mode 100644 index 0000000..7265ef7 --- /dev/null +++ b/apps/sotto/internal/session/session.go @@ -0,0 +1,320 @@ +// Package session coordinates dictation lifecycle state, actions, and commit flow. +package session + +import ( + "context" + "errors" + "fmt" + "log/slog" + "strings" + "sync" + "time" + + "github.com/rbright/sotto/internal/fsm" + "github.com/rbright/sotto/internal/ipc" +) + +type action int + +const ( + actionStop action = iota + 1 + actionCancel +) + +// Result is the complete lifecycle output returned by one Run invocation. +type Result struct { + State fsm.State + Transcript string + Cancelled bool + Err error + AudioDevice string + BytesCaptured int64 + GRPCLatency time.Duration + StartedAt time.Time + FinishedAt time.Time + FocusedMonitor string +} + +// Indicator is the session-facing subset of indicator behavior. +type Indicator interface { + ShowRecording(context.Context) + ShowTranscribing(context.Context) + ShowError(context.Context, string) + CueStop(context.Context) + CueComplete(context.Context) + CueCancel(context.Context) + Hide(context.Context) + FocusedMonitor() string +} + +// noopIndicator preserves session flow when no indicator is wired. +type noopIndicator struct{} + +func (noopIndicator) ShowRecording(context.Context) {} +func (noopIndicator) ShowTranscribing(context.Context) {} +func (noopIndicator) ShowError(context.Context, string) {} +func (noopIndicator) CueStop(context.Context) {} +func (noopIndicator) CueComplete(context.Context) {} +func (noopIndicator) CueCancel(context.Context) {} +func (noopIndicator) Hide(context.Context) {} +func (noopIndicator) FocusedMonitor() string { return "" } + +// Controller orchestrates session state transitions and side effects. +type Controller struct { + logger *slog.Logger + transcribe Transcriber + commit Committer + indicator Indicator + + mu sync.RWMutex + state fsm.State + + actions chan action +} + +// NewController constructs a session controller with safe default fallbacks. +func NewController( + logger *slog.Logger, + transcriber Transcriber, + committer Committer, + indicator Indicator, +) *Controller { + if transcriber == nil { + transcriber = PlaceholderTranscriber{} + } + if committer == nil { + committer = CommitFunc(func(context.Context, string) error { return nil }) + } + if indicator == nil { + indicator = noopIndicator{} + } + + return &Controller{ + logger: logger, + transcribe: transcriber, + commit: committer, + indicator: indicator, + state: fsm.StateIdle, + actions: make(chan action, 1), + } +} + +// State returns the current FSM state snapshot. +func (c *Controller) State() fsm.State { + c.mu.RLock() + defer c.mu.RUnlock() + return c.state +} + +// transition applies one FSM event to the controller state. +func (c *Controller) transition(event fsm.Event) error { + c.mu.Lock() + defer c.mu.Unlock() + + next, err := fsm.Transition(c.state, event) + if err != nil { + return err + } + c.state = next + return nil +} + +// Run executes one owner lifecycle from start to stop/cancel/failure completion. +func (c *Controller) Run(ctx context.Context) Result { + result := Result{StartedAt: time.Now()} + + if err := c.transition(fsm.EventStart); err != nil { + result.State = c.State() + result.Err = err + result.FinishedAt = time.Now() + return result + } + + c.indicator.ShowRecording(ctx) + + if err := c.transcribe.Start(ctx); err != nil { + c.indicator.ShowError(ctx, "Unable to start recording") + c.toErrorAndReset() + result.State = c.State() + result.Err = err + result.FinishedAt = time.Now() + result.FocusedMonitor = c.indicator.FocusedMonitor() + return result + } + + defer func() { + cleanupCtx, cancel := context.WithTimeout(context.Background(), 800*time.Millisecond) + defer cancel() + c.indicator.Hide(cleanupCtx) + }() + + select { + case <-ctx.Done(): + _ = c.transcribe.Cancel(context.Background()) + c.indicator.CueCancel(context.Background()) + c.indicator.ShowError(context.Background(), "Cancelled") + c.toErrorAndReset() + result.State = c.State() + result.Err = ctx.Err() + result.FinishedAt = time.Now() + result.FocusedMonitor = c.indicator.FocusedMonitor() + return result + case a := <-c.actions: + switch a { + case actionCancel: + _ = c.transcribe.Cancel(context.Background()) + c.indicator.CueCancel(context.Background()) + _ = c.transition(fsm.EventCancel) + result.State = c.State() + result.Cancelled = true + result.FinishedAt = time.Now() + result.FocusedMonitor = c.indicator.FocusedMonitor() + return result + case actionStop: + if err := c.transition(fsm.EventStop); err != nil { + c.toErrorAndReset() + result.State = c.State() + result.Err = err + result.FinishedAt = time.Now() + result.FocusedMonitor = c.indicator.FocusedMonitor() + return result + } + c.indicator.ShowTranscribing(ctx) + + stopResult, err := c.transcribe.StopAndTranscribe(ctx) + c.indicator.CueStop(context.Background()) + if err != nil { + c.indicator.ShowError(context.Background(), "Speech recognition failed") + c.toErrorAndReset() + result.State = c.State() + result.Err = err + result.BytesCaptured = stopResult.BytesCaptured + result.AudioDevice = stopResult.AudioDevice + result.GRPCLatency = stopResult.GRPCLatency + result.FinishedAt = time.Now() + result.FocusedMonitor = c.indicator.FocusedMonitor() + return result + } + + if strings.TrimSpace(stopResult.Transcript) == "" { + c.indicator.ShowError(context.Background(), "No speech detected") + c.toErrorAndReset() + result.State = c.State() + result.Err = ErrEmptyTranscript + result.Transcript = stopResult.Transcript + result.AudioDevice = stopResult.AudioDevice + result.BytesCaptured = stopResult.BytesCaptured + result.GRPCLatency = stopResult.GRPCLatency + result.FinishedAt = time.Now() + result.FocusedMonitor = c.indicator.FocusedMonitor() + return result + } + + if err := c.commit.Commit(ctx, stopResult.Transcript); err != nil { + c.indicator.ShowError(context.Background(), "Output dispatch failed") + c.toErrorAndReset() + result.State = c.State() + result.Err = err + result.Transcript = stopResult.Transcript + result.AudioDevice = stopResult.AudioDevice + result.BytesCaptured = stopResult.BytesCaptured + result.GRPCLatency = stopResult.GRPCLatency + result.FinishedAt = time.Now() + result.FocusedMonitor = c.indicator.FocusedMonitor() + return result + } + c.indicator.CueComplete(context.Background()) + + if err := c.transition(fsm.EventTranscribed); err != nil { + result.State = c.State() + result.Err = err + result.Transcript = stopResult.Transcript + result.AudioDevice = stopResult.AudioDevice + result.BytesCaptured = stopResult.BytesCaptured + result.GRPCLatency = stopResult.GRPCLatency + result.FinishedAt = time.Now() + result.FocusedMonitor = c.indicator.FocusedMonitor() + return result + } + + result.State = c.State() + result.Transcript = stopResult.Transcript + result.AudioDevice = stopResult.AudioDevice + result.BytesCaptured = stopResult.BytesCaptured + result.GRPCLatency = stopResult.GRPCLatency + result.FinishedAt = time.Now() + result.FocusedMonitor = c.indicator.FocusedMonitor() + return result + default: + c.toErrorAndReset() + result.State = c.State() + result.Err = fmt.Errorf("unknown action %d", a) + result.FinishedAt = time.Now() + result.FocusedMonitor = c.indicator.FocusedMonitor() + return result + } + } +} + +// Handle serves IPC commands for the active owner session. +func (c *Controller) Handle(_ context.Context, req ipc.Request) ipc.Response { + switch req.Command { + case "status": + return ipc.Response{OK: true, State: string(c.State()), Message: "status"} + case "toggle": + return c.requestStop("toggle") + case "stop": + return c.requestStop("stop") + case "cancel": + return c.requestCancel() + default: + return ipc.Response{OK: false, State: string(c.State()), Error: fmt.Sprintf("unknown command: %s", req.Command)} + } +} + +// requestStop enqueues a stop action when state permits it. +func (c *Controller) requestStop(source string) ipc.Response { + state := c.State() + if state == fsm.StateTranscribing { + return ipc.Response{OK: false, State: string(state), Error: "already transcribing"} + } + if state != fsm.StateRecording { + return ipc.Response{OK: false, State: string(state), Error: fmt.Sprintf("cannot %s from state %s", source, state)} + } + + select { + case c.actions <- actionStop: + return ipc.Response{OK: true, State: string(state), Message: "stop requested"} + default: + return ipc.Response{OK: true, State: string(state), Message: "stop already requested"} + } +} + +// requestCancel enqueues a cancel action when state permits it. +func (c *Controller) requestCancel() ipc.Response { + state := c.State() + if state == fsm.StateTranscribing { + return ipc.Response{OK: false, State: string(state), Error: "cannot cancel while transcribing"} + } + if state != fsm.StateRecording { + return ipc.Response{OK: false, State: string(state), Error: fmt.Sprintf("cannot cancel from state %s", state)} + } + + select { + case c.actions <- actionCancel: + return ipc.Response{OK: true, State: string(state), Message: "cancel requested"} + default: + return ipc.Response{OK: true, State: string(state), Message: "cancel already requested"} + } +} + +// toErrorAndReset transitions to error and back to idle best-effort. +func (c *Controller) toErrorAndReset() { + _ = c.transition(fsm.EventFail) + _ = c.transition(fsm.EventReset) +} + +// IsPipelineUnavailable reports whether an error represents missing pipeline wiring. +func IsPipelineUnavailable(err error) bool { + return errors.Is(err, ErrPipelineUnavailable) +} diff --git a/apps/sotto/internal/session/session_controller_test.go b/apps/sotto/internal/session/session_controller_test.go new file mode 100644 index 0000000..06c6244 --- /dev/null +++ b/apps/sotto/internal/session/session_controller_test.go @@ -0,0 +1,198 @@ +package session + +import ( + "context" + "errors" + "testing" + "time" + + "github.com/rbright/sotto/internal/fsm" + "github.com/rbright/sotto/internal/ipc" + "github.com/stretchr/testify/require" +) + +func TestHandleStatusAndUnknownCommand(t *testing.T) { + ctrl := NewController(nil, &fakeTranscriber{}, nil, &fakeIndicator{}) + + status := ctrl.Handle(context.Background(), ipc.Request{Command: "status"}) + require.True(t, status.OK) + require.Equal(t, string(fsm.StateIdle), status.State) + + unknown := ctrl.Handle(context.Background(), ipc.Request{Command: "definitely-unknown"}) + require.False(t, unknown.OK) + require.Contains(t, unknown.Error, "unknown command") +} + +func TestRequestStopAndCancelStateGuards(t *testing.T) { + ctrl := NewController(nil, &fakeTranscriber{}, nil, &fakeIndicator{}) + + stopFromIdle := ctrl.Handle(context.Background(), ipc.Request{Command: "stop"}) + require.False(t, stopFromIdle.OK) + require.Contains(t, stopFromIdle.Error, "cannot stop from state idle") + + cancelFromIdle := ctrl.Handle(context.Background(), ipc.Request{Command: "cancel"}) + require.False(t, cancelFromIdle.OK) + require.Contains(t, cancelFromIdle.Error, "cannot cancel from state idle") + + ctrl.mu.Lock() + ctrl.state = fsm.StateTranscribing + ctrl.mu.Unlock() + + stopFromTranscribing := ctrl.Handle(context.Background(), ipc.Request{Command: "stop"}) + require.False(t, stopFromTranscribing.OK) + require.Contains(t, stopFromTranscribing.Error, "already transcribing") + + cancelFromTranscribing := ctrl.Handle(context.Background(), ipc.Request{Command: "cancel"}) + require.False(t, cancelFromTranscribing.OK) + require.Contains(t, cancelFromTranscribing.Error, "cannot cancel while transcribing") +} + +func TestRequestStopAndCancelAlreadyRequested(t *testing.T) { + ctrl := NewController(nil, &fakeTranscriber{}, nil, &fakeIndicator{}) + + ctrl.mu.Lock() + ctrl.state = fsm.StateRecording + ctrl.mu.Unlock() + + ctrl.actions <- actionStop + stop := ctrl.requestStop("stop") + require.True(t, stop.OK) + require.Equal(t, "stop already requested", stop.Message) + + <-ctrl.actions + ctrl.actions <- actionCancel + cancel := ctrl.requestCancel() + require.True(t, cancel.OK) + require.Equal(t, "cancel already requested", cancel.Message) +} + +func TestRunStartFailure(t *testing.T) { + transcriber := &fakeTranscriber{startErr: errors.New("start failed")} + indicator := &fakeIndicator{} + ctrl := NewController(nil, transcriber, nil, indicator) + + result := ctrl.Run(context.Background()) + require.Error(t, result.Err) + require.Equal(t, fsm.StateIdle, result.State) + require.NotZero(t, result.FinishedAt) + require.Equal(t, int32(0), indicator.stopCues.Load()) + require.Equal(t, int32(0), indicator.completeCues.Load()) +} + +func TestRunCommitFailure(t *testing.T) { + indicator := &fakeIndicator{} + ctrl := NewController( + nil, + &fakeTranscriber{transcript: "hello world"}, + CommitFunc(func(context.Context, string) error { return errors.New("commit failed") }), + indicator, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resultCh := make(chan Result, 1) + go func() { + resultCh <- ctrl.Run(ctx) + }() + + waitForState(t, ctrl, fsm.StateRecording) + resp := ctrl.Handle(ctx, ipc.Request{Command: "stop"}) + require.True(t, resp.OK) + + result := <-resultCh + require.Error(t, result.Err) + require.Contains(t, result.Err.Error(), "commit failed") + require.Equal(t, int32(1), indicator.stopCues.Load()) + require.Equal(t, int32(0), indicator.completeCues.Load()) +} + +func TestRunContextCancelled(t *testing.T) { + indicator := &fakeIndicator{} + ctrl := NewController(nil, &fakeTranscriber{}, nil, indicator) + + ctx, cancel := context.WithCancel(context.Background()) + resultCh := make(chan Result, 1) + go func() { + resultCh <- ctrl.Run(ctx) + }() + + waitForState(t, ctrl, fsm.StateRecording) + cancel() + + result := <-resultCh + require.ErrorIs(t, result.Err, context.Canceled) + require.Equal(t, fsm.StateIdle, result.State) + require.Equal(t, int32(1), indicator.cancelCues.Load()) + require.False(t, result.Cancelled) +} + +func TestRunUnknownAction(t *testing.T) { + ctrl := NewController(nil, &fakeTranscriber{}, nil, &fakeIndicator{}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resultCh := make(chan Result, 1) + go func() { + resultCh <- ctrl.Run(ctx) + }() + + waitForState(t, ctrl, fsm.StateRecording) + ctrl.actions <- action(99) + + result := <-resultCh + require.Error(t, result.Err) + require.Contains(t, result.Err.Error(), "unknown action") + require.Equal(t, fsm.StateIdle, result.State) +} + +func TestIsPipelineUnavailable(t *testing.T) { + require.True(t, IsPipelineUnavailable(ErrPipelineUnavailable)) + require.False(t, IsPipelineUnavailable(errors.New("different error"))) + require.False(t, IsPipelineUnavailable(nil)) +} + +func TestPlaceholderTranscriberContract(t *testing.T) { + p := PlaceholderTranscriber{} + require.NoError(t, p.Start(context.Background())) + + result, err := p.StopAndTranscribe(context.Background()) + require.ErrorIs(t, err, ErrPipelineUnavailable) + require.Equal(t, StopResult{}, result) + + require.NoError(t, p.Cancel(context.Background())) +} + +func TestCommitFuncDelegates(t *testing.T) { + called := false + commit := CommitFunc(func(_ context.Context, transcript string) error { + called = true + require.Equal(t, "hello", transcript) + return nil + }) + + require.NoError(t, commit.Commit(context.Background(), "hello")) + require.True(t, called) +} + +func TestResultTimestampsAdvance(t *testing.T) { + ctrl := NewController(nil, &fakeTranscriber{transcript: "ok"}, nil, &fakeIndicator{}) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resultCh := make(chan Result, 1) + go func() { + resultCh <- ctrl.Run(ctx) + }() + + waitForState(t, ctrl, fsm.StateRecording) + require.True(t, ctrl.Handle(ctx, ipc.Request{Command: "stop"}).OK) + result := <-resultCh + + require.False(t, result.StartedAt.IsZero()) + require.False(t, result.FinishedAt.IsZero()) + require.True(t, result.FinishedAt.After(result.StartedAt) || result.FinishedAt.Equal(result.StartedAt)) + require.LessOrEqual(t, result.FinishedAt.Sub(result.StartedAt), 2*time.Second) +} diff --git a/apps/sotto/internal/session/session_test.go b/apps/sotto/internal/session/session_test.go new file mode 100644 index 0000000..f3fcda2 --- /dev/null +++ b/apps/sotto/internal/session/session_test.go @@ -0,0 +1,236 @@ +package session + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" + + "github.com/rbright/sotto/internal/fsm" + "github.com/rbright/sotto/internal/ipc" +) + +type fakeIndicator struct { + stopCues atomic.Int32 + completeCues atomic.Int32 + cancelCues atomic.Int32 +} + +func (*fakeIndicator) ShowRecording(context.Context) {} +func (*fakeIndicator) ShowTranscribing(context.Context) {} +func (*fakeIndicator) ShowError(context.Context, string) {} +func (f *fakeIndicator) CueStop(context.Context) { f.stopCues.Add(1) } +func (f *fakeIndicator) CueComplete(context.Context) { f.completeCues.Add(1) } +func (f *fakeIndicator) CueCancel(context.Context) { f.cancelCues.Add(1) } +func (*fakeIndicator) Hide(context.Context) {} +func (*fakeIndicator) FocusedMonitor() string { return "DP-1" } + +type fakeTranscriber struct { + startErr error + transcript string + stopErr error + cancelCalls atomic.Int32 +} + +func (f *fakeTranscriber) Start(context.Context) error { + return f.startErr +} + +func (f *fakeTranscriber) StopAndTranscribe(context.Context) (StopResult, error) { + return StopResult{ + Transcript: f.transcript, + AudioDevice: "test mic", + BytesCaptured: 3200, + GRPCLatency: 200 * time.Millisecond, + }, f.stopErr +} + +func (f *fakeTranscriber) Cancel(context.Context) error { + f.cancelCalls.Add(1) + return nil +} + +func TestControllerCancel(t *testing.T) { + transcriber := &fakeTranscriber{} + ind := &fakeIndicator{} + ctrl := NewController(nil, transcriber, nil, ind) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resultCh := make(chan Result, 1) + go func() { + resultCh <- ctrl.Run(ctx) + }() + + waitForState(t, ctrl, fsm.StateRecording) + resp := ctrl.Handle(ctx, ipc.Request{Command: "cancel"}) + if !resp.OK { + t.Fatalf("cancel response not OK: %+v", resp) + } + + result := <-resultCh + if !result.Cancelled { + t.Fatalf("expected cancelled result, got %+v", result) + } + if state := ctrl.State(); state != fsm.StateIdle { + t.Fatalf("expected idle state after cancel, got %s", state) + } + if transcriber.cancelCalls.Load() == 0 { + t.Fatalf("expected cancel to propagate to transcriber") + } + if ind.cancelCues.Load() == 0 { + t.Fatalf("expected cancel cue to play") + } + if ind.stopCues.Load() != 0 { + t.Fatalf("expected no stop cue on cancel") + } + if ind.completeCues.Load() != 0 { + t.Fatalf("expected no complete cue on cancel") + } +} + +func TestControllerStopCommitsTranscript(t *testing.T) { + var committed atomic.Bool + ind := &fakeIndicator{} + ctrl := NewController( + nil, + &fakeTranscriber{transcript: "hello world"}, + CommitFunc(func(context.Context, string) error { + committed.Store(true) + return nil + }), + ind, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resultCh := make(chan Result, 1) + go func() { + resultCh <- ctrl.Run(ctx) + }() + + waitForState(t, ctrl, fsm.StateRecording) + resp := ctrl.Handle(ctx, ipc.Request{Command: "stop"}) + if !resp.OK { + t.Fatalf("stop response not OK: %+v", resp) + } + + result := <-resultCh + if result.Err != nil { + t.Fatalf("unexpected result error: %v", result.Err) + } + if result.Transcript != "hello world" { + t.Fatalf("unexpected transcript: %q", result.Transcript) + } + if result.AudioDevice != "test mic" { + t.Fatalf("unexpected audio device: %q", result.AudioDevice) + } + if result.BytesCaptured != 3200 { + t.Fatalf("unexpected bytes captured: %d", result.BytesCaptured) + } + if !committed.Load() { + t.Fatalf("expected committer to run") + } + if ind.stopCues.Load() == 0 { + t.Fatalf("expected stop cue to play") + } + if ind.cancelCues.Load() != 0 { + t.Fatalf("expected no cancel cue on stop") + } + if ind.completeCues.Load() == 0 { + t.Fatalf("expected complete cue on successful commit") + } +} + +func TestControllerStopPipelineError(t *testing.T) { + ind := &fakeIndicator{} + ctrl := NewController(nil, &fakeTranscriber{stopErr: ErrPipelineUnavailable}, nil, ind) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resultCh := make(chan Result, 1) + go func() { + resultCh <- ctrl.Run(ctx) + }() + + waitForState(t, ctrl, fsm.StateRecording) + resp := ctrl.Handle(ctx, ipc.Request{Command: "toggle"}) + if !resp.OK { + t.Fatalf("toggle response not OK: %+v", resp) + } + + result := <-resultCh + if !errors.Is(result.Err, ErrPipelineUnavailable) { + t.Fatalf("unexpected result error: %v", result.Err) + } + if state := ctrl.State(); state != fsm.StateIdle { + t.Fatalf("expected idle after error reset, got %s", state) + } + if ind.stopCues.Load() == 0 { + t.Fatalf("expected stop cue even on pipeline error") + } + if ind.completeCues.Load() != 0 { + t.Fatalf("did not expect complete cue when stop fails") + } +} + +func TestControllerStopEmptyTranscriptReturnsError(t *testing.T) { + var committed atomic.Bool + ind := &fakeIndicator{} + ctrl := NewController( + nil, + &fakeTranscriber{transcript: ""}, + CommitFunc(func(context.Context, string) error { + committed.Store(true) + return nil + }), + ind, + ) + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + resultCh := make(chan Result, 1) + go func() { + resultCh <- ctrl.Run(ctx) + }() + + waitForState(t, ctrl, fsm.StateRecording) + resp := ctrl.Handle(ctx, ipc.Request{Command: "stop"}) + if !resp.OK { + t.Fatalf("stop response not OK: %+v", resp) + } + + result := <-resultCh + if !errors.Is(result.Err, ErrEmptyTranscript) { + t.Fatalf("unexpected result error: %v", result.Err) + } + if committed.Load() { + t.Fatalf("expected committer not to run for empty transcript") + } + if state := ctrl.State(); state != fsm.StateIdle { + t.Fatalf("expected idle after empty transcript error reset, got %s", state) + } + if ind.stopCues.Load() == 0 { + t.Fatalf("expected stop cue on empty transcript") + } + if ind.completeCues.Load() != 0 { + t.Fatalf("did not expect complete cue on empty transcript") + } +} + +func waitForState(t *testing.T, ctrl *Controller, desired fsm.State) { + t.Helper() + deadline := time.Now().Add(2 * time.Second) + for time.Now().Before(deadline) { + if ctrl.State() == desired { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatalf("timed out waiting for state %s (current=%s)", desired, ctrl.State()) +} diff --git a/apps/sotto/internal/session/transcriber.go b/apps/sotto/internal/session/transcriber.go new file mode 100644 index 0000000..d90be38 --- /dev/null +++ b/apps/sotto/internal/session/transcriber.go @@ -0,0 +1,44 @@ +package session + +import ( + "context" + "errors" + "time" +) + +var ( + // ErrPipelineUnavailable indicates runtime transcriber wiring is missing. + ErrPipelineUnavailable = errors.New("audio capture and ASR pipeline not implemented") + // ErrEmptyTranscript indicates stop completed but no usable speech was recognized. + ErrEmptyTranscript = errors.New("no speech recognized; check microphone input or mute state") +) + +// StopResult is the transcriber output consumed by the session controller. +type StopResult struct { + Transcript string + AudioDevice string + BytesCaptured int64 + GRPCLatency time.Duration +} + +// Transcriber abstracts capture/ASR operations needed by session orchestration. +type Transcriber interface { + Start(context.Context) error + StopAndTranscribe(context.Context) (StopResult, error) + Cancel(context.Context) error +} + +// PlaceholderTranscriber is a no-op placeholder used in tests/fallback wiring. +type PlaceholderTranscriber struct{} + +func (PlaceholderTranscriber) Start(context.Context) error { + return nil +} + +func (PlaceholderTranscriber) StopAndTranscribe(context.Context) (StopResult, error) { + return StopResult{}, ErrPipelineUnavailable +} + +func (PlaceholderTranscriber) Cancel(context.Context) error { + return nil +} diff --git a/apps/sotto/internal/transcript/assemble.go b/apps/sotto/internal/transcript/assemble.go new file mode 100644 index 0000000..e426672 --- /dev/null +++ b/apps/sotto/internal/transcript/assemble.go @@ -0,0 +1,22 @@ +// Package transcript assembles and normalizes recognized ASR segments. +package transcript + +import "strings" + +// Assemble joins final ASR segments and applies whitespace/trailing-space normalization. +func Assemble(finalSegments []string, trailingSpace bool) string { + if len(finalSegments) == 0 { + return "" + } + + joined := strings.Join(finalSegments, " ") + normalized := strings.Join(strings.Fields(joined), " ") + if normalized == "" { + return "" + } + + if trailingSpace { + return normalized + " " + } + return normalized +} diff --git a/apps/sotto/internal/transcript/assemble_test.go b/apps/sotto/internal/transcript/assemble_test.go new file mode 100644 index 0000000..89dc704 --- /dev/null +++ b/apps/sotto/internal/transcript/assemble_test.go @@ -0,0 +1,42 @@ +package transcript + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAssembleNormalizesWhitespaceAndTrailingSpace(t *testing.T) { + t.Parallel() + + got := Assemble([]string{" hello", "world ", "\nfrom", "sotto"}, true) + require.Equal(t, "hello world from sotto ", got) +} + +func TestAssembleWithoutTrailingSpace(t *testing.T) { + t.Parallel() + + got := Assemble([]string{"hello", "world"}, false) + require.Equal(t, "hello world", got) +} + +func TestAssembleEmptyInput(t *testing.T) { + t.Parallel() + + require.Empty(t, Assemble(nil, true)) +} + +func TestAssembleSkipsWhitespaceOnlySegments(t *testing.T) { + t.Parallel() + + got := Assemble([]string{" ", "\n\t", "hello"}, false) + require.Equal(t, "hello", got) +} + +func TestAssembleIdempotentForNormalizedOutput(t *testing.T) { + t.Parallel() + + first := Assemble([]string{"hello", "world"}, false) + second := Assemble([]string{first}, false) + require.Equal(t, first, second) +} diff --git a/apps/sotto/internal/version/version.go b/apps/sotto/internal/version/version.go new file mode 100644 index 0000000..5992d58 --- /dev/null +++ b/apps/sotto/internal/version/version.go @@ -0,0 +1,15 @@ +// Package version exposes build metadata used by `sotto version`. +package version + +import "runtime" + +var ( + Version = "dev" + Commit = "none" + Date = "unknown" +) + +// String returns build metadata in the user-facing version output format. +func String() string { + return "sotto " + Version + " (commit=" + Commit + ", date=" + Date + ", go=" + runtime.Version() + ")" +} diff --git a/apps/sotto/internal/version/version_test.go b/apps/sotto/internal/version/version_test.go new file mode 100644 index 0000000..aa99dcc --- /dev/null +++ b/apps/sotto/internal/version/version_test.go @@ -0,0 +1,28 @@ +package version + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestStringIncludesBuildMetadata(t *testing.T) { + originalVersion := Version + originalCommit := Commit + originalDate := Date + t.Cleanup(func() { + Version = originalVersion + Commit = originalCommit + Date = originalDate + }) + + Version = "1.2.3" + Commit = "abc123" + Date = "2026-02-18" + + got := String() + require.Contains(t, got, "sotto 1.2.3") + require.Contains(t, got, "commit=abc123") + require.Contains(t, got, "date=2026-02-18") + require.Contains(t, got, "go=") +} diff --git a/apps/sotto/proto/gen/go/riva/proto/riva_asr.pb.go b/apps/sotto/proto/gen/go/riva/proto/riva_asr.pb.go new file mode 100644 index 0000000..fd9eb44 --- /dev/null +++ b/apps/sotto/proto/gen/go/riva/proto/riva_asr.pb.go @@ -0,0 +1,1606 @@ +// SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: MIT + +// Copyright 2019 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.6 +// protoc (unknown) +// source: riva/proto/riva_asr.proto + +package proto + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type RivaSpeechRecognitionConfigRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // If model is specified only return config for model, otherwise return all + // configs. + ModelName string `protobuf:"bytes,1,opt,name=model_name,json=modelName,proto3" json:"model_name,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RivaSpeechRecognitionConfigRequest) Reset() { + *x = RivaSpeechRecognitionConfigRequest{} + mi := &file_riva_proto_riva_asr_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RivaSpeechRecognitionConfigRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RivaSpeechRecognitionConfigRequest) ProtoMessage() {} + +func (x *RivaSpeechRecognitionConfigRequest) ProtoReflect() protoreflect.Message { + mi := &file_riva_proto_riva_asr_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RivaSpeechRecognitionConfigRequest.ProtoReflect.Descriptor instead. +func (*RivaSpeechRecognitionConfigRequest) Descriptor() ([]byte, []int) { + return file_riva_proto_riva_asr_proto_rawDescGZIP(), []int{0} +} + +func (x *RivaSpeechRecognitionConfigRequest) GetModelName() string { + if x != nil { + return x.ModelName + } + return "" +} + +type RivaSpeechRecognitionConfigResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + ModelConfig []*RivaSpeechRecognitionConfigResponse_Config `protobuf:"bytes,1,rep,name=model_config,json=modelConfig,proto3" json:"model_config,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RivaSpeechRecognitionConfigResponse) Reset() { + *x = RivaSpeechRecognitionConfigResponse{} + mi := &file_riva_proto_riva_asr_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RivaSpeechRecognitionConfigResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RivaSpeechRecognitionConfigResponse) ProtoMessage() {} + +func (x *RivaSpeechRecognitionConfigResponse) ProtoReflect() protoreflect.Message { + mi := &file_riva_proto_riva_asr_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RivaSpeechRecognitionConfigResponse.ProtoReflect.Descriptor instead. +func (*RivaSpeechRecognitionConfigResponse) Descriptor() ([]byte, []int) { + return file_riva_proto_riva_asr_proto_rawDescGZIP(), []int{1} +} + +func (x *RivaSpeechRecognitionConfigResponse) GetModelConfig() []*RivaSpeechRecognitionConfigResponse_Config { + if x != nil { + return x.ModelConfig + } + return nil +} + +// RecognizeRequest is used for batch processing of a single audio recording. +type RecognizeRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Provides information to recognizer that specifies how to process the + // request. + Config *RecognitionConfig `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"` + // The raw audio data to be processed. The audio bytes must be encoded as + // specified in `RecognitionConfig`. + Audio []byte `protobuf:"bytes,2,opt,name=audio,proto3" json:"audio,omitempty"` + // The ID to be associated with the request. If provided, this will be + // returned in the corresponding response. + Id *RequestId `protobuf:"bytes,100,opt,name=id,proto3" json:"id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RecognizeRequest) Reset() { + *x = RecognizeRequest{} + mi := &file_riva_proto_riva_asr_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RecognizeRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RecognizeRequest) ProtoMessage() {} + +func (x *RecognizeRequest) ProtoReflect() protoreflect.Message { + mi := &file_riva_proto_riva_asr_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RecognizeRequest.ProtoReflect.Descriptor instead. +func (*RecognizeRequest) Descriptor() ([]byte, []int) { + return file_riva_proto_riva_asr_proto_rawDescGZIP(), []int{2} +} + +func (x *RecognizeRequest) GetConfig() *RecognitionConfig { + if x != nil { + return x.Config + } + return nil +} + +func (x *RecognizeRequest) GetAudio() []byte { + if x != nil { + return x.Audio + } + return nil +} + +func (x *RecognizeRequest) GetId() *RequestId { + if x != nil { + return x.Id + } + return nil +} + +// A StreamingRecognizeRequest is used to configure and stream audio content to +// the Riva ASR Service. The first message sent must include only a +// StreamingRecognitionConfig. Subsequent messages sent in the stream must +// contain only raw bytes of the audio to be recognized. +type StreamingRecognizeRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The streaming request, which is either a streaming config or audio content. + // + // Types that are valid to be assigned to StreamingRequest: + // + // *StreamingRecognizeRequest_StreamingConfig + // *StreamingRecognizeRequest_AudioContent + StreamingRequest isStreamingRecognizeRequest_StreamingRequest `protobuf_oneof:"streaming_request"` + // The ID to be associated with the request. If provided, this will be + // returned in the corresponding responses. + Id *RequestId `protobuf:"bytes,100,opt,name=id,proto3" json:"id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StreamingRecognizeRequest) Reset() { + *x = StreamingRecognizeRequest{} + mi := &file_riva_proto_riva_asr_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StreamingRecognizeRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StreamingRecognizeRequest) ProtoMessage() {} + +func (x *StreamingRecognizeRequest) ProtoReflect() protoreflect.Message { + mi := &file_riva_proto_riva_asr_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StreamingRecognizeRequest.ProtoReflect.Descriptor instead. +func (*StreamingRecognizeRequest) Descriptor() ([]byte, []int) { + return file_riva_proto_riva_asr_proto_rawDescGZIP(), []int{3} +} + +func (x *StreamingRecognizeRequest) GetStreamingRequest() isStreamingRecognizeRequest_StreamingRequest { + if x != nil { + return x.StreamingRequest + } + return nil +} + +func (x *StreamingRecognizeRequest) GetStreamingConfig() *StreamingRecognitionConfig { + if x != nil { + if x, ok := x.StreamingRequest.(*StreamingRecognizeRequest_StreamingConfig); ok { + return x.StreamingConfig + } + } + return nil +} + +func (x *StreamingRecognizeRequest) GetAudioContent() []byte { + if x != nil { + if x, ok := x.StreamingRequest.(*StreamingRecognizeRequest_AudioContent); ok { + return x.AudioContent + } + } + return nil +} + +func (x *StreamingRecognizeRequest) GetId() *RequestId { + if x != nil { + return x.Id + } + return nil +} + +type isStreamingRecognizeRequest_StreamingRequest interface { + isStreamingRecognizeRequest_StreamingRequest() +} + +type StreamingRecognizeRequest_StreamingConfig struct { + // Provides information to the recognizer that specifies how to process the + // request. The first `StreamingRecognizeRequest` message must contain a + // `streaming_config` message. + StreamingConfig *StreamingRecognitionConfig `protobuf:"bytes,1,opt,name=streaming_config,json=streamingConfig,proto3,oneof"` +} + +type StreamingRecognizeRequest_AudioContent struct { + // The audio data to be recognized. Sequential chunks of audio data are sent + // in sequential `StreamingRecognizeRequest` messages. The first + // `StreamingRecognizeRequest` message must not contain `audio` data + // and all subsequent `StreamingRecognizeRequest` messages must contain + // `audio` data. The audio bytes must be encoded as specified in + // `RecognitionConfig`. + AudioContent []byte `protobuf:"bytes,2,opt,name=audio_content,json=audioContent,proto3,oneof"` +} + +func (*StreamingRecognizeRequest_StreamingConfig) isStreamingRecognizeRequest_StreamingRequest() {} + +func (*StreamingRecognizeRequest_AudioContent) isStreamingRecognizeRequest_StreamingRequest() {} + +// EndpointingConfig is used for configuring different fields related to start +// or end of utterance +type EndpointingConfig struct { + state protoimpl.MessageState `protogen:"open.v1"` + // `start_history` is the size of the window, in milliseconds, used to + // detect start of utterance. + // `start_threshold` is the percentage threshold used to detect start of + // utterance. (0.0 to 1.0) + // If `start_threshold` of `start_history` ms of the acoustic model output + // have non-blank tokens, start of utterance is detected. + StartHistory *int32 `protobuf:"varint,1,opt,name=start_history,json=startHistory,proto3,oneof" json:"start_history,omitempty"` + StartThreshold *float32 `protobuf:"fixed32,2,opt,name=start_threshold,json=startThreshold,proto3,oneof" json:"start_threshold,omitempty"` + // `stop_history` is the size of the window, in milliseconds, used to + // detect end of utterance. + // `stop_threshold` is the percentage threshold used to detect end of + // utterance. (0.0 to 1.0) + // If `stop_threshold` of `stop_history` ms of the acoustic model output have + // non-blank tokens, end of utterance is detected and decoder will be reset. + StopHistory *int32 `protobuf:"varint,3,opt,name=stop_history,json=stopHistory,proto3,oneof" json:"stop_history,omitempty"` + StopThreshold *float32 `protobuf:"fixed32,4,opt,name=stop_threshold,json=stopThreshold,proto3,oneof" json:"stop_threshold,omitempty"` + // `stop_history_eou` and `stop_threshold_eou` are used for 2-pass end of utterance. + // `stop_history_eou` is the size of the window, in milliseconds, used to + // trigger 1st pass of end of utterance and generate a partial transcript + // with stability of 1. (stop_history_eou < stop_history) + // `stop_threshold_eou` is the percentage threshold used to trigger 1st + // pass of end of utterance. (0.0 to 1.0) + // If `stop_threshold_eou` of `stop_history_eou` ms of the acoustic model + // output have non-blank tokens, 1st pass of end of utterance is triggered. + StopHistoryEou *int32 `protobuf:"varint,5,opt,name=stop_history_eou,json=stopHistoryEou,proto3,oneof" json:"stop_history_eou,omitempty"` + StopThresholdEou *float32 `protobuf:"fixed32,6,opt,name=stop_threshold_eou,json=stopThresholdEou,proto3,oneof" json:"stop_threshold_eou,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *EndpointingConfig) Reset() { + *x = EndpointingConfig{} + mi := &file_riva_proto_riva_asr_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *EndpointingConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*EndpointingConfig) ProtoMessage() {} + +func (x *EndpointingConfig) ProtoReflect() protoreflect.Message { + mi := &file_riva_proto_riva_asr_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use EndpointingConfig.ProtoReflect.Descriptor instead. +func (*EndpointingConfig) Descriptor() ([]byte, []int) { + return file_riva_proto_riva_asr_proto_rawDescGZIP(), []int{4} +} + +func (x *EndpointingConfig) GetStartHistory() int32 { + if x != nil && x.StartHistory != nil { + return *x.StartHistory + } + return 0 +} + +func (x *EndpointingConfig) GetStartThreshold() float32 { + if x != nil && x.StartThreshold != nil { + return *x.StartThreshold + } + return 0 +} + +func (x *EndpointingConfig) GetStopHistory() int32 { + if x != nil && x.StopHistory != nil { + return *x.StopHistory + } + return 0 +} + +func (x *EndpointingConfig) GetStopThreshold() float32 { + if x != nil && x.StopThreshold != nil { + return *x.StopThreshold + } + return 0 +} + +func (x *EndpointingConfig) GetStopHistoryEou() int32 { + if x != nil && x.StopHistoryEou != nil { + return *x.StopHistoryEou + } + return 0 +} + +func (x *EndpointingConfig) GetStopThresholdEou() float32 { + if x != nil && x.StopThresholdEou != nil { + return *x.StopThresholdEou + } + return 0 +} + +// Provides information to the recognizer that specifies how to process the +// request +type RecognitionConfig struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The encoding of the audio data sent in the request. + // + // All encodings support only 1 channel (mono) audio. + Encoding AudioEncoding `protobuf:"varint,1,opt,name=encoding,proto3,enum=nvidia.riva.AudioEncoding" json:"encoding,omitempty"` + // The sample rate in hertz (Hz) of the audio data sent in the + // + // `RecognizeRequest` or `StreamingRecognizeRequest` messages. + // + // The Riva server will automatically down-sample/up-sample the audio to + // match the ASR acoustic model sample rate. The sample rate value below 8kHz + // will not produce any meaningful output. + SampleRateHertz int32 `protobuf:"varint,2,opt,name=sample_rate_hertz,json=sampleRateHertz,proto3" json:"sample_rate_hertz,omitempty"` + // Required. The language of the supplied audio as a + // [BCP-47](https://www.rfc-editor.org/rfc/bcp/bcp47.txt) language tag. + // Example: "en-US". + LanguageCode string `protobuf:"bytes,3,opt,name=language_code,json=languageCode,proto3" json:"language_code,omitempty"` + // Maximum number of recognition hypotheses to be returned. + // Specifically, the maximum number of `SpeechRecognizeAlternative` messages + // within each `SpeechRecognizeResult`. + // The server may return fewer than `max_alternatives`. + // If omitted, will return a maximum of one. + MaxAlternatives int32 `protobuf:"varint,4,opt,name=max_alternatives,json=maxAlternatives,proto3" json:"max_alternatives,omitempty"` + // A custom field that enables profanity filtering for the generated + // transcripts. If set to 'true', the server filters out profanities, + // replacing all but the initial character in each filtered word with + // asterisks. For example, "x**". If set to `false` or omitted, profanities + // will not be filtered out. The default is `false`. + ProfanityFilter bool `protobuf:"varint,5,opt,name=profanity_filter,json=profanityFilter,proto3" json:"profanity_filter,omitempty"` + // Array of SpeechContext. + // A means to provide context to assist the speech recognition. For more + // information, see SpeechContext section + SpeechContexts []*SpeechContext `protobuf:"bytes,6,rep,name=speech_contexts,json=speechContexts,proto3" json:"speech_contexts,omitempty"` + // The number of channels in the input audio data. + // If `0` or omitted, defaults to one channel (mono). + // Note: Only single channel audio input is supported as of now. + AudioChannelCount int32 `protobuf:"varint,7,opt,name=audio_channel_count,json=audioChannelCount,proto3" json:"audio_channel_count,omitempty"` + // If `true`, the top result includes a list of words and the start and end + // time offsets (timestamps), and confidence scores for those words. If + // `false`, no word-level time offset information is returned. The default + // is `false`. + EnableWordTimeOffsets bool `protobuf:"varint,8,opt,name=enable_word_time_offsets,json=enableWordTimeOffsets,proto3" json:"enable_word_time_offsets,omitempty"` + // If 'true', adds punctuation to recognition result hypotheses. The + // default 'false' value does not add punctuation to result hypotheses. + EnableAutomaticPunctuation bool `protobuf:"varint,11,opt,name=enable_automatic_punctuation,json=enableAutomaticPunctuation,proto3" json:"enable_automatic_punctuation,omitempty"` + // This needs to be set to `true` explicitly and `audio_channel_count` > 1 + // to get each channel recognized separately. The recognition result will + // contain a `channel_tag` field to state which channel that result belongs + // to. If this is not true, we will only recognize the first channel. The + // request is billed cumulatively for all channels recognized: + // `audio_channel_count` multiplied by the length of the audio. + // Note: This field is not yet supported. + EnableSeparateRecognitionPerChannel bool `protobuf:"varint,12,opt,name=enable_separate_recognition_per_channel,json=enableSeparateRecognitionPerChannel,proto3" json:"enable_separate_recognition_per_channel,omitempty"` + // Which model to select for the given request. + // If empty, Riva will select the right model based on the other + // RecognitionConfig parameters. The model should correspond to the name + // passed to `riva-build` with the `--name` argument + Model string `protobuf:"bytes,13,opt,name=model,proto3" json:"model,omitempty"` + // The verbatim_transcripts flag enables or disable inverse text + // normalization. 'true' returns exactly what was said, with no + // denormalization. 'false' applies inverse text normalization, also this is + // the default + VerbatimTranscripts bool `protobuf:"varint,14,opt,name=verbatim_transcripts,json=verbatimTranscripts,proto3" json:"verbatim_transcripts,omitempty"` + // Config to enable speaker diarization and set additional + // parameters. For non-streaming requests, the diarization results will be + // provided only in the top alternative of the FINAL SpeechRecognitionResult. + DiarizationConfig *SpeakerDiarizationConfig `protobuf:"bytes,19,opt,name=diarization_config,json=diarizationConfig,proto3" json:"diarization_config,omitempty"` + // Custom fields for passing request-level + // configuration options to plugins used in the + // model pipeline. + CustomConfiguration map[string]string `protobuf:"bytes,24,rep,name=custom_configuration,json=customConfiguration,proto3" json:"custom_configuration,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + // Config for tuning start or end of utterance parameters. + // If empty, Riva will use default values or custom values if specified in riva-build arguments. + EndpointingConfig *EndpointingConfig `protobuf:"bytes,25,opt,name=endpointing_config,json=endpointingConfig,proto3,oneof" json:"endpointing_config,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RecognitionConfig) Reset() { + *x = RecognitionConfig{} + mi := &file_riva_proto_riva_asr_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RecognitionConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RecognitionConfig) ProtoMessage() {} + +func (x *RecognitionConfig) ProtoReflect() protoreflect.Message { + mi := &file_riva_proto_riva_asr_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RecognitionConfig.ProtoReflect.Descriptor instead. +func (*RecognitionConfig) Descriptor() ([]byte, []int) { + return file_riva_proto_riva_asr_proto_rawDescGZIP(), []int{5} +} + +func (x *RecognitionConfig) GetEncoding() AudioEncoding { + if x != nil { + return x.Encoding + } + return AudioEncoding_ENCODING_UNSPECIFIED +} + +func (x *RecognitionConfig) GetSampleRateHertz() int32 { + if x != nil { + return x.SampleRateHertz + } + return 0 +} + +func (x *RecognitionConfig) GetLanguageCode() string { + if x != nil { + return x.LanguageCode + } + return "" +} + +func (x *RecognitionConfig) GetMaxAlternatives() int32 { + if x != nil { + return x.MaxAlternatives + } + return 0 +} + +func (x *RecognitionConfig) GetProfanityFilter() bool { + if x != nil { + return x.ProfanityFilter + } + return false +} + +func (x *RecognitionConfig) GetSpeechContexts() []*SpeechContext { + if x != nil { + return x.SpeechContexts + } + return nil +} + +func (x *RecognitionConfig) GetAudioChannelCount() int32 { + if x != nil { + return x.AudioChannelCount + } + return 0 +} + +func (x *RecognitionConfig) GetEnableWordTimeOffsets() bool { + if x != nil { + return x.EnableWordTimeOffsets + } + return false +} + +func (x *RecognitionConfig) GetEnableAutomaticPunctuation() bool { + if x != nil { + return x.EnableAutomaticPunctuation + } + return false +} + +func (x *RecognitionConfig) GetEnableSeparateRecognitionPerChannel() bool { + if x != nil { + return x.EnableSeparateRecognitionPerChannel + } + return false +} + +func (x *RecognitionConfig) GetModel() string { + if x != nil { + return x.Model + } + return "" +} + +func (x *RecognitionConfig) GetVerbatimTranscripts() bool { + if x != nil { + return x.VerbatimTranscripts + } + return false +} + +func (x *RecognitionConfig) GetDiarizationConfig() *SpeakerDiarizationConfig { + if x != nil { + return x.DiarizationConfig + } + return nil +} + +func (x *RecognitionConfig) GetCustomConfiguration() map[string]string { + if x != nil { + return x.CustomConfiguration + } + return nil +} + +func (x *RecognitionConfig) GetEndpointingConfig() *EndpointingConfig { + if x != nil { + return x.EndpointingConfig + } + return nil +} + +// Provides information to the recognizer that specifies how to process the +// request +type StreamingRecognitionConfig struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Provides information to the recognizer that specifies how to process the + // request + Config *RecognitionConfig `protobuf:"bytes,1,opt,name=config,proto3" json:"config,omitempty"` + // If `true`, interim results (tentative hypotheses) may be + // returned as they become available (these interim results are indicated with + // the `is_final=false` flag). + // If `false` or omitted, only `is_final=true` result(s) are returned. + InterimResults bool `protobuf:"varint,2,opt,name=interim_results,json=interimResults,proto3" json:"interim_results,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StreamingRecognitionConfig) Reset() { + *x = StreamingRecognitionConfig{} + mi := &file_riva_proto_riva_asr_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StreamingRecognitionConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StreamingRecognitionConfig) ProtoMessage() {} + +func (x *StreamingRecognitionConfig) ProtoReflect() protoreflect.Message { + mi := &file_riva_proto_riva_asr_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StreamingRecognitionConfig.ProtoReflect.Descriptor instead. +func (*StreamingRecognitionConfig) Descriptor() ([]byte, []int) { + return file_riva_proto_riva_asr_proto_rawDescGZIP(), []int{6} +} + +func (x *StreamingRecognitionConfig) GetConfig() *RecognitionConfig { + if x != nil { + return x.Config + } + return nil +} + +func (x *StreamingRecognitionConfig) GetInterimResults() bool { + if x != nil { + return x.InterimResults + } + return false +} + +// Config to enable speaker diarization. +type SpeakerDiarizationConfig struct { + state protoimpl.MessageState `protogen:"open.v1"` + // If 'true', enables speaker detection for each recognized word in + // the top alternative of the recognition result using a speaker_tag provided + // in the WordInfo. + EnableSpeakerDiarization bool `protobuf:"varint,1,opt,name=enable_speaker_diarization,json=enableSpeakerDiarization,proto3" json:"enable_speaker_diarization,omitempty"` + // Maximum number of speakers in the conversation. This gives flexibility by + // allowing the system to automatically determine the correct number of + // speakers. If not set, the default value is 8. + MaxSpeakerCount int32 `protobuf:"varint,2,opt,name=max_speaker_count,json=maxSpeakerCount,proto3" json:"max_speaker_count,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SpeakerDiarizationConfig) Reset() { + *x = SpeakerDiarizationConfig{} + mi := &file_riva_proto_riva_asr_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SpeakerDiarizationConfig) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SpeakerDiarizationConfig) ProtoMessage() {} + +func (x *SpeakerDiarizationConfig) ProtoReflect() protoreflect.Message { + mi := &file_riva_proto_riva_asr_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SpeakerDiarizationConfig.ProtoReflect.Descriptor instead. +func (*SpeakerDiarizationConfig) Descriptor() ([]byte, []int) { + return file_riva_proto_riva_asr_proto_rawDescGZIP(), []int{7} +} + +func (x *SpeakerDiarizationConfig) GetEnableSpeakerDiarization() bool { + if x != nil { + return x.EnableSpeakerDiarization + } + return false +} + +func (x *SpeakerDiarizationConfig) GetMaxSpeakerCount() int32 { + if x != nil { + return x.MaxSpeakerCount + } + return 0 +} + +// Provides "hints" to the speech recognizer to favor specific words and phrases +// in the results. +type SpeechContext struct { + state protoimpl.MessageState `protogen:"open.v1"` + // A list of strings containing words and phrases "hints" so that + // the speech recognition is more likely to recognize them. This can be used + // to improve the accuracy for specific words and phrases, for example, if + // specific commands are typically spoken by the user. This can also be used + // to add additional words to the vocabulary of the recognizer. + Phrases []string `protobuf:"bytes,1,rep,name=phrases,proto3" json:"phrases,omitempty"` + // Hint Boost. Positive value will increase the probability that a specific + // phrase will be recognized over other similar sounding phrases. The higher + // the boost, the higher the chance of false positive recognition as well. + // Though `boost` can accept a wide range of positive values, most use cases + // are best served with values between 0 and 20. We recommend using a binary + // search approach to finding the optimal value for your use case. + Boost float32 `protobuf:"fixed32,4,opt,name=boost,proto3" json:"boost,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SpeechContext) Reset() { + *x = SpeechContext{} + mi := &file_riva_proto_riva_asr_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SpeechContext) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SpeechContext) ProtoMessage() {} + +func (x *SpeechContext) ProtoReflect() protoreflect.Message { + mi := &file_riva_proto_riva_asr_proto_msgTypes[8] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SpeechContext.ProtoReflect.Descriptor instead. +func (*SpeechContext) Descriptor() ([]byte, []int) { + return file_riva_proto_riva_asr_proto_rawDescGZIP(), []int{8} +} + +func (x *SpeechContext) GetPhrases() []string { + if x != nil { + return x.Phrases + } + return nil +} + +func (x *SpeechContext) GetBoost() float32 { + if x != nil { + return x.Boost + } + return 0 +} + +// The only message returned to the client by the `Recognize` method. It +// contains the result as zero or more sequential `SpeechRecognitionResult` +// messages. +type RecognizeResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Sequential list of transcription results corresponding to + // sequential portions of audio. Currently only returns one transcript. + Results []*SpeechRecognitionResult `protobuf:"bytes,1,rep,name=results,proto3" json:"results,omitempty"` + // The ID associated with the request + Id *RequestId `protobuf:"bytes,100,opt,name=id,proto3" json:"id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RecognizeResponse) Reset() { + *x = RecognizeResponse{} + mi := &file_riva_proto_riva_asr_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RecognizeResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RecognizeResponse) ProtoMessage() {} + +func (x *RecognizeResponse) ProtoReflect() protoreflect.Message { + mi := &file_riva_proto_riva_asr_proto_msgTypes[9] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RecognizeResponse.ProtoReflect.Descriptor instead. +func (*RecognizeResponse) Descriptor() ([]byte, []int) { + return file_riva_proto_riva_asr_proto_rawDescGZIP(), []int{9} +} + +func (x *RecognizeResponse) GetResults() []*SpeechRecognitionResult { + if x != nil { + return x.Results + } + return nil +} + +func (x *RecognizeResponse) GetId() *RequestId { + if x != nil { + return x.Id + } + return nil +} + +// A speech recognition result corresponding to the latest transcript +type SpeechRecognitionResult struct { + state protoimpl.MessageState `protogen:"open.v1"` + // May contain one or more recognition hypotheses (up to the + // maximum specified in `max_alternatives`). + // These alternatives are ordered in terms of accuracy, with the top (first) + // alternative being the most probable, as ranked by the recognizer. + Alternatives []*SpeechRecognitionAlternative `protobuf:"bytes,1,rep,name=alternatives,proto3" json:"alternatives,omitempty"` + // For multi-channel audio, this is the channel number corresponding to the + // recognized result for the audio from that channel. + // For audio_channel_count = N, its output values can range from '1' to 'N'. + ChannelTag int32 `protobuf:"varint,2,opt,name=channel_tag,json=channelTag,proto3" json:"channel_tag,omitempty"` + // Length of audio processed so far in seconds + AudioProcessed float32 `protobuf:"fixed32,3,opt,name=audio_processed,json=audioProcessed,proto3" json:"audio_processed,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SpeechRecognitionResult) Reset() { + *x = SpeechRecognitionResult{} + mi := &file_riva_proto_riva_asr_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SpeechRecognitionResult) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SpeechRecognitionResult) ProtoMessage() {} + +func (x *SpeechRecognitionResult) ProtoReflect() protoreflect.Message { + mi := &file_riva_proto_riva_asr_proto_msgTypes[10] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SpeechRecognitionResult.ProtoReflect.Descriptor instead. +func (*SpeechRecognitionResult) Descriptor() ([]byte, []int) { + return file_riva_proto_riva_asr_proto_rawDescGZIP(), []int{10} +} + +func (x *SpeechRecognitionResult) GetAlternatives() []*SpeechRecognitionAlternative { + if x != nil { + return x.Alternatives + } + return nil +} + +func (x *SpeechRecognitionResult) GetChannelTag() int32 { + if x != nil { + return x.ChannelTag + } + return 0 +} + +func (x *SpeechRecognitionResult) GetAudioProcessed() float32 { + if x != nil { + return x.AudioProcessed + } + return 0 +} + +// Alternative hypotheses (a.k.a. n-best list). +type SpeechRecognitionAlternative struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Transcript text representing the words that the user spoke. + Transcript string `protobuf:"bytes,1,opt,name=transcript,proto3" json:"transcript,omitempty"` + // The confidence estimate. A higher number indicates an estimated greater + // likelihood that the recognized word is correct. This field is set only for + // a non-streaming result or, for a streaming result where is_final=true. + // This field is not guaranteed to be accurate and users should not rely on + // it to be always provided. Although confidence can currently be roughly + // interpreted as a natural-log probability, the estimate computation varies + // with difference configurations, and is subject to change. The default of + // 0.0 is a sentinel value indicating confidence was not set. + Confidence float32 `protobuf:"fixed32,2,opt,name=confidence,proto3" json:"confidence,omitempty"` + // A list of word-specific information for each recognized word. Only + // populated if is_final=true + Words []*WordInfo `protobuf:"bytes,3,rep,name=words,proto3" json:"words,omitempty"` + // List of language codes detected in the transcript. + LanguageCode []string `protobuf:"bytes,4,rep,name=language_code,json=languageCode,proto3" json:"language_code,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SpeechRecognitionAlternative) Reset() { + *x = SpeechRecognitionAlternative{} + mi := &file_riva_proto_riva_asr_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SpeechRecognitionAlternative) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SpeechRecognitionAlternative) ProtoMessage() {} + +func (x *SpeechRecognitionAlternative) ProtoReflect() protoreflect.Message { + mi := &file_riva_proto_riva_asr_proto_msgTypes[11] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SpeechRecognitionAlternative.ProtoReflect.Descriptor instead. +func (*SpeechRecognitionAlternative) Descriptor() ([]byte, []int) { + return file_riva_proto_riva_asr_proto_rawDescGZIP(), []int{11} +} + +func (x *SpeechRecognitionAlternative) GetTranscript() string { + if x != nil { + return x.Transcript + } + return "" +} + +func (x *SpeechRecognitionAlternative) GetConfidence() float32 { + if x != nil { + return x.Confidence + } + return 0 +} + +func (x *SpeechRecognitionAlternative) GetWords() []*WordInfo { + if x != nil { + return x.Words + } + return nil +} + +func (x *SpeechRecognitionAlternative) GetLanguageCode() []string { + if x != nil { + return x.LanguageCode + } + return nil +} + +// Word-specific information for recognized words. +type WordInfo struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Time offset relative to the beginning of the audio in ms + // and corresponding to the start of the spoken word. + // This field is only set if `enable_word_time_offsets=true` and only + // in the top hypothesis. + StartTime int32 `protobuf:"varint,1,opt,name=start_time,json=startTime,proto3" json:"start_time,omitempty"` + // Time offset relative to the beginning of the audio in ms + // and corresponding to the end of the spoken word. + // This field is only set if `enable_word_time_offsets=true` and only + // in the top hypothesis. + EndTime int32 `protobuf:"varint,2,opt,name=end_time,json=endTime,proto3" json:"end_time,omitempty"` + // The word corresponding to this set of information. + Word string `protobuf:"bytes,3,opt,name=word,proto3" json:"word,omitempty"` + // The confidence estimate. A higher number indicates an estimated greater + // likelihood that the recognized word is correct. This field is set only for + // a non-streaming result or, for a streaming result where is_final=true. + // This field is not guaranteed to be accurate and users should not rely on + // it to be always provided. Although confidence can currently be roughly + // interpreted as a natural-log probability, the estimate computation varies + // with difference configurations, and is subject to change. The default of + // 0.0 is a sentinel value indicating confidence was not set. + Confidence float32 `protobuf:"fixed32,4,opt,name=confidence,proto3" json:"confidence,omitempty"` + // Output only. A distinct integer value is assigned for every speaker within + // the audio. This field specifies which one of those speakers was detected to + // have spoken this word. Value ranges from '1' to diarization_speaker_count. + // speaker_tag is set if enable_speaker_diarization = 'true' and only in the + // top alternative. + SpeakerTag int32 `protobuf:"varint,5,opt,name=speaker_tag,json=speakerTag,proto3" json:"speaker_tag,omitempty"` + // The language code of the word. + LanguageCode string `protobuf:"bytes,6,opt,name=language_code,json=languageCode,proto3" json:"language_code,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *WordInfo) Reset() { + *x = WordInfo{} + mi := &file_riva_proto_riva_asr_proto_msgTypes[12] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WordInfo) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WordInfo) ProtoMessage() {} + +func (x *WordInfo) ProtoReflect() protoreflect.Message { + mi := &file_riva_proto_riva_asr_proto_msgTypes[12] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WordInfo.ProtoReflect.Descriptor instead. +func (*WordInfo) Descriptor() ([]byte, []int) { + return file_riva_proto_riva_asr_proto_rawDescGZIP(), []int{12} +} + +func (x *WordInfo) GetStartTime() int32 { + if x != nil { + return x.StartTime + } + return 0 +} + +func (x *WordInfo) GetEndTime() int32 { + if x != nil { + return x.EndTime + } + return 0 +} + +func (x *WordInfo) GetWord() string { + if x != nil { + return x.Word + } + return "" +} + +func (x *WordInfo) GetConfidence() float32 { + if x != nil { + return x.Confidence + } + return 0 +} + +func (x *WordInfo) GetSpeakerTag() int32 { + if x != nil { + return x.SpeakerTag + } + return 0 +} + +func (x *WordInfo) GetLanguageCode() string { + if x != nil { + return x.LanguageCode + } + return "" +} + +type StreamingRecognizeResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // This repeated list contains the latest transcript(s) corresponding to + // audio currently being processed. + // Currently one result is returned, where each result can have multiple + // alternatives + Results []*StreamingRecognitionResult `protobuf:"bytes,1,rep,name=results,proto3" json:"results,omitempty"` + // The ID associated with the request + Id *RequestId `protobuf:"bytes,100,opt,name=id,proto3" json:"id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StreamingRecognizeResponse) Reset() { + *x = StreamingRecognizeResponse{} + mi := &file_riva_proto_riva_asr_proto_msgTypes[13] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StreamingRecognizeResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StreamingRecognizeResponse) ProtoMessage() {} + +func (x *StreamingRecognizeResponse) ProtoReflect() protoreflect.Message { + mi := &file_riva_proto_riva_asr_proto_msgTypes[13] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StreamingRecognizeResponse.ProtoReflect.Descriptor instead. +func (*StreamingRecognizeResponse) Descriptor() ([]byte, []int) { + return file_riva_proto_riva_asr_proto_rawDescGZIP(), []int{13} +} + +func (x *StreamingRecognizeResponse) GetResults() []*StreamingRecognitionResult { + if x != nil { + return x.Results + } + return nil +} + +func (x *StreamingRecognizeResponse) GetId() *RequestId { + if x != nil { + return x.Id + } + return nil +} + +type PipelineStates struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Neural VAD probabilities + VadProbabilities []float32 `protobuf:"fixed32,1,rep,packed,name=vad_probabilities,json=vadProbabilities,proto3" json:"vad_probabilities,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *PipelineStates) Reset() { + *x = PipelineStates{} + mi := &file_riva_proto_riva_asr_proto_msgTypes[14] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *PipelineStates) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PipelineStates) ProtoMessage() {} + +func (x *PipelineStates) ProtoReflect() protoreflect.Message { + mi := &file_riva_proto_riva_asr_proto_msgTypes[14] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PipelineStates.ProtoReflect.Descriptor instead. +func (*PipelineStates) Descriptor() ([]byte, []int) { + return file_riva_proto_riva_asr_proto_rawDescGZIP(), []int{14} +} + +func (x *PipelineStates) GetVadProbabilities() []float32 { + if x != nil { + return x.VadProbabilities + } + return nil +} + +// A streaming speech recognition result corresponding to a portion of the audio +// that is currently being processed. +type StreamingRecognitionResult struct { + state protoimpl.MessageState `protogen:"open.v1"` + // May contain one or more recognition hypotheses (up to the + // maximum specified in `max_alternatives`). + // These alternatives are ordered in terms of accuracy, with the top (first) + // alternative being the most probable, as ranked by the recognizer. + Alternatives []*SpeechRecognitionAlternative `protobuf:"bytes,1,rep,name=alternatives,proto3" json:"alternatives,omitempty"` + // If `false`, this `StreamingRecognitionResult` represents an + // interim result that may change. If `true`, this is the final time the + // speech service will return this particular `StreamingRecognitionResult`, + // the recognizer will not return any further hypotheses for this portion of + // the transcript and corresponding audio. + IsFinal bool `protobuf:"varint,2,opt,name=is_final,json=isFinal,proto3" json:"is_final,omitempty"` + // An estimate of the likelihood that the recognizer will not + // change its guess about this interim result. Values range from 0.0 + // (completely unstable) to 1.0 (completely stable). + // This field is only provided for interim results (`is_final=false`). + // The default of 0.0 is a sentinel value indicating `stability` was not set. + Stability float32 `protobuf:"fixed32,3,opt,name=stability,proto3" json:"stability,omitempty"` + // For multi-channel audio, this is the channel number corresponding to the + // recognized result for the audio from that channel. + // For audio_channel_count = N, its output values can range from '1' to 'N'. + ChannelTag int32 `protobuf:"varint,5,opt,name=channel_tag,json=channelTag,proto3" json:"channel_tag,omitempty"` + // Length of audio processed so far in seconds + AudioProcessed float32 `protobuf:"fixed32,6,opt,name=audio_processed,json=audioProcessed,proto3" json:"audio_processed,omitempty"` + // Message for pipeline states + PipelineStates *PipelineStates `protobuf:"bytes,7,opt,name=pipeline_states,json=pipelineStates,proto3,oneof" json:"pipeline_states,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *StreamingRecognitionResult) Reset() { + *x = StreamingRecognitionResult{} + mi := &file_riva_proto_riva_asr_proto_msgTypes[15] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *StreamingRecognitionResult) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*StreamingRecognitionResult) ProtoMessage() {} + +func (x *StreamingRecognitionResult) ProtoReflect() protoreflect.Message { + mi := &file_riva_proto_riva_asr_proto_msgTypes[15] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use StreamingRecognitionResult.ProtoReflect.Descriptor instead. +func (*StreamingRecognitionResult) Descriptor() ([]byte, []int) { + return file_riva_proto_riva_asr_proto_rawDescGZIP(), []int{15} +} + +func (x *StreamingRecognitionResult) GetAlternatives() []*SpeechRecognitionAlternative { + if x != nil { + return x.Alternatives + } + return nil +} + +func (x *StreamingRecognitionResult) GetIsFinal() bool { + if x != nil { + return x.IsFinal + } + return false +} + +func (x *StreamingRecognitionResult) GetStability() float32 { + if x != nil { + return x.Stability + } + return 0 +} + +func (x *StreamingRecognitionResult) GetChannelTag() int32 { + if x != nil { + return x.ChannelTag + } + return 0 +} + +func (x *StreamingRecognitionResult) GetAudioProcessed() float32 { + if x != nil { + return x.AudioProcessed + } + return 0 +} + +func (x *StreamingRecognitionResult) GetPipelineStates() *PipelineStates { + if x != nil { + return x.PipelineStates + } + return nil +} + +type RivaSpeechRecognitionConfigResponse_Config struct { + state protoimpl.MessageState `protogen:"open.v1"` + ModelName string `protobuf:"bytes,1,opt,name=model_name,json=modelName,proto3" json:"model_name,omitempty"` + Parameters map[string]string `protobuf:"bytes,2,rep,name=parameters,proto3" json:"parameters,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RivaSpeechRecognitionConfigResponse_Config) Reset() { + *x = RivaSpeechRecognitionConfigResponse_Config{} + mi := &file_riva_proto_riva_asr_proto_msgTypes[16] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RivaSpeechRecognitionConfigResponse_Config) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RivaSpeechRecognitionConfigResponse_Config) ProtoMessage() {} + +func (x *RivaSpeechRecognitionConfigResponse_Config) ProtoReflect() protoreflect.Message { + mi := &file_riva_proto_riva_asr_proto_msgTypes[16] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RivaSpeechRecognitionConfigResponse_Config.ProtoReflect.Descriptor instead. +func (*RivaSpeechRecognitionConfigResponse_Config) Descriptor() ([]byte, []int) { + return file_riva_proto_riva_asr_proto_rawDescGZIP(), []int{1, 0} +} + +func (x *RivaSpeechRecognitionConfigResponse_Config) GetModelName() string { + if x != nil { + return x.ModelName + } + return "" +} + +func (x *RivaSpeechRecognitionConfigResponse_Config) GetParameters() map[string]string { + if x != nil { + return x.Parameters + } + return nil +} + +var File_riva_proto_riva_asr_proto protoreflect.FileDescriptor + +const file_riva_proto_riva_asr_proto_rawDesc = "" + + "\n" + + "\x19riva/proto/riva_asr.proto\x12\x0fnvidia.riva.asr\x1a\x1briva/proto/riva_audio.proto\x1a\x1criva/proto/riva_common.proto\"C\n" + + "\"RivaSpeechRecognitionConfigRequest\x12\x1d\n" + + "\n" + + "model_name\x18\x01 \x01(\tR\tmodelName\"\xdb\x02\n" + + "#RivaSpeechRecognitionConfigResponse\x12^\n" + + "\fmodel_config\x18\x01 \x03(\v2;.nvidia.riva.asr.RivaSpeechRecognitionConfigResponse.ConfigR\vmodelConfig\x1a\xd3\x01\n" + + "\x06Config\x12\x1d\n" + + "\n" + + "model_name\x18\x01 \x01(\tR\tmodelName\x12k\n" + + "\n" + + "parameters\x18\x02 \x03(\v2K.nvidia.riva.asr.RivaSpeechRecognitionConfigResponse.Config.ParametersEntryR\n" + + "parameters\x1a=\n" + + "\x0fParametersEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01\"\x8c\x01\n" + + "\x10RecognizeRequest\x12:\n" + + "\x06config\x18\x01 \x01(\v2\".nvidia.riva.asr.RecognitionConfigR\x06config\x12\x14\n" + + "\x05audio\x18\x02 \x01(\fR\x05audio\x12&\n" + + "\x02id\x18d \x01(\v2\x16.nvidia.riva.RequestIdR\x02id\"\xd9\x01\n" + + "\x19StreamingRecognizeRequest\x12X\n" + + "\x10streaming_config\x18\x01 \x01(\v2+.nvidia.riva.asr.StreamingRecognitionConfigH\x00R\x0fstreamingConfig\x12%\n" + + "\raudio_content\x18\x02 \x01(\fH\x00R\faudioContent\x12&\n" + + "\x02id\x18d \x01(\v2\x16.nvidia.riva.RequestIdR\x02idB\x13\n" + + "\x11streaming_request\"\x97\x03\n" + + "\x11EndpointingConfig\x12(\n" + + "\rstart_history\x18\x01 \x01(\x05H\x00R\fstartHistory\x88\x01\x01\x12,\n" + + "\x0fstart_threshold\x18\x02 \x01(\x02H\x01R\x0estartThreshold\x88\x01\x01\x12&\n" + + "\fstop_history\x18\x03 \x01(\x05H\x02R\vstopHistory\x88\x01\x01\x12*\n" + + "\x0estop_threshold\x18\x04 \x01(\x02H\x03R\rstopThreshold\x88\x01\x01\x12-\n" + + "\x10stop_history_eou\x18\x05 \x01(\x05H\x04R\x0estopHistoryEou\x88\x01\x01\x121\n" + + "\x12stop_threshold_eou\x18\x06 \x01(\x02H\x05R\x10stopThresholdEou\x88\x01\x01B\x10\n" + + "\x0e_start_historyB\x12\n" + + "\x10_start_thresholdB\x0f\n" + + "\r_stop_historyB\x11\n" + + "\x0f_stop_thresholdB\x13\n" + + "\x11_stop_history_eouB\x15\n" + + "\x13_stop_threshold_eou\"\x86\b\n" + + "\x11RecognitionConfig\x126\n" + + "\bencoding\x18\x01 \x01(\x0e2\x1a.nvidia.riva.AudioEncodingR\bencoding\x12*\n" + + "\x11sample_rate_hertz\x18\x02 \x01(\x05R\x0fsampleRateHertz\x12#\n" + + "\rlanguage_code\x18\x03 \x01(\tR\flanguageCode\x12)\n" + + "\x10max_alternatives\x18\x04 \x01(\x05R\x0fmaxAlternatives\x12)\n" + + "\x10profanity_filter\x18\x05 \x01(\bR\x0fprofanityFilter\x12G\n" + + "\x0fspeech_contexts\x18\x06 \x03(\v2\x1e.nvidia.riva.asr.SpeechContextR\x0espeechContexts\x12.\n" + + "\x13audio_channel_count\x18\a \x01(\x05R\x11audioChannelCount\x127\n" + + "\x18enable_word_time_offsets\x18\b \x01(\bR\x15enableWordTimeOffsets\x12@\n" + + "\x1cenable_automatic_punctuation\x18\v \x01(\bR\x1aenableAutomaticPunctuation\x12T\n" + + "'enable_separate_recognition_per_channel\x18\f \x01(\bR#enableSeparateRecognitionPerChannel\x12\x14\n" + + "\x05model\x18\r \x01(\tR\x05model\x121\n" + + "\x14verbatim_transcripts\x18\x0e \x01(\bR\x13verbatimTranscripts\x12X\n" + + "\x12diarization_config\x18\x13 \x01(\v2).nvidia.riva.asr.SpeakerDiarizationConfigR\x11diarizationConfig\x12n\n" + + "\x14custom_configuration\x18\x18 \x03(\v2;.nvidia.riva.asr.RecognitionConfig.CustomConfigurationEntryR\x13customConfiguration\x12V\n" + + "\x12endpointing_config\x18\x19 \x01(\v2\".nvidia.riva.asr.EndpointingConfigH\x00R\x11endpointingConfig\x88\x01\x01\x1aF\n" + + "\x18CustomConfigurationEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + + "\x05value\x18\x02 \x01(\tR\x05value:\x028\x01B\x15\n" + + "\x13_endpointing_config\"\x81\x01\n" + + "\x1aStreamingRecognitionConfig\x12:\n" + + "\x06config\x18\x01 \x01(\v2\".nvidia.riva.asr.RecognitionConfigR\x06config\x12'\n" + + "\x0finterim_results\x18\x02 \x01(\bR\x0einterimResults\"\x84\x01\n" + + "\x18SpeakerDiarizationConfig\x12<\n" + + "\x1aenable_speaker_diarization\x18\x01 \x01(\bR\x18enableSpeakerDiarization\x12*\n" + + "\x11max_speaker_count\x18\x02 \x01(\x05R\x0fmaxSpeakerCount\"?\n" + + "\rSpeechContext\x12\x18\n" + + "\aphrases\x18\x01 \x03(\tR\aphrases\x12\x14\n" + + "\x05boost\x18\x04 \x01(\x02R\x05boost\"\x7f\n" + + "\x11RecognizeResponse\x12B\n" + + "\aresults\x18\x01 \x03(\v2(.nvidia.riva.asr.SpeechRecognitionResultR\aresults\x12&\n" + + "\x02id\x18d \x01(\v2\x16.nvidia.riva.RequestIdR\x02id\"\xb6\x01\n" + + "\x17SpeechRecognitionResult\x12Q\n" + + "\falternatives\x18\x01 \x03(\v2-.nvidia.riva.asr.SpeechRecognitionAlternativeR\falternatives\x12\x1f\n" + + "\vchannel_tag\x18\x02 \x01(\x05R\n" + + "channelTag\x12'\n" + + "\x0faudio_processed\x18\x03 \x01(\x02R\x0eaudioProcessed\"\xb4\x01\n" + + "\x1cSpeechRecognitionAlternative\x12\x1e\n" + + "\n" + + "transcript\x18\x01 \x01(\tR\n" + + "transcript\x12\x1e\n" + + "\n" + + "confidence\x18\x02 \x01(\x02R\n" + + "confidence\x12/\n" + + "\x05words\x18\x03 \x03(\v2\x19.nvidia.riva.asr.WordInfoR\x05words\x12#\n" + + "\rlanguage_code\x18\x04 \x03(\tR\flanguageCode\"\xbe\x01\n" + + "\bWordInfo\x12\x1d\n" + + "\n" + + "start_time\x18\x01 \x01(\x05R\tstartTime\x12\x19\n" + + "\bend_time\x18\x02 \x01(\x05R\aendTime\x12\x12\n" + + "\x04word\x18\x03 \x01(\tR\x04word\x12\x1e\n" + + "\n" + + "confidence\x18\x04 \x01(\x02R\n" + + "confidence\x12\x1f\n" + + "\vspeaker_tag\x18\x05 \x01(\x05R\n" + + "speakerTag\x12#\n" + + "\rlanguage_code\x18\x06 \x01(\tR\flanguageCode\"\x8b\x01\n" + + "\x1aStreamingRecognizeResponse\x12E\n" + + "\aresults\x18\x01 \x03(\v2+.nvidia.riva.asr.StreamingRecognitionResultR\aresults\x12&\n" + + "\x02id\x18d \x01(\v2\x16.nvidia.riva.RequestIdR\x02id\"=\n" + + "\x0ePipelineStates\x12+\n" + + "\x11vad_probabilities\x18\x01 \x03(\x02R\x10vadProbabilities\"\xd5\x02\n" + + "\x1aStreamingRecognitionResult\x12Q\n" + + "\falternatives\x18\x01 \x03(\v2-.nvidia.riva.asr.SpeechRecognitionAlternativeR\falternatives\x12\x19\n" + + "\bis_final\x18\x02 \x01(\bR\aisFinal\x12\x1c\n" + + "\tstability\x18\x03 \x01(\x02R\tstability\x12\x1f\n" + + "\vchannel_tag\x18\x05 \x01(\x05R\n" + + "channelTag\x12'\n" + + "\x0faudio_processed\x18\x06 \x01(\x02R\x0eaudioProcessed\x12M\n" + + "\x0fpipeline_states\x18\a \x01(\v2\x1f.nvidia.riva.asr.PipelineStatesH\x00R\x0epipelineStates\x88\x01\x01B\x12\n" + + "\x10_pipeline_states2\xf2\x02\n" + + "\x15RivaSpeechRecognition\x12T\n" + + "\tRecognize\x12!.nvidia.riva.asr.RecognizeRequest\x1a\".nvidia.riva.asr.RecognizeResponse\"\x00\x12s\n" + + "\x12StreamingRecognize\x12*.nvidia.riva.asr.StreamingRecognizeRequest\x1a+.nvidia.riva.asr.StreamingRecognizeResponse\"\x00(\x010\x01\x12\x8d\x01\n" + + "\x1eGetRivaSpeechRecognitionConfig\x123.nvidia.riva.asr.RivaSpeechRecognitionConfigRequest\x1a4.nvidia.riva.asr.RivaSpeechRecognitionConfigResponse\"\x00B\xb6\x01\n" + + "\x13com.nvidia.riva.asrB\fRivaAsrProtoP\x01Z0github.com/rbright/sotto/proto/gen/go/riva/proto\xf8\x01\x01\xa2\x02\x03NRA\xaa\x02\x0fNvidia.Riva.Asr\xca\x02\x0fNvidia\\Riva\\Asr\xe2\x02\x1bNvidia\\Riva\\Asr\\GPBMetadata\xea\x02\x11Nvidia::Riva::Asrb\x06proto3" + +var ( + file_riva_proto_riva_asr_proto_rawDescOnce sync.Once + file_riva_proto_riva_asr_proto_rawDescData []byte +) + +func file_riva_proto_riva_asr_proto_rawDescGZIP() []byte { + file_riva_proto_riva_asr_proto_rawDescOnce.Do(func() { + file_riva_proto_riva_asr_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_riva_proto_riva_asr_proto_rawDesc), len(file_riva_proto_riva_asr_proto_rawDesc))) + }) + return file_riva_proto_riva_asr_proto_rawDescData +} + +var file_riva_proto_riva_asr_proto_msgTypes = make([]protoimpl.MessageInfo, 19) +var file_riva_proto_riva_asr_proto_goTypes = []any{ + (*RivaSpeechRecognitionConfigRequest)(nil), // 0: nvidia.riva.asr.RivaSpeechRecognitionConfigRequest + (*RivaSpeechRecognitionConfigResponse)(nil), // 1: nvidia.riva.asr.RivaSpeechRecognitionConfigResponse + (*RecognizeRequest)(nil), // 2: nvidia.riva.asr.RecognizeRequest + (*StreamingRecognizeRequest)(nil), // 3: nvidia.riva.asr.StreamingRecognizeRequest + (*EndpointingConfig)(nil), // 4: nvidia.riva.asr.EndpointingConfig + (*RecognitionConfig)(nil), // 5: nvidia.riva.asr.RecognitionConfig + (*StreamingRecognitionConfig)(nil), // 6: nvidia.riva.asr.StreamingRecognitionConfig + (*SpeakerDiarizationConfig)(nil), // 7: nvidia.riva.asr.SpeakerDiarizationConfig + (*SpeechContext)(nil), // 8: nvidia.riva.asr.SpeechContext + (*RecognizeResponse)(nil), // 9: nvidia.riva.asr.RecognizeResponse + (*SpeechRecognitionResult)(nil), // 10: nvidia.riva.asr.SpeechRecognitionResult + (*SpeechRecognitionAlternative)(nil), // 11: nvidia.riva.asr.SpeechRecognitionAlternative + (*WordInfo)(nil), // 12: nvidia.riva.asr.WordInfo + (*StreamingRecognizeResponse)(nil), // 13: nvidia.riva.asr.StreamingRecognizeResponse + (*PipelineStates)(nil), // 14: nvidia.riva.asr.PipelineStates + (*StreamingRecognitionResult)(nil), // 15: nvidia.riva.asr.StreamingRecognitionResult + (*RivaSpeechRecognitionConfigResponse_Config)(nil), // 16: nvidia.riva.asr.RivaSpeechRecognitionConfigResponse.Config + nil, // 17: nvidia.riva.asr.RivaSpeechRecognitionConfigResponse.Config.ParametersEntry + nil, // 18: nvidia.riva.asr.RecognitionConfig.CustomConfigurationEntry + (*RequestId)(nil), // 19: nvidia.riva.RequestId + (AudioEncoding)(0), // 20: nvidia.riva.AudioEncoding +} +var file_riva_proto_riva_asr_proto_depIdxs = []int32{ + 16, // 0: nvidia.riva.asr.RivaSpeechRecognitionConfigResponse.model_config:type_name -> nvidia.riva.asr.RivaSpeechRecognitionConfigResponse.Config + 5, // 1: nvidia.riva.asr.RecognizeRequest.config:type_name -> nvidia.riva.asr.RecognitionConfig + 19, // 2: nvidia.riva.asr.RecognizeRequest.id:type_name -> nvidia.riva.RequestId + 6, // 3: nvidia.riva.asr.StreamingRecognizeRequest.streaming_config:type_name -> nvidia.riva.asr.StreamingRecognitionConfig + 19, // 4: nvidia.riva.asr.StreamingRecognizeRequest.id:type_name -> nvidia.riva.RequestId + 20, // 5: nvidia.riva.asr.RecognitionConfig.encoding:type_name -> nvidia.riva.AudioEncoding + 8, // 6: nvidia.riva.asr.RecognitionConfig.speech_contexts:type_name -> nvidia.riva.asr.SpeechContext + 7, // 7: nvidia.riva.asr.RecognitionConfig.diarization_config:type_name -> nvidia.riva.asr.SpeakerDiarizationConfig + 18, // 8: nvidia.riva.asr.RecognitionConfig.custom_configuration:type_name -> nvidia.riva.asr.RecognitionConfig.CustomConfigurationEntry + 4, // 9: nvidia.riva.asr.RecognitionConfig.endpointing_config:type_name -> nvidia.riva.asr.EndpointingConfig + 5, // 10: nvidia.riva.asr.StreamingRecognitionConfig.config:type_name -> nvidia.riva.asr.RecognitionConfig + 10, // 11: nvidia.riva.asr.RecognizeResponse.results:type_name -> nvidia.riva.asr.SpeechRecognitionResult + 19, // 12: nvidia.riva.asr.RecognizeResponse.id:type_name -> nvidia.riva.RequestId + 11, // 13: nvidia.riva.asr.SpeechRecognitionResult.alternatives:type_name -> nvidia.riva.asr.SpeechRecognitionAlternative + 12, // 14: nvidia.riva.asr.SpeechRecognitionAlternative.words:type_name -> nvidia.riva.asr.WordInfo + 15, // 15: nvidia.riva.asr.StreamingRecognizeResponse.results:type_name -> nvidia.riva.asr.StreamingRecognitionResult + 19, // 16: nvidia.riva.asr.StreamingRecognizeResponse.id:type_name -> nvidia.riva.RequestId + 11, // 17: nvidia.riva.asr.StreamingRecognitionResult.alternatives:type_name -> nvidia.riva.asr.SpeechRecognitionAlternative + 14, // 18: nvidia.riva.asr.StreamingRecognitionResult.pipeline_states:type_name -> nvidia.riva.asr.PipelineStates + 17, // 19: nvidia.riva.asr.RivaSpeechRecognitionConfigResponse.Config.parameters:type_name -> nvidia.riva.asr.RivaSpeechRecognitionConfigResponse.Config.ParametersEntry + 2, // 20: nvidia.riva.asr.RivaSpeechRecognition.Recognize:input_type -> nvidia.riva.asr.RecognizeRequest + 3, // 21: nvidia.riva.asr.RivaSpeechRecognition.StreamingRecognize:input_type -> nvidia.riva.asr.StreamingRecognizeRequest + 0, // 22: nvidia.riva.asr.RivaSpeechRecognition.GetRivaSpeechRecognitionConfig:input_type -> nvidia.riva.asr.RivaSpeechRecognitionConfigRequest + 9, // 23: nvidia.riva.asr.RivaSpeechRecognition.Recognize:output_type -> nvidia.riva.asr.RecognizeResponse + 13, // 24: nvidia.riva.asr.RivaSpeechRecognition.StreamingRecognize:output_type -> nvidia.riva.asr.StreamingRecognizeResponse + 1, // 25: nvidia.riva.asr.RivaSpeechRecognition.GetRivaSpeechRecognitionConfig:output_type -> nvidia.riva.asr.RivaSpeechRecognitionConfigResponse + 23, // [23:26] is the sub-list for method output_type + 20, // [20:23] is the sub-list for method input_type + 20, // [20:20] is the sub-list for extension type_name + 20, // [20:20] is the sub-list for extension extendee + 0, // [0:20] is the sub-list for field type_name +} + +func init() { file_riva_proto_riva_asr_proto_init() } +func file_riva_proto_riva_asr_proto_init() { + if File_riva_proto_riva_asr_proto != nil { + return + } + file_riva_proto_riva_audio_proto_init() + file_riva_proto_riva_common_proto_init() + file_riva_proto_riva_asr_proto_msgTypes[3].OneofWrappers = []any{ + (*StreamingRecognizeRequest_StreamingConfig)(nil), + (*StreamingRecognizeRequest_AudioContent)(nil), + } + file_riva_proto_riva_asr_proto_msgTypes[4].OneofWrappers = []any{} + file_riva_proto_riva_asr_proto_msgTypes[5].OneofWrappers = []any{} + file_riva_proto_riva_asr_proto_msgTypes[15].OneofWrappers = []any{} + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_riva_proto_riva_asr_proto_rawDesc), len(file_riva_proto_riva_asr_proto_rawDesc)), + NumEnums: 0, + NumMessages: 19, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_riva_proto_riva_asr_proto_goTypes, + DependencyIndexes: file_riva_proto_riva_asr_proto_depIdxs, + MessageInfos: file_riva_proto_riva_asr_proto_msgTypes, + }.Build() + File_riva_proto_riva_asr_proto = out.File + file_riva_proto_riva_asr_proto_goTypes = nil + file_riva_proto_riva_asr_proto_depIdxs = nil +} diff --git a/apps/sotto/proto/gen/go/riva/proto/riva_asr_grpc.pb.go b/apps/sotto/proto/gen/go/riva/proto/riva_asr_grpc.pb.go new file mode 100644 index 0000000..9bd2439 --- /dev/null +++ b/apps/sotto/proto/gen/go/riva/proto/riva_asr_grpc.pb.go @@ -0,0 +1,236 @@ +// SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: MIT + +// Copyright 2019 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.5.1 +// - protoc (unknown) +// source: riva/proto/riva_asr.proto + +package proto + +import ( + context "context" + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + RivaSpeechRecognition_Recognize_FullMethodName = "/nvidia.riva.asr.RivaSpeechRecognition/Recognize" + RivaSpeechRecognition_StreamingRecognize_FullMethodName = "/nvidia.riva.asr.RivaSpeechRecognition/StreamingRecognize" + RivaSpeechRecognition_GetRivaSpeechRecognitionConfig_FullMethodName = "/nvidia.riva.asr.RivaSpeechRecognition/GetRivaSpeechRecognitionConfig" +) + +// RivaSpeechRecognitionClient is the client API for RivaSpeechRecognition service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +// +// The RivaSpeechRecognition service provides two mechanisms for converting +// speech to text. +type RivaSpeechRecognitionClient interface { + // Recognize expects a RecognizeRequest and returns a RecognizeResponse. This + // request will block until the audio is uploaded, processed, and a transcript + // is returned. + Recognize(ctx context.Context, in *RecognizeRequest, opts ...grpc.CallOption) (*RecognizeResponse, error) + // StreamingRecognize is a non-blocking API call that allows audio data to be + // fed to the server in chunks as it becomes available. Depending on the + // configuration in the StreamingRecognizeRequest, intermediate results can be + // sent back to the client. Recognition ends when the stream is closed by the + // client. + StreamingRecognize(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[StreamingRecognizeRequest, StreamingRecognizeResponse], error) + // Enables clients to request the configuration of the current ASR service, or + // a specific model within the service. + GetRivaSpeechRecognitionConfig(ctx context.Context, in *RivaSpeechRecognitionConfigRequest, opts ...grpc.CallOption) (*RivaSpeechRecognitionConfigResponse, error) +} + +type rivaSpeechRecognitionClient struct { + cc grpc.ClientConnInterface +} + +func NewRivaSpeechRecognitionClient(cc grpc.ClientConnInterface) RivaSpeechRecognitionClient { + return &rivaSpeechRecognitionClient{cc} +} + +func (c *rivaSpeechRecognitionClient) Recognize(ctx context.Context, in *RecognizeRequest, opts ...grpc.CallOption) (*RecognizeResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(RecognizeResponse) + err := c.cc.Invoke(ctx, RivaSpeechRecognition_Recognize_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *rivaSpeechRecognitionClient) StreamingRecognize(ctx context.Context, opts ...grpc.CallOption) (grpc.BidiStreamingClient[StreamingRecognizeRequest, StreamingRecognizeResponse], error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + stream, err := c.cc.NewStream(ctx, &RivaSpeechRecognition_ServiceDesc.Streams[0], RivaSpeechRecognition_StreamingRecognize_FullMethodName, cOpts...) + if err != nil { + return nil, err + } + x := &grpc.GenericClientStream[StreamingRecognizeRequest, StreamingRecognizeResponse]{ClientStream: stream} + return x, nil +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type RivaSpeechRecognition_StreamingRecognizeClient = grpc.BidiStreamingClient[StreamingRecognizeRequest, StreamingRecognizeResponse] + +func (c *rivaSpeechRecognitionClient) GetRivaSpeechRecognitionConfig(ctx context.Context, in *RivaSpeechRecognitionConfigRequest, opts ...grpc.CallOption) (*RivaSpeechRecognitionConfigResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(RivaSpeechRecognitionConfigResponse) + err := c.cc.Invoke(ctx, RivaSpeechRecognition_GetRivaSpeechRecognitionConfig_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// RivaSpeechRecognitionServer is the server API for RivaSpeechRecognition service. +// All implementations must embed UnimplementedRivaSpeechRecognitionServer +// for forward compatibility. +// +// The RivaSpeechRecognition service provides two mechanisms for converting +// speech to text. +type RivaSpeechRecognitionServer interface { + // Recognize expects a RecognizeRequest and returns a RecognizeResponse. This + // request will block until the audio is uploaded, processed, and a transcript + // is returned. + Recognize(context.Context, *RecognizeRequest) (*RecognizeResponse, error) + // StreamingRecognize is a non-blocking API call that allows audio data to be + // fed to the server in chunks as it becomes available. Depending on the + // configuration in the StreamingRecognizeRequest, intermediate results can be + // sent back to the client. Recognition ends when the stream is closed by the + // client. + StreamingRecognize(grpc.BidiStreamingServer[StreamingRecognizeRequest, StreamingRecognizeResponse]) error + // Enables clients to request the configuration of the current ASR service, or + // a specific model within the service. + GetRivaSpeechRecognitionConfig(context.Context, *RivaSpeechRecognitionConfigRequest) (*RivaSpeechRecognitionConfigResponse, error) + mustEmbedUnimplementedRivaSpeechRecognitionServer() +} + +// UnimplementedRivaSpeechRecognitionServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedRivaSpeechRecognitionServer struct{} + +func (UnimplementedRivaSpeechRecognitionServer) Recognize(context.Context, *RecognizeRequest) (*RecognizeResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Recognize not implemented") +} +func (UnimplementedRivaSpeechRecognitionServer) StreamingRecognize(grpc.BidiStreamingServer[StreamingRecognizeRequest, StreamingRecognizeResponse]) error { + return status.Errorf(codes.Unimplemented, "method StreamingRecognize not implemented") +} +func (UnimplementedRivaSpeechRecognitionServer) GetRivaSpeechRecognitionConfig(context.Context, *RivaSpeechRecognitionConfigRequest) (*RivaSpeechRecognitionConfigResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method GetRivaSpeechRecognitionConfig not implemented") +} +func (UnimplementedRivaSpeechRecognitionServer) mustEmbedUnimplementedRivaSpeechRecognitionServer() {} +func (UnimplementedRivaSpeechRecognitionServer) testEmbeddedByValue() {} + +// UnsafeRivaSpeechRecognitionServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to RivaSpeechRecognitionServer will +// result in compilation errors. +type UnsafeRivaSpeechRecognitionServer interface { + mustEmbedUnimplementedRivaSpeechRecognitionServer() +} + +func RegisterRivaSpeechRecognitionServer(s grpc.ServiceRegistrar, srv RivaSpeechRecognitionServer) { + // If the following call pancis, it indicates UnimplementedRivaSpeechRecognitionServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&RivaSpeechRecognition_ServiceDesc, srv) +} + +func _RivaSpeechRecognition_Recognize_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RecognizeRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(RivaSpeechRecognitionServer).Recognize(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: RivaSpeechRecognition_Recognize_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(RivaSpeechRecognitionServer).Recognize(ctx, req.(*RecognizeRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _RivaSpeechRecognition_StreamingRecognize_Handler(srv interface{}, stream grpc.ServerStream) error { + return srv.(RivaSpeechRecognitionServer).StreamingRecognize(&grpc.GenericServerStream[StreamingRecognizeRequest, StreamingRecognizeResponse]{ServerStream: stream}) +} + +// This type alias is provided for backwards compatibility with existing code that references the prior non-generic stream type by name. +type RivaSpeechRecognition_StreamingRecognizeServer = grpc.BidiStreamingServer[StreamingRecognizeRequest, StreamingRecognizeResponse] + +func _RivaSpeechRecognition_GetRivaSpeechRecognitionConfig_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RivaSpeechRecognitionConfigRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(RivaSpeechRecognitionServer).GetRivaSpeechRecognitionConfig(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: RivaSpeechRecognition_GetRivaSpeechRecognitionConfig_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(RivaSpeechRecognitionServer).GetRivaSpeechRecognitionConfig(ctx, req.(*RivaSpeechRecognitionConfigRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// RivaSpeechRecognition_ServiceDesc is the grpc.ServiceDesc for RivaSpeechRecognition service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var RivaSpeechRecognition_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "nvidia.riva.asr.RivaSpeechRecognition", + HandlerType: (*RivaSpeechRecognitionServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "Recognize", + Handler: _RivaSpeechRecognition_Recognize_Handler, + }, + { + MethodName: "GetRivaSpeechRecognitionConfig", + Handler: _RivaSpeechRecognition_GetRivaSpeechRecognitionConfig_Handler, + }, + }, + Streams: []grpc.StreamDesc{ + { + StreamName: "StreamingRecognize", + Handler: _RivaSpeechRecognition_StreamingRecognize_Handler, + ServerStreams: true, + ClientStreams: true, + }, + }, + Metadata: "riva/proto/riva_asr.proto", +} diff --git a/apps/sotto/proto/gen/go/riva/proto/riva_audio.pb.go b/apps/sotto/proto/gen/go/riva/proto/riva_audio.pb.go new file mode 100644 index 0000000..78db6d6 --- /dev/null +++ b/apps/sotto/proto/gen/go/riva/proto/riva_audio.pb.go @@ -0,0 +1,158 @@ +// SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: MIT + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.6 +// protoc (unknown) +// source: riva/proto/riva_audio.proto + +package proto + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// AudioEncoding specifies the encoding of the audio bytes in the encapsulating +// message. +type AudioEncoding int32 + +const ( + // Not specified. + AudioEncoding_ENCODING_UNSPECIFIED AudioEncoding = 0 + // Uncompressed 16-bit signed little-endian samples (Linear PCM). + AudioEncoding_LINEAR_PCM AudioEncoding = 1 + // `FLAC` (Free Lossless Audio + // Codec) is the recommended encoding because it is + // lossless--therefore recognition is not compromised--and + // requires only about half the bandwidth of `LINEAR16`. `FLAC` stream + // encoding supports 16-bit and 24-bit samples, however, not all fields in + // `STREAMINFO` are supported. + AudioEncoding_FLAC AudioEncoding = 2 + // 8-bit samples that compand 14-bit audio samples using G.711 PCMU/mu-law. + AudioEncoding_MULAW AudioEncoding = 3 + AudioEncoding_OGGOPUS AudioEncoding = 4 + // 8-bit samples that compand 13-bit audio samples using G.711 PCMU/a-law. + AudioEncoding_ALAW AudioEncoding = 20 +) + +// Enum value maps for AudioEncoding. +var ( + AudioEncoding_name = map[int32]string{ + 0: "ENCODING_UNSPECIFIED", + 1: "LINEAR_PCM", + 2: "FLAC", + 3: "MULAW", + 4: "OGGOPUS", + 20: "ALAW", + } + AudioEncoding_value = map[string]int32{ + "ENCODING_UNSPECIFIED": 0, + "LINEAR_PCM": 1, + "FLAC": 2, + "MULAW": 3, + "OGGOPUS": 4, + "ALAW": 20, + } +) + +func (x AudioEncoding) Enum() *AudioEncoding { + p := new(AudioEncoding) + *p = x + return p +} + +func (x AudioEncoding) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (AudioEncoding) Descriptor() protoreflect.EnumDescriptor { + return file_riva_proto_riva_audio_proto_enumTypes[0].Descriptor() +} + +func (AudioEncoding) Type() protoreflect.EnumType { + return &file_riva_proto_riva_audio_proto_enumTypes[0] +} + +func (x AudioEncoding) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use AudioEncoding.Descriptor instead. +func (AudioEncoding) EnumDescriptor() ([]byte, []int) { + return file_riva_proto_riva_audio_proto_rawDescGZIP(), []int{0} +} + +var File_riva_proto_riva_audio_proto protoreflect.FileDescriptor + +const file_riva_proto_riva_audio_proto_rawDesc = "" + + "\n" + + "\x1briva/proto/riva_audio.proto\x12\vnvidia.riva*e\n" + + "\rAudioEncoding\x12\x18\n" + + "\x14ENCODING_UNSPECIFIED\x10\x00\x12\x0e\n" + + "\n" + + "LINEAR_PCM\x10\x01\x12\b\n" + + "\x04FLAC\x10\x02\x12\t\n" + + "\x05MULAW\x10\x03\x12\v\n" + + "\aOGGOPUS\x10\x04\x12\b\n" + + "\x04ALAW\x10\x14B\xa3\x01\n" + + "\x0fcom.nvidia.rivaB\x0eRivaAudioProtoP\x01Z0github.com/rbright/sotto/proto/gen/go/riva/proto\xf8\x01\x01\xa2\x02\x03NRX\xaa\x02\vNvidia.Riva\xca\x02\vNvidia\\Riva\xe2\x02\x17Nvidia\\Riva\\GPBMetadata\xea\x02\fNvidia::Rivab\x06proto3" + +var ( + file_riva_proto_riva_audio_proto_rawDescOnce sync.Once + file_riva_proto_riva_audio_proto_rawDescData []byte +) + +func file_riva_proto_riva_audio_proto_rawDescGZIP() []byte { + file_riva_proto_riva_audio_proto_rawDescOnce.Do(func() { + file_riva_proto_riva_audio_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_riva_proto_riva_audio_proto_rawDesc), len(file_riva_proto_riva_audio_proto_rawDesc))) + }) + return file_riva_proto_riva_audio_proto_rawDescData +} + +var file_riva_proto_riva_audio_proto_enumTypes = make([]protoimpl.EnumInfo, 1) +var file_riva_proto_riva_audio_proto_goTypes = []any{ + (AudioEncoding)(0), // 0: nvidia.riva.AudioEncoding +} +var file_riva_proto_riva_audio_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_riva_proto_riva_audio_proto_init() } +func file_riva_proto_riva_audio_proto_init() { + if File_riva_proto_riva_audio_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_riva_proto_riva_audio_proto_rawDesc), len(file_riva_proto_riva_audio_proto_rawDesc)), + NumEnums: 1, + NumMessages: 0, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_riva_proto_riva_audio_proto_goTypes, + DependencyIndexes: file_riva_proto_riva_audio_proto_depIdxs, + EnumInfos: file_riva_proto_riva_audio_proto_enumTypes, + }.Build() + File_riva_proto_riva_audio_proto = out.File + file_riva_proto_riva_audio_proto_goTypes = nil + file_riva_proto_riva_audio_proto_depIdxs = nil +} diff --git a/apps/sotto/proto/gen/go/riva/proto/riva_common.pb.go b/apps/sotto/proto/gen/go/riva/proto/riva_common.pb.go new file mode 100644 index 0000000..ec28c92 --- /dev/null +++ b/apps/sotto/proto/gen/go/riva/proto/riva_common.pb.go @@ -0,0 +1,127 @@ +// SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: MIT + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.6 +// protoc (unknown) +// source: riva/proto/riva_common.proto + +package proto + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// Specifies the request ID of the request. +type RequestId struct { + state protoimpl.MessageState `protogen:"open.v1"` + Value string `protobuf:"bytes,1,opt,name=value,proto3" json:"value,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RequestId) Reset() { + *x = RequestId{} + mi := &file_riva_proto_riva_common_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RequestId) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RequestId) ProtoMessage() {} + +func (x *RequestId) ProtoReflect() protoreflect.Message { + mi := &file_riva_proto_riva_common_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RequestId.ProtoReflect.Descriptor instead. +func (*RequestId) Descriptor() ([]byte, []int) { + return file_riva_proto_riva_common_proto_rawDescGZIP(), []int{0} +} + +func (x *RequestId) GetValue() string { + if x != nil { + return x.Value + } + return "" +} + +var File_riva_proto_riva_common_proto protoreflect.FileDescriptor + +const file_riva_proto_riva_common_proto_rawDesc = "" + + "\n" + + "\x1criva/proto/riva_common.proto\x12\vnvidia.riva\"!\n" + + "\tRequestId\x12\x14\n" + + "\x05value\x18\x01 \x01(\tR\x05valueB\xa4\x01\n" + + "\x0fcom.nvidia.rivaB\x0fRivaCommonProtoP\x01Z0github.com/rbright/sotto/proto/gen/go/riva/proto\xf8\x01\x01\xa2\x02\x03NRX\xaa\x02\vNvidia.Riva\xca\x02\vNvidia\\Riva\xe2\x02\x17Nvidia\\Riva\\GPBMetadata\xea\x02\fNvidia::Rivab\x06proto3" + +var ( + file_riva_proto_riva_common_proto_rawDescOnce sync.Once + file_riva_proto_riva_common_proto_rawDescData []byte +) + +func file_riva_proto_riva_common_proto_rawDescGZIP() []byte { + file_riva_proto_riva_common_proto_rawDescOnce.Do(func() { + file_riva_proto_riva_common_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_riva_proto_riva_common_proto_rawDesc), len(file_riva_proto_riva_common_proto_rawDesc))) + }) + return file_riva_proto_riva_common_proto_rawDescData +} + +var file_riva_proto_riva_common_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_riva_proto_riva_common_proto_goTypes = []any{ + (*RequestId)(nil), // 0: nvidia.riva.RequestId +} +var file_riva_proto_riva_common_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_riva_proto_riva_common_proto_init() } +func file_riva_proto_riva_common_proto_init() { + if File_riva_proto_riva_common_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_riva_proto_riva_common_proto_rawDesc), len(file_riva_proto_riva_common_proto_rawDesc)), + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_riva_proto_riva_common_proto_goTypes, + DependencyIndexes: file_riva_proto_riva_common_proto_depIdxs, + MessageInfos: file_riva_proto_riva_common_proto_msgTypes, + }.Build() + File_riva_proto_riva_common_proto = out.File + file_riva_proto_riva_common_proto_goTypes = nil + file_riva_proto_riva_common_proto_depIdxs = nil +} diff --git a/apps/sotto/proto/third_party/buf.yaml b/apps/sotto/proto/third_party/buf.yaml new file mode 100644 index 0000000..c126332 --- /dev/null +++ b/apps/sotto/proto/third_party/buf.yaml @@ -0,0 +1 @@ +version: v1 diff --git a/apps/sotto/proto/third_party/riva/proto/riva_asr.proto b/apps/sotto/proto/third_party/riva/proto/riva_asr.proto new file mode 100644 index 0000000..7e43712 --- /dev/null +++ b/apps/sotto/proto/third_party/riva/proto/riva_asr.proto @@ -0,0 +1,440 @@ +// SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: MIT + +// Copyright 2019 Google LLC. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +syntax = "proto3"; + +package nvidia.riva.asr; + +option cc_enable_arenas = true; +option go_package = "nvidia.com/riva_speech"; + +import "riva/proto/riva_audio.proto"; +import "riva/proto/riva_common.proto"; + +/* + * The RivaSpeechRecognition service provides two mechanisms for converting + * speech to text. + */ +service RivaSpeechRecognition { + // Recognize expects a RecognizeRequest and returns a RecognizeResponse. This + // request will block until the audio is uploaded, processed, and a transcript + // is returned. + rpc Recognize(RecognizeRequest) returns (RecognizeResponse) {} + // StreamingRecognize is a non-blocking API call that allows audio data to be + // fed to the server in chunks as it becomes available. Depending on the + // configuration in the StreamingRecognizeRequest, intermediate results can be + // sent back to the client. Recognition ends when the stream is closed by the + // client. + rpc StreamingRecognize(stream StreamingRecognizeRequest) + returns (stream StreamingRecognizeResponse) {} + + // Enables clients to request the configuration of the current ASR service, or + // a specific model within the service. + rpc GetRivaSpeechRecognitionConfig(RivaSpeechRecognitionConfigRequest) + returns (RivaSpeechRecognitionConfigResponse) {} +} + +/* + * RivaSpeechRecognitionConfigRequest + */ + +message RivaSpeechRecognitionConfigRequest { + // If model is specified only return config for model, otherwise return all + // configs. + string model_name = 1; +} + +message RivaSpeechRecognitionConfigResponse { + message Config { + string model_name = 1; + map parameters = 2; + } + + repeated Config model_config = 1; +} + +/* + * RecognizeRequest is used for batch processing of a single audio recording. + */ +message RecognizeRequest { + // Provides information to recognizer that specifies how to process the + // request. + RecognitionConfig config = 1; + // The raw audio data to be processed. The audio bytes must be encoded as + // specified in `RecognitionConfig`. + bytes audio = 2; + + // The ID to be associated with the request. If provided, this will be + // returned in the corresponding response. + RequestId id = 100; +} + +/* + * A StreamingRecognizeRequest is used to configure and stream audio content to + * the Riva ASR Service. The first message sent must include only a + * StreamingRecognitionConfig. Subsequent messages sent in the stream must + * contain only raw bytes of the audio to be recognized. + */ +message StreamingRecognizeRequest { + // The streaming request, which is either a streaming config or audio content. + oneof streaming_request { + // Provides information to the recognizer that specifies how to process the + // request. The first `StreamingRecognizeRequest` message must contain a + // `streaming_config` message. + StreamingRecognitionConfig streaming_config = 1; + // The audio data to be recognized. Sequential chunks of audio data are sent + // in sequential `StreamingRecognizeRequest` messages. The first + // `StreamingRecognizeRequest` message must not contain `audio` data + // and all subsequent `StreamingRecognizeRequest` messages must contain + // `audio` data. The audio bytes must be encoded as specified in + // `RecognitionConfig`. + bytes audio_content = 2; + } + + // The ID to be associated with the request. If provided, this will be + // returned in the corresponding responses. + RequestId id = 100; +} + +/* + * EndpointingConfig is used for configuring different fields related to start + * or end of utterance + */ +message EndpointingConfig { + // `start_history` is the size of the window, in milliseconds, used to + // detect start of utterance. + // `start_threshold` is the percentage threshold used to detect start of + // utterance. (0.0 to 1.0) + // If `start_threshold` of `start_history` ms of the acoustic model output + // have non-blank tokens, start of utterance is detected. + optional int32 start_history = 1; + optional float start_threshold = 2; + + // `stop_history` is the size of the window, in milliseconds, used to + // detect end of utterance. + // `stop_threshold` is the percentage threshold used to detect end of + // utterance. (0.0 to 1.0) + // If `stop_threshold` of `stop_history` ms of the acoustic model output have + // non-blank tokens, end of utterance is detected and decoder will be reset. + optional int32 stop_history = 3; + optional float stop_threshold = 4; + + // `stop_history_eou` and `stop_threshold_eou` are used for 2-pass end of utterance. + // `stop_history_eou` is the size of the window, in milliseconds, used to + // trigger 1st pass of end of utterance and generate a partial transcript + // with stability of 1. (stop_history_eou < stop_history) + // `stop_threshold_eou` is the percentage threshold used to trigger 1st + // pass of end of utterance. (0.0 to 1.0) + // If `stop_threshold_eou` of `stop_history_eou` ms of the acoustic model + // output have non-blank tokens, 1st pass of end of utterance is triggered. + optional int32 stop_history_eou = 5; + optional float stop_threshold_eou = 6; +} + +// Provides information to the recognizer that specifies how to process the +// request +message RecognitionConfig { + // The encoding of the audio data sent in the request. + // + // All encodings support only 1 channel (mono) audio. + AudioEncoding encoding = 1; + + // The sample rate in hertz (Hz) of the audio data sent in the + // `RecognizeRequest` or `StreamingRecognizeRequest` messages. + // The Riva server will automatically down-sample/up-sample the audio to + // match the ASR acoustic model sample rate. The sample rate value below 8kHz + // will not produce any meaningful output. + int32 sample_rate_hertz = 2; + + // Required. The language of the supplied audio as a + // [BCP-47](https://www.rfc-editor.org/rfc/bcp/bcp47.txt) language tag. + // Example: "en-US". + string language_code = 3; + + // Maximum number of recognition hypotheses to be returned. + // Specifically, the maximum number of `SpeechRecognizeAlternative` messages + // within each `SpeechRecognizeResult`. + // The server may return fewer than `max_alternatives`. + // If omitted, will return a maximum of one. + int32 max_alternatives = 4; + + // A custom field that enables profanity filtering for the generated + // transcripts. If set to 'true', the server filters out profanities, + // replacing all but the initial character in each filtered word with + // asterisks. For example, "x**". If set to `false` or omitted, profanities + // will not be filtered out. The default is `false`. + bool profanity_filter = 5; + + // Array of SpeechContext. + // A means to provide context to assist the speech recognition. For more + // information, see SpeechContext section + repeated SpeechContext speech_contexts = 6; + + // The number of channels in the input audio data. + // If `0` or omitted, defaults to one channel (mono). + // Note: Only single channel audio input is supported as of now. + int32 audio_channel_count = 7; + + // If `true`, the top result includes a list of words and the start and end + // time offsets (timestamps), and confidence scores for those words. If + // `false`, no word-level time offset information is returned. The default + // is `false`. + bool enable_word_time_offsets = 8; + + // If 'true', adds punctuation to recognition result hypotheses. The + // default 'false' value does not add punctuation to result hypotheses. + bool enable_automatic_punctuation = 11; + + // This needs to be set to `true` explicitly and `audio_channel_count` > 1 + // to get each channel recognized separately. The recognition result will + // contain a `channel_tag` field to state which channel that result belongs + // to. If this is not true, we will only recognize the first channel. The + // request is billed cumulatively for all channels recognized: + // `audio_channel_count` multiplied by the length of the audio. + // Note: This field is not yet supported. + bool enable_separate_recognition_per_channel = 12; + + // Which model to select for the given request. + // If empty, Riva will select the right model based on the other + // RecognitionConfig parameters. The model should correspond to the name + // passed to `riva-build` with the `--name` argument + string model = 13; + + // The verbatim_transcripts flag enables or disable inverse text + // normalization. 'true' returns exactly what was said, with no + // denormalization. 'false' applies inverse text normalization, also this is + // the default + bool verbatim_transcripts = 14; + + // Config to enable speaker diarization and set additional + // parameters. For non-streaming requests, the diarization results will be + // provided only in the top alternative of the FINAL SpeechRecognitionResult. + SpeakerDiarizationConfig diarization_config = 19; + + // Custom fields for passing request-level + // configuration options to plugins used in the + // model pipeline. + map custom_configuration = 24; + + // Config for tuning start or end of utterance parameters. + // If empty, Riva will use default values or custom values if specified in riva-build arguments. + optional EndpointingConfig endpointing_config = 25; +} + +// Provides information to the recognizer that specifies how to process the +// request +message StreamingRecognitionConfig { + // Provides information to the recognizer that specifies how to process the + // request + RecognitionConfig config = 1; + + // If `true`, interim results (tentative hypotheses) may be + // returned as they become available (these interim results are indicated with + // the `is_final=false` flag). + // If `false` or omitted, only `is_final=true` result(s) are returned. + bool interim_results = 2; +} + +// Config to enable speaker diarization. +message SpeakerDiarizationConfig { + // If 'true', enables speaker detection for each recognized word in + // the top alternative of the recognition result using a speaker_tag provided + // in the WordInfo. + bool enable_speaker_diarization = 1; + + // Maximum number of speakers in the conversation. This gives flexibility by + // allowing the system to automatically determine the correct number of + // speakers. If not set, the default value is 8. + int32 max_speaker_count = 2; +} + +// Provides "hints" to the speech recognizer to favor specific words and phrases +// in the results. +message SpeechContext { + // A list of strings containing words and phrases "hints" so that + // the speech recognition is more likely to recognize them. This can be used + // to improve the accuracy for specific words and phrases, for example, if + // specific commands are typically spoken by the user. This can also be used + // to add additional words to the vocabulary of the recognizer. + repeated string phrases = 1; + + // Hint Boost. Positive value will increase the probability that a specific + // phrase will be recognized over other similar sounding phrases. The higher + // the boost, the higher the chance of false positive recognition as well. + // Though `boost` can accept a wide range of positive values, most use cases + // are best served with values between 0 and 20. We recommend using a binary + // search approach to finding the optimal value for your use case. + float boost = 4; +} + +// The only message returned to the client by the `Recognize` method. It +// contains the result as zero or more sequential `SpeechRecognitionResult` +// messages. +message RecognizeResponse { + // Sequential list of transcription results corresponding to + // sequential portions of audio. Currently only returns one transcript. + repeated SpeechRecognitionResult results = 1; + + // The ID associated with the request + RequestId id = 100; +} + +// A speech recognition result corresponding to the latest transcript +message SpeechRecognitionResult { + // May contain one or more recognition hypotheses (up to the + // maximum specified in `max_alternatives`). + // These alternatives are ordered in terms of accuracy, with the top (first) + // alternative being the most probable, as ranked by the recognizer. + repeated SpeechRecognitionAlternative alternatives = 1; + + // For multi-channel audio, this is the channel number corresponding to the + // recognized result for the audio from that channel. + // For audio_channel_count = N, its output values can range from '1' to 'N'. + int32 channel_tag = 2; + + // Length of audio processed so far in seconds + float audio_processed = 3; +} + +// Alternative hypotheses (a.k.a. n-best list). +message SpeechRecognitionAlternative { + // Transcript text representing the words that the user spoke. + string transcript = 1; + + // The confidence estimate. A higher number indicates an estimated greater + // likelihood that the recognized word is correct. This field is set only for + // a non-streaming result or, for a streaming result where is_final=true. + // This field is not guaranteed to be accurate and users should not rely on + // it to be always provided. Although confidence can currently be roughly + // interpreted as a natural-log probability, the estimate computation varies + // with difference configurations, and is subject to change. The default of + // 0.0 is a sentinel value indicating confidence was not set. + float confidence = 2; + + // A list of word-specific information for each recognized word. Only + // populated if is_final=true + repeated WordInfo words = 3; + + // List of language codes detected in the transcript. + repeated string language_code = 4; +} + +// Word-specific information for recognized words. +message WordInfo { + // Time offset relative to the beginning of the audio in ms + // and corresponding to the start of the spoken word. + // This field is only set if `enable_word_time_offsets=true` and only + // in the top hypothesis. + int32 start_time = 1; + + // Time offset relative to the beginning of the audio in ms + // and corresponding to the end of the spoken word. + // This field is only set if `enable_word_time_offsets=true` and only + // in the top hypothesis. + int32 end_time = 2; + + // The word corresponding to this set of information. + string word = 3; + + // The confidence estimate. A higher number indicates an estimated greater + // likelihood that the recognized word is correct. This field is set only for + // a non-streaming result or, for a streaming result where is_final=true. + // This field is not guaranteed to be accurate and users should not rely on + // it to be always provided. Although confidence can currently be roughly + // interpreted as a natural-log probability, the estimate computation varies + // with difference configurations, and is subject to change. The default of + // 0.0 is a sentinel value indicating confidence was not set. + float confidence = 4; + + // Output only. A distinct integer value is assigned for every speaker within + // the audio. This field specifies which one of those speakers was detected to + // have spoken this word. Value ranges from '1' to diarization_speaker_count. + // speaker_tag is set if enable_speaker_diarization = 'true' and only in the + // top alternative. + int32 speaker_tag = 5; + + // The language code of the word. + string language_code = 6; +} + +// `StreamingRecognizeResponse` is the only message returned to the client by +// `StreamingRecognize`. A series of zero or more `StreamingRecognizeResponse` +// messages are streamed back to the client. +// +// Here are few examples of `StreamingRecognizeResponse`s +// +// 1. results { alternatives { transcript: "tube" } stability: 0.01 } +// +// 2. results { alternatives { transcript: "to be a" } stability: 0.01 } +// +// 3. results { alternatives { transcript: "to be or not to be" +// confidence: 0.92 } +// alternatives { transcript: "to bee or not to bee" } +// is_final: true } +// + +message StreamingRecognizeResponse { + // This repeated list contains the latest transcript(s) corresponding to + // audio currently being processed. + // Currently one result is returned, where each result can have multiple + // alternatives + repeated StreamingRecognitionResult results = 1; + + // The ID associated with the request + RequestId id = 100; +} + +message PipelineStates { + // Neural VAD probabilities + repeated float vad_probabilities = 1; +} + +// A streaming speech recognition result corresponding to a portion of the audio +// that is currently being processed. +message StreamingRecognitionResult { + // May contain one or more recognition hypotheses (up to the + // maximum specified in `max_alternatives`). + // These alternatives are ordered in terms of accuracy, with the top (first) + // alternative being the most probable, as ranked by the recognizer. + repeated SpeechRecognitionAlternative alternatives = 1; + + // If `false`, this `StreamingRecognitionResult` represents an + // interim result that may change. If `true`, this is the final time the + // speech service will return this particular `StreamingRecognitionResult`, + // the recognizer will not return any further hypotheses for this portion of + // the transcript and corresponding audio. + bool is_final = 2; + + // An estimate of the likelihood that the recognizer will not + // change its guess about this interim result. Values range from 0.0 + // (completely unstable) to 1.0 (completely stable). + // This field is only provided for interim results (`is_final=false`). + // The default of 0.0 is a sentinel value indicating `stability` was not set. + float stability = 3; + + // For multi-channel audio, this is the channel number corresponding to the + // recognized result for the audio from that channel. + // For audio_channel_count = N, its output values can range from '1' to 'N'. + int32 channel_tag = 5; + + // Length of audio processed so far in seconds + float audio_processed = 6; + + // Message for pipeline states + optional PipelineStates pipeline_states = 7; +} \ No newline at end of file diff --git a/apps/sotto/proto/third_party/riva/proto/riva_audio.proto b/apps/sotto/proto/third_party/riva/proto/riva_audio.proto new file mode 100644 index 0000000..f1a4b17 --- /dev/null +++ b/apps/sotto/proto/third_party/riva/proto/riva_audio.proto @@ -0,0 +1,37 @@ +// SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: MIT + +syntax = "proto3"; + +package nvidia.riva; + +option cc_enable_arenas = true; +option go_package = "nvidia.com/riva_speech"; + +/* + * AudioEncoding specifies the encoding of the audio bytes in the encapsulating + * message. + */ +enum AudioEncoding { + // Not specified. + ENCODING_UNSPECIFIED = 0; + + // Uncompressed 16-bit signed little-endian samples (Linear PCM). + LINEAR_PCM = 1; + + // `FLAC` (Free Lossless Audio + // Codec) is the recommended encoding because it is + // lossless--therefore recognition is not compromised--and + // requires only about half the bandwidth of `LINEAR16`. `FLAC` stream + // encoding supports 16-bit and 24-bit samples, however, not all fields in + // `STREAMINFO` are supported. + FLAC = 2; + + // 8-bit samples that compand 14-bit audio samples using G.711 PCMU/mu-law. + MULAW = 3; + + OGGOPUS = 4; + + // 8-bit samples that compand 13-bit audio samples using G.711 PCMU/a-law. + ALAW = 20; +} diff --git a/apps/sotto/proto/third_party/riva/proto/riva_common.proto b/apps/sotto/proto/third_party/riva/proto/riva_common.proto new file mode 100644 index 0000000..43ce2c8 --- /dev/null +++ b/apps/sotto/proto/third_party/riva/proto/riva_common.proto @@ -0,0 +1,16 @@ +// SPDX-FileCopyrightText: Copyright (c) 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: MIT + +syntax = "proto3"; + +package nvidia.riva; + +option cc_enable_arenas = true; +option go_package = "nvidia.com/riva_speech"; + +/* + * Specifies the request ID of the request. + */ +message RequestId { + string value = 1; +} diff --git a/buf.gen.yaml b/buf.gen.yaml new file mode 100644 index 0000000..e13ba5d --- /dev/null +++ b/buf.gen.yaml @@ -0,0 +1,16 @@ +version: v1 + +managed: + enabled: true + go_package_prefix: + default: github.com/rbright/sotto/proto/gen/go + +plugins: + - plugin: go + out: apps/sotto/proto/gen/go + opt: + - paths=source_relative + - plugin: go-grpc + out: apps/sotto/proto/gen/go + opt: + - paths=source_relative diff --git a/docs/architecture.md b/docs/architecture.md new file mode 100644 index 0000000..afadab4 --- /dev/null +++ b/docs/architecture.md @@ -0,0 +1,95 @@ +# Architecture + +`sotto` is a local-first ASR CLI with clear boundaries between state logic and side-effect adapters. + +## Component map + +```mermaid +flowchart LR + Trigger["Trigger\n(shell / hotkey / script)"] --> CLI["CLI + command dispatch"] + CLI --> IPC["IPC socket\n$XDG_RUNTIME_DIR/sotto.sock"] + IPC --> Session["Session controller\n(FSM + lifecycle)"] + + Session --> Audio["Audio capture\n(PipeWire/Pulse)"] + Session --> ASR["Riva streaming client\n(gRPC)"] + ASR --> Transcript["Transcript assembly\n(normalize + trailing space)"] + Transcript --> Output["Output adapters\n(clipboard + paste)"] + + Session --> Indicator["Indicator adapters\n(hypr or desktop) + cues"] + Session --> Logs["JSONL logs\n$XDG_STATE_HOME/sotto/log.jsonl"] +``` + +## Package responsibilities + +| Package | Responsibility | +| --- | --- | +| `internal/cli` | command/flag contract | +| `internal/app` | top-level wiring and dispatch | +| `internal/ipc` | single-instance socket lifecycle + forwarding | +| `internal/fsm` | legal session transitions | +| `internal/session` | lifecycle orchestration (`toggle`/`stop`/`cancel`) | +| `internal/audio` | device discovery/selection + capture stream | +| `internal/riva` | ASR stream transport + response accumulation | +| `internal/pipeline` | audio-to-ASR bridge + debug artifacts | +| `internal/transcript` | text normalization and assembly | +| `internal/output` | clipboard + paste adapters | +| `internal/indicator` | visual indicator + cue sound dispatch | +| `internal/doctor` | environment/readiness checks | +| `internal/logging` | session log bootstrap | + +## Runtime flow (`toggle` -> `toggle`) + +```mermaid +sequenceDiagram + participant T as Trigger + participant C as CLI + participant I as IPC + participant S as Session + participant A as Audio + participant R as Riva + participant O as Output + + T->>C: sotto toggle (start) + C->>I: acquire socket / become owner + I->>S: start + S->>A: start capture + S->>R: open stream + send config + A-->>R: PCM chunks + + T->>C: sotto toggle (stop) + C->>I: send stop + I->>S: stop + S->>R: close stream + gather transcript + S->>O: commit(transcript) +``` + +## Session state machine (`internal/fsm`) + +```mermaid +stateDiagram-v2 + [*] --> idle + + idle --> recording: start + recording --> transcribing: stop + recording --> idle: cancel + transcribing --> idle: transcribed + + idle --> error: fail + recording --> error: fail + transcribing --> error: fail + error --> idle: reset +``` + +Notes: + +- `fail` is a global event in code: it forces transition to `error` from any active state. +- Any transition not listed above is rejected by `fsm.Transition` as an invalid transition error. + +## Platform coupling (today) + +Current production path is Wayland + Hyprland: + +- default paste path calls `hyprctl sendshortcut` +- doctor checks require Hyprland session context + +This coupling is intentionally explicit and isolated in `internal/hypr` + output/doctor adapters so additional desktop targets can be added without changing session/FSM logic. diff --git a/docs/configuration.md b/docs/configuration.md new file mode 100644 index 0000000..0e0feac --- /dev/null +++ b/docs/configuration.md @@ -0,0 +1,201 @@ +# Configuration + +## Resolution order + +`sotto` loads configuration in this order: + +1. `--config ` +2. `$XDG_CONFIG_HOME/sotto/config.jsonc` +3. `~/.config/sotto/config.jsonc` + +If no `.jsonc` file exists at the default path, sotto falls back to legacy `config.conf` for compatibility. + +## Format + +Preferred format is **JSONC**: + +- JSON object root (`{ ... }`) +- line comments (`// ...`) +- block comments (`/* ... */`) +- trailing commas are accepted + +Unknown fields are hard errors. + +## Schema overview + +Top-level object keys: + +- `riva` +- `audio` +- `paste` +- `asr` +- `transcript` +- `indicator` +- `clipboard_cmd` +- `paste_cmd` +- `vocab` +- `debug` + +## Keys and defaults + +### `riva` + +| Key | Default | Notes | +| --- | --- | --- | +| `riva.grpc` | `127.0.0.1:50051` | gRPC ASR endpoint | +| `riva.http` | `127.0.0.1:9000` | HTTP endpoint for readiness checks | +| `riva.health_path` | `/v1/health/ready` | must start with `/` | + +### `audio` + +| Key | Default | Notes | +| --- | --- | --- | +| `audio.input` | `default` | preferred device match | +| `audio.fallback` | `default` | fallback device match | + +### `paste` + +| Key | Default | Notes | +| --- | --- | --- | +| `paste.enable` | `true` | run paste adapter after clipboard commit | +| `paste.shortcut` | `CTRL,V` | used by default Hyprland paste path when `paste_cmd` unset | + +### `asr` + +| Key | Default | Notes | +| --- | --- | --- | +| `asr.automatic_punctuation` | `true` | punctuation hint | +| `asr.language_code` | `en-US` | language code | +| `asr.model` | empty | optional explicit model | + +### `transcript` + +| Key | Default | Notes | +| --- | --- | --- | +| `transcript.trailing_space` | `true` | append space after assembled transcript | + +### `indicator` + +| Key | Default | Notes | +| --- | --- | --- | +| `indicator.enable` | `true` | visual indicator switch | +| `indicator.backend` | `hypr` | `hypr` or `desktop` | +| `indicator.desktop_app_name` | `sotto-indicator` | required for desktop backend | +| `indicator.sound_enable` | `true` | cue sounds switch | +| `indicator.sound_start_file` | empty | optional WAV path | +| `indicator.sound_stop_file` | empty | optional WAV path | +| `indicator.sound_complete_file` | empty | optional WAV path | +| `indicator.sound_cancel_file` | empty | optional WAV path | +| `indicator.height` | `28` | indicator size parameter | +| `indicator.text_recording` | `Recording…` | recording label | +| `indicator.text_processing` | `Transcribing…` | processing label | +| `indicator.text_transcribing` | alias | compatibility alias for `indicator.text_processing` | +| `indicator.text_error` | `Speech recognition error` | error label | +| `indicator.error_timeout_ms` | `1600` | `>= 0` | + +### command keys + +| Key | Default | Notes | +| --- | --- | --- | +| `clipboard_cmd` | `wl-copy --trim-newline` | command argv; no shell execution | +| `paste_cmd` | empty | optional explicit paste command override | + +### `vocab` + +| Key | Default | Notes | +| --- | --- | --- | +| `vocab.global` | empty | enabled vocab set names (array preferred; comma string also accepted) | +| `vocab.max_phrases` | `1024` | hard cap after dedupe | +| `vocab.sets` | empty map | map of named vocab sets | + +Each vocab set object supports: + +- `boost` (number) +- `phrases` (string array) + +### `debug` + +| Key | Default | Notes | +| --- | --- | --- | +| `debug.audio_dump` | `false` | write debug WAV artifacts | +| `debug.grpc_dump` | `false` | write raw ASR response JSON | + +## Desktop-notification placement example (mako) + +```conf +[app-name="sotto-indicator"] +anchor=top-center +default-timeout=0 +``` + +## Example (`config.jsonc`) + +```jsonc +{ + "riva": { + "grpc": "127.0.0.1:50051", + "http": "127.0.0.1:9000", + "health_path": "/v1/health/ready" + }, + + "audio": { + "input": "default", + "fallback": "default" + }, + + "paste": { + "enable": true, + "shortcut": "CTRL,V" + }, + + "clipboard_cmd": "wl-copy --trim-newline", + "paste_cmd": "", + + "asr": { + "automatic_punctuation": true, + "language_code": "en-US", + "model": "" + }, + + "transcript": { + "trailing_space": true + }, + + "indicator": { + "enable": true, + "backend": "hypr", + "desktop_app_name": "sotto-indicator", + "sound_enable": true, + "sound_start_file": "", + "sound_stop_file": "", + "sound_complete_file": "", + "sound_cancel_file": "", + "text_recording": "Recording…", + "text_processing": "Transcribing…", + "text_error": "Speech recognition error", + "error_timeout_ms": 1600 + }, + + "vocab": { + "global": ["internal"], + "max_phrases": 1024, + "sets": { + "internal": { + "boost": 14, + "phrases": ["Parakeet", "Riva", "local ASR"] + } + } + }, + + "debug": { + "audio_dump": false, + "grpc_dump": false + } +} +``` + +## Legacy format compatibility + +Legacy `key = value` config files are still accepted to avoid breaking deployed setups. + +When a legacy file is parsed, sotto emits a warning so you can migrate to JSONC. diff --git a/docs/modularity.md b/docs/modularity.md new file mode 100644 index 0000000..8775043 --- /dev/null +++ b/docs/modularity.md @@ -0,0 +1,46 @@ +# Modularity Review + +This document tracks readability risk and safe refactor slices. + +## Current state + +### What is working well + +- package boundaries are responsibility-oriented (`session`, `audio`, `riva`, `output`, `indicator`) +- state/FSM logic is separated from most I/O adapters +- test coverage is broad enough to support behavior-preserving extraction + +### Remaining readability hotspots + +Large handwritten files still carry higher review/refactor risk: + +- `internal/config/parser.go` +- `internal/audio/pulse.go` +- `internal/app/app.go` +- `internal/riva/client.go` +- `internal/pipeline/transcriber.go` +- `internal/session/session.go` + +Generated code is out of scope for these thresholds. + +## Refactor slices (behavior-preserving) + +1. `internal/config/parser.go` + - isolate token/scalar helpers + - isolate vocab block parser +2. `internal/audio/pulse.go` + - separate device selection from capture loop +3. `internal/app/app.go` + - split command dispatch, IPC forwarding, runtime bootstrap +4. `internal/riva/client.go` + - split stream lifecycle from segment merge helpers +5. `internal/pipeline/transcriber.go` + - split debug artifact writers from orchestration +6. `internal/session/session.go` + - isolate transition handling from commit/result assembly + +## Guardrails + +- soft target: handwritten files near `<= 250` LOC +- files above `~350` LOC need an explicit extraction note in `PLAN.md` +- extract in small slices with tests first diff --git a/docs/verification.md b/docs/verification.md new file mode 100644 index 0000000..4cd1270 --- /dev/null +++ b/docs/verification.md @@ -0,0 +1,50 @@ +# Verification + +## Required local gate + +Run before hand-off: + +```bash +just ci-check +nix build 'path:.#sotto' +``` + +## Optional integration-tag tests + +These are local-resource tests and are not part of the default CI gate: + +```bash +just test-integration +``` + +## Coverage snapshot + +```bash +go test ./apps/sotto/... -cover +``` + +## Manual runtime smoke (non-CI) + +Prerequisites: + +- local Riva endpoint is reachable +- active Wayland/Hyprland session +- valid `sotto` config + +Quick helpers: + +```bash +just smoke-riva-doctor +just smoke-riva-manual +``` + +Checklist: + +1. `sotto doctor` reports config/audio/Riva ready. +2. `sotto toggle` start -> speak -> `sotto toggle` stop. +3. Confirm non-empty transcript commit. +4. Confirm clipboard contains transcript after commit. +5. Confirm paste behavior for your configured adapter. +6. Run `sotto cancel` and verify clipboard is unchanged. +7. Stop Riva and confirm safe failure (no unintended clipboard/paste side effects). +8. Kill active `sotto` process mid-session and verify stale-socket recovery on next command. diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..3c3a5a8 --- /dev/null +++ b/flake.lock @@ -0,0 +1,61 @@ +{ + "nodes": { + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1771369470, + "narHash": "sha256-0NBlEBKkN3lufyvFegY4TYv5mCNHbi5OmBDrzihbBMQ=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "0182a361324364ae3f436a63005877674cf45efb", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..4d649c8 --- /dev/null +++ b/flake.nix @@ -0,0 +1,96 @@ +{ + description = "sotto: local-first ASR CLI"; + + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable"; + flake-utils.url = "github:numtide/flake-utils"; + }; + + outputs = + { + self, + nixpkgs, + flake-utils, + ... + }: + flake-utils.lib.eachDefaultSystem ( + system: + let + pkgs = import nixpkgs { inherit system; }; + version = if self ? shortRev then "0.1.0-${self.shortRev}" else "0.1.0-dev"; + in + { + packages = rec { + sotto = pkgs.buildGoModule { + pname = "sotto"; + inherit version; + src = ./.; + modRoot = "apps/sotto"; + go = pkgs.go_1_25; + subPackages = [ "cmd/sotto" ]; + vendorHash = "sha256-4/+DtLMcMwhckIH+ieVlsleXxzdA+J1kYXrpzVmW52s="; + env.GOWORK = "off"; + ldflags = [ + "-s" + "-w" + "-X github.com/rbright/sotto/internal/version.Version=${version}" + "-X github.com/rbright/sotto/internal/version.Commit=${self.shortRev or "dirty"}" + "-X github.com/rbright/sotto/internal/version.Date=unknown" + ]; + nativeBuildInputs = [ pkgs.makeWrapper ]; + postInstall = '' + wrapProgram $out/bin/sotto \ + --prefix PATH : ${ + pkgs.lib.makeBinPath [ + pkgs.curl + pkgs.hyprland + pkgs.pipewire + pkgs.systemd + pkgs.wl-clipboard + ] + } + ''; + }; + + default = sotto; + }; + + apps = { + sotto = flake-utils.lib.mkApp { drv = self.packages.${system}.sotto; }; + default = self.apps.${system}.sotto; + }; + + devShells.default = pkgs.mkShell { + packages = with pkgs; [ + go_1_25 + just + buf + protobuf + golangci-lint + statix + deadnix + nixfmt-rfc-style + prek + ]; + }; + + formatter = pkgs.nixfmt-rfc-style; + } + ) + // { + nixosModules.default = + { + config, + lib, + pkgs, + ... + }: + { + options.programs.sotto.enable = lib.mkEnableOption "Install the sotto CLI"; + + config = lib.mkIf config.programs.sotto.enable { + environment.systemPackages = [ self.packages.${pkgs.stdenv.hostPlatform.system}.sotto ]; + }; + }; + }; +} diff --git a/go.work b/go.work new file mode 100644 index 0000000..a09f467 --- /dev/null +++ b/go.work @@ -0,0 +1,3 @@ +go 1.25.5 + +use ./apps/sotto diff --git a/go.work.sum b/go.work.sum new file mode 100644 index 0000000..04cb365 --- /dev/null +++ b/go.work.sum @@ -0,0 +1,21 @@ +cel.dev/expr v0.25.1/go.mod h1:hrXvqGP6G6gyx8UAHSHJ5RGk//1Oj5nXQ2NI02Nrsg4= +cloud.google.com/go/compute/metadata v0.9.0/go.mod h1:E0bWwX5wTnLPedCKqk3pJmVgCBSM6qQI1yTBdEb3C10= +github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.30.0/go.mod h1:P4WPRUkOhJC13W//jWpyfJNDAIpvRbAUIYLX/4jtlE0= +github.com/cncf/xds/go v0.0.0-20251210132809-ee656c7534f5/go.mod h1:KdCmV+x/BuvyMxRnYBlmVaq4OLiKW6iRQfvC62cvdkI= +github.com/envoyproxy/go-control-plane v0.14.0/go.mod h1:NcS5X47pLl/hfqxU70yPwL9ZMkUlwlKxtAohpi2wBEU= +github.com/envoyproxy/go-control-plane/envoy v1.36.0/go.mod h1:ty89S1YCCVruQAm9OtKeEkQLTb+Lkz0k8v9W0Oxsv98= +github.com/envoyproxy/go-control-plane/ratelimit v0.1.0/go.mod h1:Wk+tMFAFbCXaJPzVVHnPgRKdUdwW/KdbRt94AzgRee4= +github.com/envoyproxy/protoc-gen-validate v1.3.0/go.mod h1:HvYl7zwPa5mffgyeTUHA9zHIH36nmrm7oCbo4YKoSWA= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= +github.com/golang/glog v1.2.5/go.mod h1:6AhwSGph0fcJtXVM/PEHPqZlFeoLxhs7/t5UDAwmO+w= +github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10/go.mod h1:t/avpk3KcrXxUnYOhZhMXJlSEyie6gQbtLq5NM3loB8= +github.com/spiffe/go-spiffe/v2 v2.6.0/go.mod h1:gm2SeUoMZEtpnzPNs2Csc0D/gX33k1xIx7lEzqblHEs= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +go.opentelemetry.io/contrib/detectors/gcp v1.39.0/go.mod h1:t/OGqzHBa5v6RHZwrDBJ2OirWc+4q/w2fTbLZwAKjTk= +golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= +golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= +golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= +golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= +golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= +google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto= diff --git a/justfile b/justfile new file mode 100644 index 0000000..7931cfb --- /dev/null +++ b/justfile @@ -0,0 +1,10 @@ +set shell := ["bash", "-euo", "pipefail", "-c"] +set positional-arguments + +import ".just/common.just" +import ".just/go.just" +import ".just/codegen.just" +import ".just/ci.just" +import ".just/nix.just" +import ".just/hooks.just" +import ".just/smoke.just"