diff --git a/.github/workflows/build-amd.yml b/.github/workflows/build-amd.yml new file mode 100644 index 000000000..8ae5a5cf1 --- /dev/null +++ b/.github/workflows/build-amd.yml @@ -0,0 +1,69 @@ +# .github/workflows/build-amd.yml +name: Build AMD GPU binary (Vulkan + DirectML) + +# on: +# push: +# paths: +# - '.github/workflows/build-amd.yml' +# - 'CMakeLists.txt' +# - '**/*.cpp' +# - '**/*.hpp' +# - '**/*.cu' +# workflow_dispatch: + +jobs: + build-amd: + name: Windows AMD (Vulkan + DirectML) + runs-on: windows-2025 + + env: + VULKAN_VERSION: '1.3.261.1' + + steps: + # 1) Репозиторий + - uses: actions/checkout@v4 + with: + submodules: recursive + fetch-depth: 0 + + # 2) Vulkan SDK + - name: Install Vulkan SDK + shell: pwsh + run: | + curl.exe -o $Env:RUNNER_TEMP\VulkanSDK-Installer.exe ` + -L "https://sdk.lunarg.com/sdk/download/$Env:VULKAN_VERSION/windows/VulkanSDK-$Env:VULKAN_VERSION-Installer.exe" + & "$Env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install + Add-Content $Env:GITHUB_ENV "VULKAN_SDK=C:\VulkanSDK\$Env:VULKAN_VERSION" + Add-Content $Env:GITHUB_ENV "PATH=C:\VulkanSDK\$Env:VULKAN_VERSION\Bin;$Env:PATH" + + # 3) CMake (DirectML+Vulkan) + - name: Configure CMake + shell: pwsh + run: | + New-Item -ItemType Directory -Force -Path build + Set-Location build + cmake .. ` + -G "Visual Studio 17 2022" ` + -A x64 ` + -DCMAKE_BUILD_TYPE=Release ` + -DSD_VULKAN=ON ` + -DSD_DIRECTML=ON ` + -DBUILD_SHARED_LIBS=OFF + + # 4) Сборка sd и sd-server + - name: Build + shell: pwsh + run: | + cmake --build build --config Release --parallel + cmake --build build --config Release --target sd-server --parallel + + # 5) Артефакты + - name: Upload binaries + uses: actions/upload-artifact@v4 + with: + name: sd-win-amd + path: | + build/bin/Release/sd.exe + build/bin/Release/sd-server.exe + build/bin/Release/libstable-diffusion*.dll + build/bin/Release/*.spv # шейдеры Vulkan, если они генерируются diff --git a/.github/workflows/build-linux.yml b/.github/workflows/build-linux.yml new file mode 100644 index 000000000..a3e07d680 --- /dev/null +++ b/.github/workflows/build-linux.yml @@ -0,0 +1,161 @@ +name: sd-linux-build + +on: + workflow_dispatch: + push: + paths: + - '.github/workflows/build-linux.yml' + - 'CMakeLists.txt' + - '**/*.cpp' + - '**/*.hpp' + - '**/*.cu' + +jobs: + linux-vulkan: + name: Linux x86_64 (Vulkan) + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + with: + submodules: recursive + fetch-depth: 0 + + - name: Install deps (build + Vulkan + glslc) + run: | + sudo apt-get update + sudo apt-get install -y \ + build-essential cmake ninja-build \ + libvulkan-dev mesa-vulkan-drivers vulkan-tools \ + glslc + # опционально: чтобы меньше системных зависимостей + sudo apt-get install -y libc6-dev + + - name: Configure (Vulkan, mostly-static libstdc++) + run: | + cmake -B build -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TRY_COMPILE_CONFIGURATION=Release \ + -DCMAKE_EXE_LINKER_FLAGS="-static-libstdc++ -static-libgcc" \ + -DBUILD_SHARED_LIBS=OFF \ + -DSD_VULKAN=ON + + - name: Build sd & sd-server + run: | + cmake --build build --parallel + cmake --build build --parallel --target sd-server + + - name: Collect package + run: | + mkdir -p package + # бинарники + cp build/bin/sd package/ + cp build/bin/sd-server package/ + # шейдеры, если сборка их сгенерировала + if ls build/bin/*.spv >/dev/null 2>&1; then + cp build/bin/*.spv package/ + fi + # если проект кладёт свои so — заберём + if ls build/bin/libstable-diffusion*.so >/dev/null 2>&1; then + cp build/bin/libstable-diffusion*.so package/ + fi + tar -czf sd-linux-vulkan.tar.gz -C package . + + - uses: actions/upload-artifact@v4 + with: + name: sd-linux-vulkan + path: sd-linux-vulkan.tar.gz + + linux-cuda: + name: Linux x86_64 (CUDA 12.4) + runs-on: ubuntu-latest + container: + image: nvidia/cuda:12.4.1-devel-ubuntu22.04 + + # В контейнере по умолчанию /bin/sh — фиксируем bash + defaults: + run: + shell: bash + + steps: + - name: Install build deps (inside CUDA container) + run: | + set -euo pipefail + apt-get update + DEBIAN_FRONTEND=noninteractive apt-get install -y \ + git ca-certificates build-essential cmake ninja-build \ + libc6-dev pkg-config patchelf rsync + update-ca-certificates + + - uses: actions/checkout@v4 + with: + submodules: recursive + fetch-depth: 0 + + - name: Show CUDA version + run: nvcc --version || true + + - name: Configure (CUDA, mostly-static libstdc++ + rpath=$ORIGIN) + env: + CMAKE_CUDA_ARCHS: "61;75;86;89" + run: | + set -euo pipefail + cmake -B build -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TRY_COMPILE_CONFIGURATION=Release \ + -DCMAKE_EXE_LINKER_FLAGS="-static-libstdc++ -static-libgcc -Wl,-rpath,'\$ORIGIN'" \ + -DBUILD_SHARED_LIBS=OFF \ + -DGGML_CUDA=ON \ + -DSD_CUDA=ON \ + "-DCMAKE_CUDA_ARCHITECTURES=${CMAKE_CUDA_ARCHS}" + + - name: Build sd & sd-server + run: | + set -euo pipefail + cmake --build build --parallel + cmake --build build --parallel --target sd-server + + - name: Collect package (bundle only needed CUDA SONAMEs) + run: | + set -euo pipefail + mkdir -p package + cp build/bin/sd package/ + cp build/bin/sd-server package/ + if ls build/bin/libstable-diffusion*.so >/dev/null 2>&1; then + cp -a build/bin/libstable-diffusion*.so package/ + fi + + # где лежат CUDA библиотеки в контейнере + CUDA_LIBDIR="/usr/local/cuda/lib64" + if [ ! -d "$CUDA_LIBDIR" ]; then + CUDA_LIBDIR="/usr/local/cuda/targets/x86_64-linux/lib" + fi + + # Копируем только SONAME (*.so.12*) + их target'ы (сохраняем symlink-и). + # Безверсные *.so НЕ нужны для рантайма. + rsync -a --info=NAME \ + "$CUDA_LIBDIR/libcudart.so.12"* \ + "$CUDA_LIBDIR/libcublas.so.12"* \ + "$CUDA_LIBDIR/libcublasLt.so.12"* \ + "$CUDA_LIBDIR/libcurand.so.10"* \ + "$CUDA_LIBDIR/libcusparse.so.12"* \ + "$CUDA_LIBDIR/libnvrtc.so.12"* \ + "$CUDA_LIBDIR/libnvJitLink.so.12"* \ + package/ || true + + # облегчить бинарники (не критично) + strip -s package/sd package/sd-server || true + + # гарантируем rpath -> $ORIGIN + patchelf --set-rpath '$ORIGIN' package/sd || true + patchelf --set-rpath '$ORIGIN' package/sd-server || true + + echo "ldd sd:" && (cd package && ldd ./sd || true) + echo "ldd sd-server:" && (cd package && ldd ./sd-server || true) + + tar -czf sd-linux-cuda.tar.gz -C package . + + - uses: actions/upload-artifact@v4 + with: + name: sd-linux-cuda + path: sd-linux-cuda.tar.gz diff --git a/.github/workflows/build-mac.yml b/.github/workflows/build-mac.yml new file mode 100644 index 000000000..e09971286 --- /dev/null +++ b/.github/workflows/build-mac.yml @@ -0,0 +1,40 @@ +name: sd-mac-build + +# on: [push, workflow_dispatch] + +jobs: + build: + runs-on: macos-14 + + steps: + # ① Клонируем репозиторий вместе с подмодулями + - uses: actions/checkout@v4 + with: + submodules: recursive + fetch-depth: 1 + + # ② Конфигурация CMake (универсальный статический билд с Metal) + - name: Configure + run: | + cmake -B build \ + -DCMAKE_BUILD_TYPE=Release \ + -DSD_METAL=ON \ + -DBUILD_SHARED_LIBS=OFF \ + -DCMAKE_OSX_ARCHITECTURES="x86_64;arm64" + + # ③ Сборка sd и sd-server + - name: Build + run: | + cmake --build build --config Release -j$(sysctl -n hw.logicalcpu) + cmake --build build --config Release --target sd-server -j$(sysctl -n hw.logicalcpu) + + # ④ Артефакты (оба бинаря) + - uses: actions/upload-artifact@v4 + with: + name: sd-mac-universal + path: | + build/bin/sd + build/bin/sd-server + build/bin/libstable-diffusion*.dylib + build/bin/*.metal + build/bin/ggml-*.h diff --git a/.github/workflows/build-windows.yml b/.github/workflows/build-windows.yml new file mode 100644 index 000000000..9216338d2 --- /dev/null +++ b/.github/workflows/build-windows.yml @@ -0,0 +1,120 @@ +name: sd server + +# on: +# push: +# branches: [ master ] +# paths: +# - '.github/workflows/build-windows.yml' +# - 'CMakeLists.txt' +# - '**/*.cpp' +# - '**/*.cu' +# workflow_dispatch: + +jobs: + win-cuda-129: + runs-on: windows-2022 + + steps: + # 1 ─ репозиторий + - uses: actions/checkout@v4 + with: + submodules: recursive + fetch-depth: 0 + + # 2 ─ Ninja + - name: Install Ninja + shell: pwsh + run: choco install ninja -y + + # 3 ─ кэш CUDA + - name: Restore CUDA cache + id: cache-cuda + uses: actions/cache@v4 + with: + path: 'C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.9' + key: cuda-12.9.1-win + + # 4 ─ установка CUDA 12.9 (если нет кэша) + - name: Install CUDA 12.9.1 + if: steps.cache-cuda.outputs.cache-hit != 'true' + shell: pwsh + run: | + $url = 'https://developer.download.nvidia.com/compute/cuda/12.9.1/local_installers/cuda_12.9.1_576.57_windows.exe' + $exe = "$env:RUNNER_TEMP\cuda129.exe" + Invoke-WebRequest -Uri $url -OutFile $exe + Start-Process $exe -ArgumentList '-s' -Wait + + # 5 ─ конфигурация + сборка (MSVC host, статический CRT) + - name: Configure & Build + shell: cmd + run: | + rem 1) build dir + mkdir build + + rem 2) MSVC + Windows SDK + call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvars64.bat" + + rem 3) Set CUDA environment + set "CUDA_PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.9" + set "PATH=%CUDA_PATH%\bin;%CUDA_PATH%\libnvvp;%PATH%" + + rem 4) CMake → Ninja + cd build + cmake .. -G Ninja ^ + -DCMAKE_BUILD_TYPE=Release ^ + -DCMAKE_TRY_COMPILE_CONFIGURATION=Release ^ + -DCMAKE_MSVC_RUNTIME_LIBRARY=MultiThreaded ^ + -DGGML_CUDA=ON -DSD_CUDA=ON ^ + "-DCMAKE_CUDA_ARCHITECTURES=61;75;86;89;120" + + rem 5) сборка основного бинарника + cmake --build . --parallel + + rem 6) сборка примера-сервера sd-server.exe + cmake --build . --parallel --target sd-server + + # 6 ─ проверить что собралось + - name: List build outputs + shell: pwsh + run: | + Write-Host "Build directory contents:" + Get-ChildItem -Recurse build | Where-Object {$_.Extension -eq '.exe'} | ForEach-Object { Write-Host $_.FullName } + + # 7 ─ собрать пакет + - name: Collect runtime + shell: pwsh + run: | + $pkg = 'package' + New-Item $pkg -ItemType Directory -Force | Out-Null + + # Найти исполняемые файлы в build + $exeFiles = Get-ChildItem -Recurse build -Name "*.exe" + Write-Host "Found executables: $($exeFiles -join ', ')" + + foreach ($exe in $exeFiles) { + $srcPath = Join-Path "build" $exe + $fileName = Split-Path $exe -Leaf + Write-Host "Copying $srcPath -> $pkg\$fileName" + Copy-Item $srcPath "$pkg\$fileName" + } + + # Копировать CUDA DLL + $cudaPath = "C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.9\bin" + if (Test-Path $cudaPath) { + robocopy $cudaPath $pkg ` + cudart64_*.dll cublas64_*.dll cublasLt64_*.dll curand64_*.dll + } else { + Write-Warning "CUDA path not found: $cudaPath" + } + exit 0 + + # 8 ─ zip + artifact + - name: Create zip + shell: pwsh + run: 7z a sd-win-cuda12.9.zip package\* + + - name: Upload artifact + uses: actions/upload-artifact@v4 + with: + name: sd-win-cuda12.9 + path: sd-win-cuda12.9.zip diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index fe1410891..dd36ec9a0 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1,331 +1,331 @@ -name: CI - -on: - workflow_dispatch: # allows manual triggering - inputs: - create_release: - description: "Create new release" - required: true - type: boolean - push: - branches: - - master - - ci - paths: - [ - ".github/workflows/**", - "**/CMakeLists.txt", - "**/Makefile", - "**/*.h", - "**/*.hpp", - "**/*.c", - "**/*.cpp", - "**/*.cu", - ] - pull_request: - types: [opened, synchronize, reopened] - paths: - [ - "**/CMakeLists.txt", - "**/Makefile", - "**/*.h", - "**/*.hpp", - "**/*.c", - "**/*.cpp", - "**/*.cu", - ] - -env: - BRANCH_NAME: ${{ github.head_ref || github.ref_name }} - -jobs: - ubuntu-latest-cmake: - runs-on: ubuntu-latest - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v3 - with: - submodules: recursive - - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get install build-essential - - - name: Build - id: cmake_build - run: | - mkdir build - cd build - cmake .. -DGGML_AVX2=ON -DSD_BUILD_SHARED_LIBS=ON - cmake --build . --config Release - - - name: Get commit hash - id: commit - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/main' ) || github.event.inputs.create_release == 'true' }} - uses: pr-mpt/actions-commit-hash@v2 - - - name: Fetch system info - id: system-info - run: | - echo "CPU_ARCH=`uname -m`" >> "$GITHUB_OUTPUT" - echo "OS_NAME=`lsb_release -s -i`" >> "$GITHUB_OUTPUT" - echo "OS_VERSION=`lsb_release -s -r`" >> "$GITHUB_OUTPUT" - echo "OS_TYPE=`uname -s`" >> "$GITHUB_OUTPUT" - - - name: Pack artifacts - id: pack_artifacts - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - run: | - cp ggml/LICENSE ./build/bin/ggml.txt - cp LICENSE ./build/bin/stable-diffusion.cpp.txt - zip -j sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-${{ steps.system-info.outputs.OS_NAME }}-${{ steps.system-info.outputs.OS_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}.zip ./build/bin/* - - - name: Upload artifacts - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - uses: actions/upload-artifact@v4 - with: - name: sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-${{ steps.system-info.outputs.OS_NAME }}-${{ steps.system-info.outputs.OS_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}.zip - path: | - sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-${{ steps.system-info.outputs.OS_NAME }}-${{ steps.system-info.outputs.OS_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}.zip - - macOS-latest-cmake: - runs-on: macos-latest - - steps: - - name: Clone - id: checkout - uses: actions/checkout@v3 - with: - submodules: recursive - - - name: Dependencies - id: depends - run: | - brew install zip - - - name: Build - id: cmake_build - run: | - sysctl -a - mkdir build - cd build - cmake .. -DGGML_AVX2=ON -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" -DSD_BUILD_SHARED_LIBS=ON - cmake --build . --config Release - - - name: Get commit hash - id: commit - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/main' ) || github.event.inputs.create_release == 'true' }} - uses: pr-mpt/actions-commit-hash@v2 - - - name: Fetch system info - id: system-info - run: | - echo "CPU_ARCH=`uname -m`" >> "$GITHUB_OUTPUT" - echo "OS_NAME=`sw_vers -productName`" >> "$GITHUB_OUTPUT" - echo "OS_VERSION=`sw_vers -productVersion`" >> "$GITHUB_OUTPUT" - echo "OS_TYPE=`uname -s`" >> "$GITHUB_OUTPUT" - - - name: Pack artifacts - id: pack_artifacts - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - run: | - cp ggml/LICENSE ./build/bin/ggml.txt - cp LICENSE ./build/bin/stable-diffusion.cpp.txt - zip -j sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-${{ steps.system-info.outputs.OS_NAME }}-${{ steps.system-info.outputs.OS_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}.zip ./build/bin/* - - - name: Upload artifacts - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - uses: actions/upload-artifact@v4 - with: - name: sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-${{ steps.system-info.outputs.OS_NAME }}-${{ steps.system-info.outputs.OS_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}.zip - path: | - sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-${{ steps.system-info.outputs.OS_NAME }}-${{ steps.system-info.outputs.OS_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}.zip - - windows-latest-cmake: - runs-on: windows-2019 - - env: - VULKAN_VERSION: 1.3.261.1 - - strategy: - matrix: - include: - - build: "noavx" - defines: "-DGGML_AVX=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF -DSD_BUILD_SHARED_LIBS=ON" - - build: "avx2" - defines: "-DGGML_AVX2=ON -DSD_BUILD_SHARED_LIBS=ON" - - build: "avx" - defines: "-DGGML_AVX2=OFF -DSD_BUILD_SHARED_LIBS=ON" - - build: "avx512" - defines: "-DGGML_AVX512=ON -DSD_BUILD_SHARED_LIBS=ON" - - build: "cuda12" - defines: "-DSD_CUBLAS=ON -DSD_BUILD_SHARED_LIBS=ON" - - build: "rocm5.5" - defines: '-G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DSD_HIPBLAS=ON -DCMAKE_BUILD_TYPE=Release -DAMDGPU_TARGETS="gfx1100;gfx1102;gfx1030" -DSD_BUILD_SHARED_LIBS=ON' - - build: 'vulkan' - defines: "-DSD_VULKAN=ON -DSD_BUILD_SHARED_LIBS=ON" - steps: - - name: Clone - id: checkout - uses: actions/checkout@v3 - with: - submodules: recursive - - - name: Install cuda-toolkit - id: cuda-toolkit - if: ${{ matrix.build == 'cuda12' }} - uses: Jimver/cuda-toolkit@v0.2.11 - with: - cuda: "12.2.0" - method: "network" - sub-packages: '["nvcc", "cudart", "cublas", "cublas_dev", "thrust", "visual_studio_integration"]' - - - name: Install rocm-toolkit - id: rocm-toolkit - if: ${{ matrix.build == 'rocm5.5' }} - uses: Cyberhan123/rocm-toolkit@v0.1.0 - with: - rocm: "5.5.0" - - - name: Install Ninja - id: install-ninja - if: ${{ matrix.build == 'rocm5.5' }} - uses: urkle/action-get-ninja@v1 - with: - version: 1.11.1 - - name: Install Vulkan SDK - id: get_vulkan - if: ${{ matrix.build == 'vulkan' }} - run: | - curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/VulkanSDK-${env:VULKAN_VERSION}-Installer.exe" - & "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install - Add-Content $env:GITHUB_ENV "VULKAN_SDK=C:\VulkanSDK\${env:VULKAN_VERSION}" - Add-Content $env:GITHUB_PATH "C:\VulkanSDK\${env:VULKAN_VERSION}\bin" - - - name: Build - id: cmake_build - run: | - mkdir build - cd build - cmake .. ${{ matrix.defines }} - cmake --build . --config Release - - - name: Check AVX512F support - id: check_avx512f - if: ${{ matrix.build == 'avx512' }} - continue-on-error: true - run: | - cd build - $vcdir = $(vswhere -latest -products * -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath) - $msvc = $(join-path $vcdir $('VC\Tools\MSVC\'+$(gc -raw $(join-path $vcdir 'VC\Auxiliary\Build\Microsoft.VCToolsVersion.default.txt')).Trim())) - $cl = $(join-path $msvc 'bin\Hostx64\x64\cl.exe') - echo 'int main(void){unsigned int a[4];__cpuid(a,7);return !(a[1]&65536);}' >> avx512f.c - & $cl /O2 /GS- /kernel avx512f.c /link /nodefaultlib /entry:main - .\avx512f.exe && echo "AVX512F: YES" && ( echo HAS_AVX512F=1 >> $env:GITHUB_ENV ) || echo "AVX512F: NO" - - - name: Get commit hash - id: commit - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - uses: pr-mpt/actions-commit-hash@v2 - - - name: Pack artifacts - id: pack_artifacts - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - run: | - $filePath = ".\build\bin\Release\*" - if (Test-Path $filePath) { - echo "Exists at path $filePath" - Copy-Item ggml/LICENSE .\build\bin\Release\ggml.txt - Copy-Item LICENSE .\build\bin\Release\stable-diffusion.cpp.txt - } elseif (Test-Path ".\build\bin\stable-diffusion.dll") { - $filePath = ".\build\bin\*" - echo "Exists at path $filePath" - Copy-Item ggml/LICENSE .\build\bin\ggml.txt - Copy-Item LICENSE .\build\bin\stable-diffusion.cpp.txt - } else { - ls .\build\bin - throw "Can't find stable-diffusion.dll" - } - 7z a sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-x64.zip $filePath - - - name: Copy and pack Cuda runtime - id: pack_cuda_runtime - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' && matrix.build == 'cuda12' ) || github.event.inputs.create_release == 'true' }} - run: | - echo "Cuda install location: ${{steps.cuda-toolkit.outputs.CUDA_PATH}}" - $dst='.\build\bin\cudart\' - robocopy "${{steps.cuda-toolkit.outputs.CUDA_PATH}}\bin" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll - 7z a cudart-sd-bin-win-cu12-x64.zip $dst\* - - - name: Upload Cuda runtime - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' && matrix.build == 'cuda12' ) || github.event.inputs.create_release == 'true' }} - uses: actions/upload-artifact@v4 - with: - name: sd-cudart-sd-bin-win-cu12-x64.zip - path: | - cudart-sd-bin-win-cu12-x64.zip - - - name: Upload artifacts - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - uses: actions/upload-artifact@v4 - with: - name: sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-x64.zip - path: | - sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-x64.zip - - release: - if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} - - runs-on: ubuntu-latest - - needs: - - ubuntu-latest-cmake - - macOS-latest-cmake - - windows-latest-cmake - - steps: - - name: Download artifacts - id: download-artifact - uses: actions/download-artifact@v4 - with: - path: ./artifact - pattern: sd-* - merge-multiple: true - - - name: Get commit hash - id: commit - uses: pr-mpt/actions-commit-hash@v2 - - - name: Create release - id: create_release - uses: anzz1/action-create-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - with: - tag_name: ${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }} - - - name: Upload release - id: upload_release - uses: actions/github-script@v3 - with: - github-token: ${{secrets.GITHUB_TOKEN}} - script: | - const path = require('path'); - const fs = require('fs'); - const release_id = '${{ steps.create_release.outputs.id }}'; - for (let file of await fs.readdirSync('./artifact')) { - if (path.extname(file) === '.zip') { - console.log('uploadReleaseAsset', file); - await github.repos.uploadReleaseAsset({ - owner: context.repo.owner, - repo: context.repo.repo, - release_id: release_id, - name: file, - data: await fs.readFileSync(`./artifact/${file}`) - }); - } - } +# name: CI + +# on: +# workflow_dispatch: # allows manual triggering +# inputs: +# create_release: +# description: "Create new release" +# required: true +# type: boolean +# push: +# branches: +# - master +# - ci +# paths: +# [ +# ".github/workflows/**", +# "**/CMakeLists.txt", +# "**/Makefile", +# "**/*.h", +# "**/*.hpp", +# "**/*.c", +# "**/*.cpp", +# "**/*.cu", +# ] +# pull_request: +# types: [opened, synchronize, reopened] +# paths: +# [ +# "**/CMakeLists.txt", +# "**/Makefile", +# "**/*.h", +# "**/*.hpp", +# "**/*.c", +# "**/*.cpp", +# "**/*.cu", +# ] + +# env: +# BRANCH_NAME: ${{ github.head_ref || github.ref_name }} + +# jobs: +# ubuntu-latest-cmake: +# runs-on: ubuntu-latest + +# steps: +# - name: Clone +# id: checkout +# uses: actions/checkout@v3 +# with: +# submodules: recursive + +# - name: Dependencies +# id: depends +# run: | +# sudo apt-get update +# sudo apt-get install build-essential + +# - name: Build +# id: cmake_build +# run: | +# mkdir build +# cd build +# cmake .. -DGGML_AVX2=ON -DSD_BUILD_SHARED_LIBS=ON +# cmake --build . --config Release + +# - name: Get commit hash +# id: commit +# if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/main' ) || github.event.inputs.create_release == 'true' }} +# uses: pr-mpt/actions-commit-hash@v2 + +# - name: Fetch system info +# id: system-info +# run: | +# echo "CPU_ARCH=`uname -m`" >> "$GITHUB_OUTPUT" +# echo "OS_NAME=`lsb_release -s -i`" >> "$GITHUB_OUTPUT" +# echo "OS_VERSION=`lsb_release -s -r`" >> "$GITHUB_OUTPUT" +# echo "OS_TYPE=`uname -s`" >> "$GITHUB_OUTPUT" + +# - name: Pack artifacts +# id: pack_artifacts +# if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} +# run: | +# cp ggml/LICENSE ./build/bin/ggml.txt +# cp LICENSE ./build/bin/stable-diffusion.cpp.txt +# zip -j sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-${{ steps.system-info.outputs.OS_NAME }}-${{ steps.system-info.outputs.OS_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}.zip ./build/bin/* + +# - name: Upload artifacts +# if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} +# uses: actions/upload-artifact@v4 +# with: +# name: sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-${{ steps.system-info.outputs.OS_NAME }}-${{ steps.system-info.outputs.OS_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}.zip +# path: | +# sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-${{ steps.system-info.outputs.OS_NAME }}-${{ steps.system-info.outputs.OS_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}.zip + +# macOS-latest-cmake: +# runs-on: macos-latest + +# steps: +# - name: Clone +# id: checkout +# uses: actions/checkout@v3 +# with: +# submodules: recursive + +# - name: Dependencies +# id: depends +# run: | +# brew install zip + +# - name: Build +# id: cmake_build +# run: | +# sysctl -a +# mkdir build +# cd build +# cmake .. -DGGML_AVX2=ON -DCMAKE_OSX_ARCHITECTURES="arm64;x86_64" -DSD_BUILD_SHARED_LIBS=ON +# cmake --build . --config Release + +# - name: Get commit hash +# id: commit +# if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/main' ) || github.event.inputs.create_release == 'true' }} +# uses: pr-mpt/actions-commit-hash@v2 + +# - name: Fetch system info +# id: system-info +# run: | +# echo "CPU_ARCH=`uname -m`" >> "$GITHUB_OUTPUT" +# echo "OS_NAME=`sw_vers -productName`" >> "$GITHUB_OUTPUT" +# echo "OS_VERSION=`sw_vers -productVersion`" >> "$GITHUB_OUTPUT" +# echo "OS_TYPE=`uname -s`" >> "$GITHUB_OUTPUT" + +# - name: Pack artifacts +# id: pack_artifacts +# if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} +# run: | +# cp ggml/LICENSE ./build/bin/ggml.txt +# cp LICENSE ./build/bin/stable-diffusion.cpp.txt +# zip -j sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-${{ steps.system-info.outputs.OS_NAME }}-${{ steps.system-info.outputs.OS_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}.zip ./build/bin/* + +# - name: Upload artifacts +# if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} +# uses: actions/upload-artifact@v4 +# with: +# name: sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-${{ steps.system-info.outputs.OS_NAME }}-${{ steps.system-info.outputs.OS_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}.zip +# path: | +# sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-${{ steps.system-info.outputs.OS_TYPE }}-${{ steps.system-info.outputs.OS_NAME }}-${{ steps.system-info.outputs.OS_VERSION }}-${{ steps.system-info.outputs.CPU_ARCH }}.zip + +# windows-latest-cmake: +# runs-on: windows-2025 + +# env: +# VULKAN_VERSION: 1.3.261.1 + +# strategy: +# matrix: +# include: +# - build: "noavx" +# defines: "-DGGML_NATIVE=OFF -DGGML_AVX=OFF -DGGML_AVX2=OFF -DGGML_FMA=OFF -DSD_BUILD_SHARED_LIBS=ON" +# - build: "avx2" +# defines: "-DGGML_NATIVE=OFF -DGGML_AVX2=ON -DSD_BUILD_SHARED_LIBS=ON" +# - build: "avx" +# defines: "-DGGML_NATIVE=OFF -DGGML_AVX=ON -DGGML_AVX2=OFF -DSD_BUILD_SHARED_LIBS=ON" +# - build: "avx512" +# defines: "-DGGML_NATIVE=OFF -DGGML_AVX512=ON -DGGML_AVX=ON -DGGML_AVX2=ON -DSD_BUILD_SHARED_LIBS=ON" +# - build: "cuda12" +# defines: "-DSD_CUDA=ON -DSD_BUILD_SHARED_LIBS=ON -DCMAKE_CUDA_ARCHITECTURES=90;89;80;75" +# # - build: "rocm5.5" +# # defines: '-G Ninja -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DSD_HIPBLAS=ON -DCMAKE_BUILD_TYPE=Release -DAMDGPU_TARGETS="gfx1100;gfx1102;gfx1030" -DSD_BUILD_SHARED_LIBS=ON' +# - build: 'vulkan' +# defines: "-DSD_VULKAN=ON -DSD_BUILD_SHARED_LIBS=ON" +# steps: +# - name: Clone +# id: checkout +# uses: actions/checkout@v3 +# with: +# submodules: recursive + +# - name: Install cuda-toolkit +# id: cuda-toolkit +# if: ${{ matrix.build == 'cuda12' }} +# uses: Jimver/cuda-toolkit@v0.2.19 +# with: +# cuda: "12.6.2" +# method: "network" +# sub-packages: '["nvcc", "cudart", "cublas", "cublas_dev", "thrust", "visual_studio_integration"]' + +# - name: Install rocm-toolkit +# id: rocm-toolkit +# if: ${{ matrix.build == 'rocm5.5' }} +# uses: Cyberhan123/rocm-toolkit@v0.1.0 +# with: +# rocm: "5.5.0" + +# - name: Install Ninja +# id: install-ninja +# if: ${{ matrix.build == 'rocm5.5' }} +# uses: urkle/action-get-ninja@v1 +# with: +# version: 1.11.1 +# - name: Install Vulkan SDK +# id: get_vulkan +# if: ${{ matrix.build == 'vulkan' }} +# run: | +# curl.exe -o $env:RUNNER_TEMP/VulkanSDK-Installer.exe -L "https://sdk.lunarg.com/sdk/download/${env:VULKAN_VERSION}/windows/VulkanSDK-${env:VULKAN_VERSION}-Installer.exe" +# & "$env:RUNNER_TEMP\VulkanSDK-Installer.exe" --accept-licenses --default-answer --confirm-command install +# Add-Content $env:GITHUB_ENV "VULKAN_SDK=C:\VulkanSDK\${env:VULKAN_VERSION}" +# Add-Content $env:GITHUB_PATH "C:\VulkanSDK\${env:VULKAN_VERSION}\bin" + +# - name: Build +# id: cmake_build +# run: | +# mkdir build +# cd build +# cmake .. ${{ matrix.defines }} +# cmake --build . --config Release + +# - name: Check AVX512F support +# id: check_avx512f +# if: ${{ matrix.build == 'avx512' }} +# continue-on-error: true +# run: | +# cd build +# $vcdir = $(vswhere -latest -products * -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath) +# $msvc = $(join-path $vcdir $('VC\Tools\MSVC\'+$(gc -raw $(join-path $vcdir 'VC\Auxiliary\Build\Microsoft.VCToolsVersion.default.txt')).Trim())) +# $cl = $(join-path $msvc 'bin\Hostx64\x64\cl.exe') +# echo 'int main(void){unsigned int a[4];__cpuid(a,7);return !(a[1]&65536);}' >> avx512f.c +# & $cl /O2 /GS- /kernel avx512f.c /link /nodefaultlib /entry:main +# .\avx512f.exe && echo "AVX512F: YES" && ( echo HAS_AVX512F=1 >> $env:GITHUB_ENV ) || echo "AVX512F: NO" + +# - name: Get commit hash +# id: commit +# if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} +# uses: pr-mpt/actions-commit-hash@v2 + +# - name: Pack artifacts +# id: pack_artifacts +# if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} +# run: | +# $filePath = ".\build\bin\Release\*" +# if (Test-Path $filePath) { +# echo "Exists at path $filePath" +# Copy-Item ggml/LICENSE .\build\bin\Release\ggml.txt +# Copy-Item LICENSE .\build\bin\Release\stable-diffusion.cpp.txt +# } elseif (Test-Path ".\build\bin\stable-diffusion.dll") { +# $filePath = ".\build\bin\*" +# echo "Exists at path $filePath" +# Copy-Item ggml/LICENSE .\build\bin\ggml.txt +# Copy-Item LICENSE .\build\bin\stable-diffusion.cpp.txt +# } else { +# ls .\build\bin +# throw "Can't find stable-diffusion.dll" +# } +# 7z a sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-x64.zip $filePath + +# - name: Copy and pack Cuda runtime +# id: pack_cuda_runtime +# if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' && matrix.build == 'cuda12' ) || github.event.inputs.create_release == 'true' }} +# run: | +# echo "Cuda install location: ${{steps.cuda-toolkit.outputs.CUDA_PATH}}" +# $dst='.\build\bin\cudart\' +# robocopy "${{steps.cuda-toolkit.outputs.CUDA_PATH}}\bin" $dst cudart64_*.dll cublas64_*.dll cublasLt64_*.dll +# 7z a cudart-sd-bin-win-cu12-x64.zip $dst\* + +# - name: Upload Cuda runtime +# if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' && matrix.build == 'cuda12' ) || github.event.inputs.create_release == 'true' }} +# uses: actions/upload-artifact@v4 +# with: +# name: sd-cudart-sd-bin-win-cu12-x64.zip +# path: | +# cudart-sd-bin-win-cu12-x64.zip + +# - name: Upload artifacts +# if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} +# uses: actions/upload-artifact@v4 +# with: +# name: sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-x64.zip +# path: | +# sd-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-x64.zip + +# release: +# if: ${{ ( github.event_name == 'push' && github.ref == 'refs/heads/master' ) || github.event.inputs.create_release == 'true' }} + +# runs-on: ubuntu-latest + +# needs: +# - ubuntu-latest-cmake +# - macOS-latest-cmake +# - windows-latest-cmake + +# steps: +# - name: Download artifacts +# id: download-artifact +# uses: actions/download-artifact@v4 +# with: +# path: ./artifact +# pattern: sd-* +# merge-multiple: true + +# - name: Get commit hash +# id: commit +# uses: pr-mpt/actions-commit-hash@v2 + +# - name: Create release +# id: create_release +# uses: anzz1/action-create-release@v1 +# env: +# GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} +# with: +# tag_name: ${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }} + +# - name: Upload release +# id: upload_release +# uses: actions/github-script@v3 +# with: +# github-token: ${{secrets.GITHUB_TOKEN}} +# script: | +# const path = require('path'); +# const fs = require('fs'); +# const release_id = '${{ steps.create_release.outputs.id }}'; +# for (let file of await fs.readdirSync('./artifact')) { +# if (path.extname(file) === '.zip') { +# console.log('uploadReleaseAsset', file); +# await github.repos.uploadReleaseAsset({ +# owner: context.repo.owner, +# repo: context.repo.repo, +# release_id: release_id, +# name: file, +# data: await fs.readFileSync(`./artifact/${file}`) +# }); +# } +# } diff --git a/CMakeLists.txt b/CMakeLists.txt index c993e7c96..06de0d58b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -24,20 +24,21 @@ endif() # general #option(SD_BUILD_TESTS "sd: build tests" ${SD_STANDALONE}) option(SD_BUILD_EXAMPLES "sd: build examples" ${SD_STANDALONE}) -option(SD_CUBLAS "sd: cuda backend" OFF) +option(SD_CUDA "sd: cuda backend" OFF) option(SD_HIPBLAS "sd: rocm backend" OFF) option(SD_METAL "sd: metal backend" OFF) option(SD_VULKAN "sd: vulkan backend" OFF) +option(SD_OPENCL "sd: opencl backend" OFF) option(SD_SYCL "sd: sycl backend" OFF) -option(SD_FLASH_ATTN "sd: use flash attention for x4 less memory usage" OFF) +option(SD_MUSA "sd: musa backend" OFF) option(SD_FAST_SOFTMAX "sd: x1.5 faster softmax, indeterministic (sometimes, same seed don't generate same image), cuda only" OFF) option(SD_BUILD_SHARED_LIBS "sd: build shared libs" OFF) #option(SD_BUILD_SERVER "sd: build server example" ON) -if(SD_CUBLAS) - message("-- Use CUBLAS as backend stable-diffusion") +if(SD_CUDA) + message("-- Use CUDA as backend stable-diffusion") set(GGML_CUDA ON) - add_definitions(-DSD_USE_CUBLAS) + add_definitions(-DSD_USE_CUDA) endif() if(SD_METAL) @@ -52,23 +53,33 @@ if (SD_VULKAN) add_definitions(-DSD_USE_VULKAN) endif () +if (SD_OPENCL) + message("-- Use OpenCL as backend stable-diffusion") + set(GGML_OPENCL ON) + add_definitions(-DSD_USE_OPENCL) +endif () + if (SD_HIPBLAS) message("-- Use HIPBLAS as backend stable-diffusion") - set(GGML_HIPBLAS ON) - add_definitions(-DSD_USE_CUBLAS) + set(GGML_HIP ON) + add_definitions(-DSD_USE_CUDA) if(SD_FAST_SOFTMAX) set(GGML_CUDA_FAST_SOFTMAX ON) endif() endif () -if(SD_FLASH_ATTN) - message("-- Use Flash Attention for memory optimization") - add_definitions(-DSD_USE_FLASH_ATTENTION) +if(SD_MUSA) + message("-- Use MUSA as backend stable-diffusion") + set(GGML_MUSA ON) + add_definitions(-DSD_USE_CUDA) + if(SD_FAST_SOFTMAX) + set(GGML_CUDA_FAST_SOFTMAX ON) + endif() endif() set(SD_LIB stable-diffusion) -file(GLOB SD_LIB_SOURCES +file(GLOB SD_LIB_SOURCES "*.h" "*.cpp" "*.hpp" @@ -92,6 +103,7 @@ endif() if(SD_SYCL) message("-- Use SYCL as backend stable-diffusion") set(GGML_SYCL ON) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wno-narrowing -fsycl") add_definitions(-DSD_USE_SYCL) # disable fast-math on host, see: # https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-10/fp-model-fp.html diff --git a/Dockerfile.musa b/Dockerfile.musa new file mode 100644 index 000000000..c7f5f2e83 --- /dev/null +++ b/Dockerfile.musa @@ -0,0 +1,22 @@ +ARG MUSA_VERSION=rc3.1.1 + +FROM mthreads/musa:${MUSA_VERSION}-devel-ubuntu22.04 as build + +RUN apt-get update && apt-get install -y ccache cmake git + +WORKDIR /sd.cpp + +COPY . . + +RUN mkdir build && cd build && \ + cmake .. -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_FLAGS="${CMAKE_C_FLAGS} -fopenmp -I/usr/lib/llvm-14/lib/clang/14.0.0/include -L/usr/lib/llvm-14/lib" \ + -DCMAKE_CXX_FLAGS="${CMAKE_CXX_FLAGS} -fopenmp -I/usr/lib/llvm-14/lib/clang/14.0.0/include -L/usr/lib/llvm-14/lib" \ + -DSD_MUSA=ON -DCMAKE_BUILD_TYPE=Release && \ + cmake --build . --config Release + +FROM mthreads/musa:${MUSA_VERSION}-runtime-ubuntu22.04 as runtime + +COPY --from=build /sd.cpp/build/bin/sd /sd + +ENTRYPOINT [ "/sd" ] \ No newline at end of file diff --git a/README.md b/README.md index c1ba396fe..8ce98137f 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,8 @@ Inference of Stable Diffusion and Flux in pure C/C++ - SD1.x, SD2.x, SDXL and [SD3/SD3.5](./docs/sd3.md) support - !!!The VAE in SDXL encounters NaN issues under FP16, but unfortunately, the ggml_conv_2d only operates under FP16. Hence, a parameter is needed to specify the VAE that has fixed the FP16 NaN issue. You can find it here: [SDXL VAE FP16 Fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors). - [Flux-dev/Flux-schnell Support](./docs/flux.md) - +- [FLUX.1-Kontext-dev](./docs/kontext.md) +- [Chroma](./docs/chroma.md) - [SD-Turbo](https://huggingface.co/stabilityai/sd-turbo) and [SDXL-Turbo](https://huggingface.co/stabilityai/sdxl-turbo) support - [PhotoMaker](https://github.com/TencentARC/PhotoMaker) support. - 16-bit, 32-bit float support @@ -21,10 +22,10 @@ Inference of Stable Diffusion and Flux in pure C/C++ - Accelerated memory-efficient CPU inference - Only requires ~2.3GB when using txt2img with fp16 precision to generate a 512x512 image, enabling Flash Attention just requires ~1.8GB. - AVX, AVX2 and AVX512 support for x86 architectures -- Full CUDA, Metal, Vulkan and SYCL backend for GPU acceleration. +- Full CUDA, Metal, Vulkan, OpenCL and SYCL backend for GPU acceleration. - Can load ckpt, safetensors and diffusers models/checkpoints. Standalone VAEs models - No need to convert to `.ggml` or `.gguf` anymore! -- Flash Attention for memory usage optimization (only cpu for now) +- Flash Attention for memory usage optimization - Original `txt2img` and `img2img` mode - Negative prompt - [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) style tokenizer (not all the features, only token weighting for now) @@ -49,7 +50,7 @@ Inference of Stable Diffusion and Flux in pure C/C++ - Linux - Mac OS - Windows - - Android (via Termux) + - Android (via Termux, [Local Diffusion](https://github.com/rmatif/Local-Diffusion)) ### TODO @@ -113,12 +114,12 @@ cmake .. -DGGML_OPENBLAS=ON cmake --build . --config Release ``` -##### Using CUBLAS +##### Using CUDA This provides BLAS acceleration using the CUDA cores of your Nvidia GPU. Make sure to have the CUDA toolkit installed. You can download it from your Linux distro's package manager (e.g. `apt install nvidia-cuda-toolkit`) or from here: [CUDA Toolkit](https://developer.nvidia.com/cuda-downloads). Recommended to have at least 4 GB of VRAM. ``` -cmake .. -DSD_CUBLAS=ON +cmake .. -DSD_CUDA=ON cmake --build . --config Release ``` @@ -132,6 +133,14 @@ cmake .. -G "Ninja" -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DSD_H cmake --build . --config Release ``` +##### Using MUSA + +This provides BLAS acceleration using the MUSA cores of your Moore Threads GPU. Make sure to have the MUSA toolkit installed. + +```bash +cmake .. -DCMAKE_C_COMPILER=/usr/local/musa/bin/clang -DCMAKE_CXX_COMPILER=/usr/local/musa/bin/clang++ -DSD_MUSA=ON -DCMAKE_BUILD_TYPE=Release +cmake --build . --config Release +``` ##### Using Metal @@ -151,6 +160,73 @@ cmake .. -DSD_VULKAN=ON cmake --build . --config Release ``` +##### Using OpenCL (for Adreno GPU) + +Currently, it supports only Adreno GPUs and is primarily optimized for Q4_0 type + +To build for Windows ARM please refers to [Windows 11 Arm64 +](https://github.com/ggml-org/llama.cpp/blob/master/docs/backend/OPENCL.md#windows-11-arm64) + +Building for Android: + + Android NDK: + Download and install the Android NDK from the [official Android developer site](https://developer.android.com/ndk/downloads). + +Setup OpenCL Dependencies for NDK: + +You need to provide OpenCL headers and the ICD loader library to your NDK sysroot. + +* OpenCL Headers: + ```bash + # In a temporary working directory + git clone https://github.com/KhronosGroup/OpenCL-Headers + cd OpenCL-Headers + # Replace with your actual NDK installation path + # e.g., cp -r CL /path/to/android-ndk-r26c/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include + sudo cp -r CL /toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include + cd .. + ``` + +* OpenCL ICD Loader: + ```bash + # In the same temporary working directory + git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader + cd OpenCL-ICD-Loader + mkdir build_ndk && cd build_ndk + + # Replace in the CMAKE_TOOLCHAIN_FILE and OPENCL_ICD_LOADER_HEADERS_DIR + cmake .. -G Ninja -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_TOOLCHAIN_FILE=/build/cmake/android.toolchain.cmake \ + -DOPENCL_ICD_LOADER_HEADERS_DIR=/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/include \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=24 \ + -DANDROID_STL=c++_shared + + ninja + # Replace + # e.g., cp libOpenCL.so /path/to/android-ndk-r26c/toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android + sudo cp libOpenCL.so /toolchains/llvm/prebuilt/linux-x86_64/sysroot/usr/lib/aarch64-linux-android + cd ../.. + ``` + +Build `stable-diffusion.cpp` for Android with OpenCL: + +```bash +mkdir build-android && cd build-android + +# Replace with your actual NDK installation path +# e.g., -DCMAKE_TOOLCHAIN_FILE=/path/to/android-ndk-r26c/build/cmake/android.toolchain.cmake +cmake .. -G Ninja \ + -DCMAKE_TOOLCHAIN_FILE=/build/cmake/android.toolchain.cmake \ + -DANDROID_ABI=arm64-v8a \ + -DANDROID_PLATFORM=android-28 \ + -DGGML_OPENMP=OFF \ + -DSD_OPENCL=ON + +ninja +``` +*(Note: Don't forget to include `LD_LIBRARY_PATH=/vendor/lib64` in your command line before running the binary)* + ##### Using SYCL Using SYCL makes the computation run on the Intel GPU. Please make sure you have installed the related driver and [Intel® oneAPI Base toolkit](https://www.intel.com/content/www/us/en/developer/tools/oneapi/base-toolkit.html) before start. More details and steps can refer to [llama.cpp SYCL backend](https://github.com/ggerganov/llama.cpp/blob/master/docs/backend/SYCL.md#linux). @@ -182,11 +258,21 @@ Example of text2img by using SYCL backend: ##### Using Flash Attention -Enabling flash attention reduces memory usage by at least 400 MB. At the moment, it is not supported when CUBLAS is enabled because the kernel implementation is missing. +Enabling flash attention for the diffusion model reduces memory usage by varying amounts of MB. +eg.: + - flux 768x768 ~600mb + - SD2 768x768 ~1400mb + +For most backends, it slows things down, but for cuda it generally speeds it up too. +At the moment, it is only supported for some models and some backends (like cpu, cuda/rocm, metal). +Run by adding `--diffusion-fa` to the arguments and watch for: ``` -cmake .. -DSD_FLASH_ATTN=ON -cmake --build . --config Release +[INFO ] stable-diffusion.cpp:312 - Using flash attention in the diffusion model +``` +and the compute buffer shrink in the debug log: +``` +[DEBUG] ggml_extend.hpp:1004 - flux compute buffer size: 650.00 MB(VRAM) ``` ### Run @@ -196,14 +282,14 @@ usage: ./bin/sd [arguments] arguments: -h, --help show this help message and exit - -M, --mode [MODEL] run mode (txt2img or img2img or convert, default: txt2img) + -M, --mode [MODE] run mode, one of: [img_gen, convert], default: img_gen -t, --threads N number of threads to use during computation (default: -1) If threads <= 0, then threads will be set to the number of CPU physical cores -m, --model [MODEL] path to full model --diffusion-model path to the standalone diffusion model --clip_l path to the clip-l text encoder - --clip_g path to the clip-l text encoder - --t5xxl path to the the t5xxl text encoder + --clip_g path to the clip-g text encoder + --t5xxl path to the t5xxl text encoder --vae [VAE] path to vae --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) --control-net [CONTROL_PATH] path to control net model @@ -213,22 +299,34 @@ arguments: --normalize-input normalize PHOTOMAKER input id images --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now --upscale-repeats Run the ESRGAN upscaler this many times (default 1) - --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k) + --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K) If not specified, the default is the type of the weight file + --tensor-type-rules [EXPRESSION] weight type per tensor pattern (example: "^vae\.=f16,model\.=q8_0") --lora-model-dir [DIR] lora model directory -i, --init-img [IMAGE] path to the input image, required by img2img + --mask [MASK] path to the mask image, required by img2img with mask --control-image [IMAGE] path to image condition, control net + -r, --ref-image [PATH] reference image for Flux Kontext models (can be used multiple times) -o, --output OUTPUT path to write result image to (default: ./output.png) -p, --prompt [PROMPT] the prompt to render -n, --negative-prompt PROMPT the negative prompt (default: "") --cfg-scale SCALE unconditional guidance scale: (default: 7.0) + --img-cfg-scale SCALE image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale) + --guidance SCALE distilled guidance scale for models with guidance input (default: 3.5) + --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0) + 0 means disabled, a value of 2.5 is nice for sd3.5 medium + --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0) + --skip-layers LAYERS Layers to skip for SLG steps: (default: [7,8,9]) + --skip-layer-start START SLG enabling point: (default: 0.01) + --skip-layer-end END SLG disabling point: (default: 0.2) + SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END]) --strength STRENGTH strength for noising/unnoising (default: 0.75) - --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%) + --style-ratio STYLE-RATIO strength for keeping input identity (default: 20) --control-strength STRENGTH strength to apply Control Net (default: 0.9) 1.0 corresponds to full destruction of information in init image -H, --height H image height, in pixel space (default: 512) -W, --width W image width, in pixel space (default: 512) - --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm} + --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd} sampling method (default: "euler_a") --steps STEPS number of sample steps (default: 20) --rng {std_default, cuda} RNG (default: cuda) @@ -240,9 +338,15 @@ arguments: --vae-tiling process vae in tiles to reduce memory usage --vae-on-cpu keep vae in cpu (for low vram) --clip-on-cpu keep clip in cpu (for low vram) + --diffusion-fa use flash attention in the diffusion model (for low vram) + Might lower quality, since it implies converting k and v to f16. + This might crash if it is not supported by the backend. --control-net-cpu keep controlnet in cpu (for low vram) --canny apply canny preprocessor (edge detection) - --color Colors the logging tags according to level + --color colors the logging tags according to level + --chroma-disable-dit-mask disable dit mask for chroma + --chroma-enable-t5-mask enable t5 mask for chroma + --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma -v, --verbose print extra info ``` @@ -269,7 +373,7 @@ Using formats of different precisions will yield results of varying quality. ``` -./bin/sd --mode img2img -m ../models/sd-v1-4.ckpt -p "cat with blue eyes" -i ./output.png -o ./img2img_output.png --strength 0.4 +./bin/sd -m ../models/sd-v1-4.ckpt -p "cat with blue eyes" -i ./output.png -o ./img2img_output.png --strength 0.4 ```

@@ -290,14 +394,21 @@ Using formats of different precisions will yield results of varying quality. These projects wrap `stable-diffusion.cpp` for easier use in other languages/frameworks. -* Golang: [seasonjs/stable-diffusion](https://github.com/seasonjs/stable-diffusion) +* Golang (non-cgo): [seasonjs/stable-diffusion](https://github.com/seasonjs/stable-diffusion) +* Golang (cgo): [Binozo/GoStableDiffusion](https://github.com/Binozo/GoStableDiffusion) * C#: [DarthAffe/StableDiffusion.NET](https://github.com/DarthAffe/StableDiffusion.NET) +* Python: [william-murray1204/stable-diffusion-cpp-python](https://github.com/william-murray1204/stable-diffusion-cpp-python) +* Rust: [newfla/diffusion-rs](https://github.com/newfla/diffusion-rs) +* Flutter/Dart: [rmatif/Local-Diffusion](https://github.com/rmatif/Local-Diffusion) ## UIs These projects use `stable-diffusion.cpp` as a backend for their image generation. - [Jellybox](https://jellybox.com) +- [Stable Diffusion GUI](https://github.com/fszontagh/sd.cpp.gui.wx) +- [Stable Diffusion CLI-GUI](https://github.com/piallai/stable-diffusion.cpp) +- [Local Diffusion](https://github.com/rmatif/Local-Diffusion) ## Contributors diff --git a/assets/flux/chroma_v40.png b/assets/flux/chroma_v40.png new file mode 100644 index 000000000..4217009dc Binary files /dev/null and b/assets/flux/chroma_v40.png differ diff --git a/assets/flux/kontext1_dev_output.png b/assets/flux/kontext1_dev_output.png new file mode 100644 index 000000000..4fa5e38dd Binary files /dev/null and b/assets/flux/kontext1_dev_output.png differ diff --git a/clip.hpp b/clip.hpp index f9ac631a8..d359f61cd 100644 --- a/clip.hpp +++ b/clip.hpp @@ -343,6 +343,13 @@ class CLIPTokenizer { } } + std::string clean_up_tokenization(std::string& text) { + std::regex pattern(R"( ,)"); + // Replace " ," with "," + std::string result = std::regex_replace(text, pattern, ","); + return result; + } + std::string decode(const std::vector& tokens) { std::string text = ""; for (int t : tokens) { @@ -351,8 +358,12 @@ class CLIPTokenizer { std::u32string ts = decoder[t]; // printf("%d, %s \n", t, utf32_to_utf8(ts).c_str()); std::string s = utf32_to_utf8(ts); - if (s.length() >= 4 && ends_with(s, "")) { - text += " " + s.replace(s.length() - 4, s.length() - 1, ""); + if (s.length() >= 4) { + if (ends_with(s, "")) { + text += s.replace(s.length() - 4, s.length() - 1, "") + " "; + } else { + text += s; + } } else { text += " " + s; } @@ -364,6 +375,7 @@ class CLIPTokenizer { // std::string s((char *)bytes.data()); // std::string s = ""; + text = clean_up_tokenization(text); return trim(text); } @@ -533,9 +545,12 @@ class CLIPEmbeddings : public GGMLBlock { int64_t vocab_size; int64_t num_positions; - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, wtype, embed_dim, vocab_size); - params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, embed_dim, num_positions); + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + enum ggml_type token_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "token_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "token_embedding.weight"] : GGML_TYPE_F32; + enum ggml_type position_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end()) ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32; + + params["token_embedding.weight"] = ggml_new_tensor_2d(ctx, token_wtype, embed_dim, vocab_size); + params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, position_wtype, embed_dim, num_positions); } public: @@ -579,11 +594,14 @@ class CLIPVisionEmbeddings : public GGMLBlock { int64_t image_size; int64_t num_patches; int64_t num_positions; + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + enum ggml_type patch_wtype = GGML_TYPE_F16; // tensor_types.find(prefix + "patch_embedding.weight") != tensor_types.end() ? tensor_types[prefix + "patch_embedding.weight"] : GGML_TYPE_F16; + enum ggml_type class_wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "class_embedding") != tensor_types.end() ? tensor_types[prefix + "class_embedding"] : GGML_TYPE_F32; + enum ggml_type position_wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "position_embedding.weight") != tensor_types.end() ? tensor_types[prefix + "position_embedding.weight"] : GGML_TYPE_F32; - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["patch_embedding.weight"] = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, patch_size, patch_size, num_channels, embed_dim); - params["class_embedding"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, embed_dim); - params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, embed_dim, num_positions); + params["patch_embedding.weight"] = ggml_new_tensor_4d(ctx, patch_wtype, patch_size, patch_size, num_channels, embed_dim); + params["class_embedding"] = ggml_new_tensor_1d(ctx, class_wtype, embed_dim); + params["position_embedding.weight"] = ggml_new_tensor_2d(ctx, position_wtype, embed_dim, num_positions); } public: @@ -639,9 +657,10 @@ enum CLIPVersion { class CLIPTextModel : public GGMLBlock { protected: - void init_params(struct ggml_context* ctx, ggml_type wtype) { + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { if (version == OPEN_CLIP_VIT_BIGG_14) { - params["text_projection"] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, projection_dim, hidden_size); + enum ggml_type wtype = GGML_TYPE_F32; // tensor_types.find(prefix + "text_projection") != tensor_types.end() ? tensor_types[prefix + "text_projection"] : GGML_TYPE_F32; + params["text_projection"] = ggml_new_tensor_2d(ctx, wtype, projection_dim, hidden_size); } } @@ -659,8 +678,8 @@ class CLIPTextModel : public GGMLBlock { bool with_final_ln = true; CLIPTextModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14, - int clip_skip_value = -1, - bool with_final_ln = true) + bool with_final_ln = true, + int clip_skip_value = -1) : version(version), with_final_ln(with_final_ln) { if (version == OPEN_CLIP_VIT_H_14) { hidden_size = 1024; @@ -682,7 +701,7 @@ class CLIPTextModel : public GGMLBlock { void set_clip_skip(int skip) { if (skip <= 0) { - return; + skip = -1; } clip_skip = skip; } @@ -711,8 +730,12 @@ class CLIPTextModel : public GGMLBlock { if (return_pooled) { auto text_projection = params["text_projection"]; ggml_tensor* pooled = ggml_view_1d(ctx, x, hidden_size, x->nb[1] * max_token_idx); - pooled = ggml_mul_mat(ctx, ggml_cont(ctx, ggml_transpose(ctx, text_projection)), pooled); - return pooled; + if (text_projection != NULL) { + pooled = ggml_nn_linear(ctx, pooled, text_projection, NULL); + } else { + LOG_DEBUG("Missing text_projection matrix, assuming identity..."); + } + return pooled; // [hidden_size, 1, 1] } return x; // [N, n_token, hidden_size] @@ -761,14 +784,17 @@ class CLIPVisionModel : public GGMLBlock { auto x = embeddings->forward(ctx, pixel_values); // [N, num_positions, embed_dim] x = pre_layernorm->forward(ctx, x); x = encoder->forward(ctx, x, -1, false); - x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size] + // print_ggml_tensor(x, true, "ClipVisionModel x: "); + auto last_hidden_state = x; + x = post_layernorm->forward(ctx, x); // [N, n_token, hidden_size] GGML_ASSERT(x->ne[3] == 1); if (return_pooled) { ggml_tensor* pooled = ggml_cont(ctx, ggml_view_2d(ctx, x, x->ne[0], x->ne[2], x->nb[2], 0)); return pooled; // [N, hidden_size] } else { - return x; // [N, n_token, hidden_size] + // return x; // [N, n_token, hidden_size] + return last_hidden_state; // [N, n_token, hidden_size] } } }; @@ -779,9 +805,9 @@ class CLIPProjection : public UnaryBlock { int64_t out_features; bool transpose_weight; - void init_params(struct ggml_context* ctx, ggml_type wtype) { + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + enum ggml_type wtype = tensor_types.find(prefix + "weight") != tensor_types.end() ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; if (transpose_weight) { - LOG_ERROR("transpose_weight"); params["weight"] = ggml_new_tensor_2d(ctx, wtype, out_features, in_features); } else { params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features); @@ -842,12 +868,13 @@ struct CLIPTextModelRunner : public GGMLRunner { CLIPTextModel model; CLIPTextModelRunner(ggml_backend_t backend, - ggml_type wtype, + std::map& tensor_types, + const std::string prefix, CLIPVersion version = OPENAI_CLIP_VIT_L_14, - int clip_skip_value = 1, - bool with_final_ln = true) - : GGMLRunner(backend, wtype), model(version, clip_skip_value, with_final_ln) { - model.init(params_ctx, wtype); + bool with_final_ln = true, + int clip_skip_value = -1) + : GGMLRunner(backend), model(version, with_final_ln, clip_skip_value) { + model.init(params_ctx, tensor_types, prefix); } std::string get_desc() { @@ -889,13 +916,13 @@ struct CLIPTextModelRunner : public GGMLRunner { struct ggml_tensor* embeddings = NULL; if (num_custom_embeddings > 0 && custom_embeddings_data != NULL) { - auto custom_embeddings = ggml_new_tensor_2d(compute_ctx, - wtype, - model.hidden_size, - num_custom_embeddings); + auto token_embed_weight = model.get_token_embed_weight(); + auto custom_embeddings = ggml_new_tensor_2d(compute_ctx, + token_embed_weight->type, + model.hidden_size, + num_custom_embeddings); set_backend_tensor_data(custom_embeddings, custom_embeddings_data); - auto token_embed_weight = model.get_token_embed_weight(); // concatenate custom embeddings embeddings = ggml_concat(compute_ctx, token_embed_weight, custom_embeddings, 1); } diff --git a/common.hpp b/common.hpp index b18ee51f5..9b5cc53be 100644 --- a/common.hpp +++ b/common.hpp @@ -56,8 +56,8 @@ class UpSampleBlock : public GGMLBlock { // x: [N, channels, h, w] auto conv = std::dynamic_pointer_cast(blocks["conv"]); - x = ggml_upscale(ctx, x, 2); // [N, channels, h*2, w*2] - x = conv->forward(ctx, x); // [N, out_channels, h*2, w*2] + x = ggml_upscale(ctx, x, 2, GGML_SCALE_MODE_NEAREST); // [N, channels, h*2, w*2] + x = conv->forward(ctx, x); // [N, out_channels, h*2, w*2] return x; } }; @@ -182,9 +182,11 @@ class GEGLU : public GGMLBlock { int64_t dim_in; int64_t dim_out; - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2); - params["proj.bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, dim_out * 2); + void init_params(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { + enum ggml_type wtype = (tensor_types.find(prefix + "proj.weight") != tensor_types.end()) ? tensor_types[prefix + "proj.weight"] : GGML_TYPE_F32; + enum ggml_type bias_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "proj.bias") != tensor_types.end()) ? tensor_types[prefix + "proj.bias"] : GGML_TYPE_F32; + params["proj.weight"] = ggml_new_tensor_2d(ctx, wtype, dim_in, dim_out * 2); + params["proj.bias"] = ggml_new_tensor_1d(ctx, bias_wtype, dim_out * 2); } public: @@ -245,16 +247,19 @@ class CrossAttention : public GGMLBlock { int64_t context_dim; int64_t n_head; int64_t d_head; + bool flash_attn; public: CrossAttention(int64_t query_dim, int64_t context_dim, int64_t n_head, - int64_t d_head) + int64_t d_head, + bool flash_attn = false) : n_head(n_head), d_head(d_head), query_dim(query_dim), - context_dim(context_dim) { + context_dim(context_dim), + flash_attn(flash_attn) { int64_t inner_dim = d_head * n_head; blocks["to_q"] = std::shared_ptr(new Linear(query_dim, inner_dim, false)); @@ -283,7 +288,7 @@ class CrossAttention : public GGMLBlock { auto k = to_k->forward(ctx, context); // [N, n_context, inner_dim] auto v = to_v->forward(ctx, context); // [N, n_context, inner_dim] - x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false); // [N, n_token, inner_dim] + x = ggml_nn_attention_ext(ctx, q, k, v, n_head, NULL, false, false, flash_attn); // [N, n_token, inner_dim] x = to_out_0->forward(ctx, x); // [N, n_token, query_dim] return x; @@ -301,15 +306,16 @@ class BasicTransformerBlock : public GGMLBlock { int64_t n_head, int64_t d_head, int64_t context_dim, - bool ff_in = false) + bool ff_in = false, + bool flash_attn = false) : n_head(n_head), d_head(d_head), ff_in(ff_in) { // disable_self_attn is always False // disable_temporal_crossattention is always False // switch_temporal_ca_to_sa is always False // inner_dim is always None or equal to dim // gated_ff is always True - blocks["attn1"] = std::shared_ptr(new CrossAttention(dim, dim, n_head, d_head)); - blocks["attn2"] = std::shared_ptr(new CrossAttention(dim, context_dim, n_head, d_head)); + blocks["attn1"] = std::shared_ptr(new CrossAttention(dim, dim, n_head, d_head, flash_attn)); + blocks["attn2"] = std::shared_ptr(new CrossAttention(dim, context_dim, n_head, d_head, flash_attn)); blocks["ff"] = std::shared_ptr(new FeedForward(dim, dim)); blocks["norm1"] = std::shared_ptr(new LayerNorm(dim)); blocks["norm2"] = std::shared_ptr(new LayerNorm(dim)); @@ -374,7 +380,8 @@ class SpatialTransformer : public GGMLBlock { int64_t n_head, int64_t d_head, int64_t depth, - int64_t context_dim) + int64_t context_dim, + bool flash_attn = false) : in_channels(in_channels), n_head(n_head), d_head(d_head), @@ -388,7 +395,7 @@ class SpatialTransformer : public GGMLBlock { for (int i = 0; i < depth; i++) { std::string name = "transformer_blocks." + std::to_string(i); - blocks[name] = std::shared_ptr(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim)); + blocks[name] = std::shared_ptr(new BasicTransformerBlock(inner_dim, n_head, d_head, context_dim, false, flash_attn)); } blocks["proj_out"] = std::shared_ptr(new Conv2d(inner_dim, in_channels, {1, 1})); @@ -433,8 +440,10 @@ class SpatialTransformer : public GGMLBlock { class AlphaBlender : public GGMLBlock { protected: - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["mix_factor"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + void init_params(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { + // Get the type of the "mix_factor" tensor from the input tensors map with the specified prefix + enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "mix_factor") != tensor_types.end()) ? tensor_types[prefix + "mix_factor"] : GGML_TYPE_F32; + params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1); } float get_alpha() { @@ -511,4 +520,4 @@ class VideoResBlock : public ResBlock { } }; -#endif // __COMMON_HPP__ \ No newline at end of file +#endif // __COMMON_HPP__ diff --git a/conditioner.hpp b/conditioner.hpp index ac2ab7ebf..3f89d5263 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -43,71 +43,74 @@ struct Conditioner { // ldm.modules.encoders.modules.FrozenCLIPEmbedder // Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/sd_hijack_clip.py#L283 struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { - SDVersion version = VERSION_SD1; + SDVersion version = VERSION_SD1; + PMVersion pm_version = PM_VERSION_1; CLIPTokenizer tokenizer; - ggml_type wtype; std::shared_ptr text_model; std::shared_ptr text_model2; std::string trigger_word = "img"; // should be user settable std::string embd_dir; - int32_t num_custom_embeddings = 0; + int32_t num_custom_embeddings = 0; + int32_t num_custom_embeddings_2 = 0; std::vector token_embed_custom; std::vector readed_embeddings; FrozenCLIPEmbedderWithCustomWords(ggml_backend_t backend, - ggml_type wtype, + std::map& tensor_types, const std::string& embd_dir, SDVersion version = VERSION_SD1, + PMVersion pv = PM_VERSION_1, int clip_skip = -1) - : version(version), tokenizer(version == VERSION_SD2 ? 0 : 49407), embd_dir(embd_dir), wtype(wtype) { + : version(version), pm_version(pv), tokenizer(sd_version_is_sd2(version) ? 0 : 49407), embd_dir(embd_dir) { + if (sd_version_is_sd1(version)) { + text_model = std::make_shared(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14); + } else if (sd_version_is_sd2(version)) { + text_model = std::make_shared(backend, tensor_types, "cond_stage_model.transformer.text_model", OPEN_CLIP_VIT_H_14); + } else if (sd_version_is_sdxl(version)) { + text_model = std::make_shared(backend, tensor_types, "cond_stage_model.transformer.text_model", OPENAI_CLIP_VIT_L_14, false); + text_model2 = std::make_shared(backend, tensor_types, "cond_stage_model.1.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false); + } + set_clip_skip(clip_skip); + } + + void set_clip_skip(int clip_skip) { if (clip_skip <= 0) { clip_skip = 1; - if (version == VERSION_SD2 || version == VERSION_SDXL) { + if (sd_version_is_sd2(version) || sd_version_is_sdxl(version)) { clip_skip = 2; } } - if (version == VERSION_SD1) { - text_model = std::make_shared(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip); - } else if (version == VERSION_SD2) { - text_model = std::make_shared(backend, wtype, OPEN_CLIP_VIT_H_14, clip_skip); - } else if (version == VERSION_SDXL) { - text_model = std::make_shared(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false); - text_model2 = std::make_shared(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false); - } - } - - void set_clip_skip(int clip_skip) { text_model->set_clip_skip(clip_skip); - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { text_model2->set_clip_skip(clip_skip); } } void get_param_tensors(std::map& tensors) { text_model->get_param_tensors(tensors, "cond_stage_model.transformer.text_model"); - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { text_model2->get_param_tensors(tensors, "cond_stage_model.1.transformer.text_model"); } } void alloc_params_buffer() { text_model->alloc_params_buffer(); - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { text_model2->alloc_params_buffer(); } } void free_params_buffer() { text_model->free_params_buffer(); - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { text_model2->free_params_buffer(); } } size_t get_params_buffer_size() { size_t buffer_size = text_model->get_params_buffer_size(); - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { buffer_size += text_model2->get_params_buffer_size(); } return buffer_size; @@ -130,28 +133,55 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { params.no_alloc = false; struct ggml_context* embd_ctx = ggml_init(params); struct ggml_tensor* embd = NULL; - int64_t hidden_size = text_model->model.hidden_size; + struct ggml_tensor* embd2 = NULL; auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) { - if (tensor_storage.ne[0] != hidden_size) { - LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], hidden_size); - return false; + if (tensor_storage.ne[0] != text_model->model.hidden_size) { + if (text_model2) { + if (tensor_storage.ne[0] == text_model2->model.hidden_size) { + embd2 = ggml_new_tensor_2d(embd_ctx, tensor_storage.type, text_model2->model.hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1); + *dst_tensor = embd2; + } else { + LOG_DEBUG("embedding wrong hidden size, got %i, expected %i or %i", tensor_storage.ne[0], text_model->model.hidden_size, text_model2->model.hidden_size); + return false; + } + } else { + LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], text_model->model.hidden_size); + return false; + } + } else { + embd = ggml_new_tensor_2d(embd_ctx, tensor_storage.type, text_model->model.hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1); + *dst_tensor = embd; } - embd = ggml_new_tensor_2d(embd_ctx, wtype, hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1); - *dst_tensor = embd; return true; }; model_loader.load_tensors(on_load, NULL); readed_embeddings.push_back(embd_name); - token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd)); - memcpy((void*)(token_embed_custom.data() + num_custom_embeddings * hidden_size * ggml_type_size(wtype)), - embd->data, - ggml_nbytes(embd)); - for (int i = 0; i < embd->ne[1]; i++) { - bpe_tokens.push_back(text_model->model.vocab_size + num_custom_embeddings); - // LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings); - num_custom_embeddings++; + if (embd) { + int64_t hidden_size = text_model->model.hidden_size; + token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd)); + memcpy((void*)(token_embed_custom.data() + num_custom_embeddings * hidden_size * ggml_type_size(embd->type)), + embd->data, + ggml_nbytes(embd)); + for (int i = 0; i < embd->ne[1]; i++) { + bpe_tokens.push_back(text_model->model.vocab_size + num_custom_embeddings); + // LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings); + num_custom_embeddings++; + } + LOG_DEBUG("embedding '%s' applied, custom embeddings: %i", embd_name.c_str(), num_custom_embeddings); + } + if (embd2) { + int64_t hidden_size = text_model2->model.hidden_size; + token_embed_custom.resize(token_embed_custom.size() + ggml_nbytes(embd2)); + memcpy((void*)(token_embed_custom.data() + num_custom_embeddings_2 * hidden_size * ggml_type_size(embd2->type)), + embd2->data, + ggml_nbytes(embd2)); + for (int i = 0; i < embd2->ne[1]; i++) { + bpe_tokens.push_back(text_model2->model.vocab_size + num_custom_embeddings_2); + // LOG_DEBUG("new custom token: %i", text_model.vocab_size + num_custom_embeddings); + num_custom_embeddings_2++; + } + LOG_DEBUG("embedding '%s' applied, custom embeddings: %i (text model 2)", embd_name.c_str(), num_custom_embeddings_2); } - LOG_DEBUG("embedding '%s' applied, custom embeddings: %i", embd_name.c_str(), num_custom_embeddings); return true; } @@ -268,7 +298,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { std::vector clean_input_ids_tmp; for (uint32_t i = 0; i < class_token_index[0]; i++) clean_input_ids_tmp.push_back(clean_input_ids[i]); - for (uint32_t i = 0; i < num_input_imgs; i++) + for (uint32_t i = 0; i < (pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs); i++) clean_input_ids_tmp.push_back(class_token); for (uint32_t i = class_token_index[0] + 1; i < clean_input_ids.size(); i++) clean_input_ids_tmp.push_back(clean_input_ids[i]); @@ -279,13 +309,16 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { tokens.insert(tokens.end(), clean_input_ids.begin(), clean_input_ids.end()); weights.insert(weights.end(), clean_input_ids.size(), curr_weight); } - tokens.insert(tokens.begin(), tokenizer.BOS_TOKEN_ID); - weights.insert(weights.begin(), 1.0); + // BUG!! double couting, pad_tokens will add BOS at the beginning + // tokens.insert(tokens.begin(), tokenizer.BOS_TOKEN_ID); + // weights.insert(weights.begin(), 1.0); tokenizer.pad_tokens(tokens, weights, max_length, padding); - + int offset = pm_version == PM_VERSION_2 ? 2 * num_input_imgs : num_input_imgs; for (uint32_t i = 0; i < tokens.size(); i++) { - if (class_idx + 1 <= i && i < class_idx + 1 + num_input_imgs) + // if (class_idx + 1 <= i && i < class_idx + 1 + 2*num_input_imgs) // photomaker V2 has num_tokens(=2)*num_input_imgs + if (class_idx + 1 <= i && i < class_idx + 1 + offset) // photomaker V2 has num_tokens(=2)*num_input_imgs + // hardcode for now class_token_mask.push_back(true); else class_token_mask.push_back(false); @@ -398,7 +431,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); struct ggml_tensor* input_ids2 = NULL; size_t max_token_idx = 0; - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), tokenizer.EOS_TOKEN_ID); if (it != chunk_tokens.end()) { std::fill(std::next(it), chunk_tokens.end(), 0); @@ -423,11 +456,11 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { false, &chunk_hidden_states1, work_ctx); - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { text_model2->compute(n_threads, input_ids2, - 0, - NULL, + num_custom_embeddings, + token_embed_custom.data(), max_token_idx, false, &chunk_hidden_states2, work_ctx); @@ -437,8 +470,8 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { if (chunk_idx == 0) { text_model2->compute(n_threads, input_ids2, - 0, - NULL, + num_custom_embeddings, + token_embed_custom.data(), max_token_idx, true, &pooled, @@ -482,7 +515,7 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]); ggml_tensor* vec = NULL; - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { int out_dim = 256; vec = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, adm_in_channels); // [0:1280] @@ -585,9 +618,9 @@ struct FrozenCLIPEmbedderWithCustomWords : public Conditioner { struct FrozenCLIPVisionEmbedder : public GGMLRunner { CLIPVisionModelProjection vision_model; - FrozenCLIPVisionEmbedder(ggml_backend_t backend, ggml_type wtype) - : vision_model(OPEN_CLIP_VIT_H_14, true), GGMLRunner(backend, wtype) { - vision_model.init(params_ctx, wtype); + FrozenCLIPVisionEmbedder(ggml_backend_t backend, std::map& tensor_types) + : vision_model(OPEN_CLIP_VIT_H_14, true), GGMLRunner(backend) { + vision_model.init(params_ctx, tensor_types, "cond_stage_model.transformer"); } std::string get_desc() { @@ -622,7 +655,6 @@ struct FrozenCLIPVisionEmbedder : public GGMLRunner { }; struct SD3CLIPEmbedder : public Conditioner { - ggml_type wtype; CLIPTokenizer clip_l_tokenizer; CLIPTokenizer clip_g_tokenizer; T5UniGramTokenizer t5_tokenizer; @@ -631,18 +663,19 @@ struct SD3CLIPEmbedder : public Conditioner { std::shared_ptr t5; SD3CLIPEmbedder(ggml_backend_t backend, - ggml_type wtype, + std::map& tensor_types, int clip_skip = -1) - : wtype(wtype), clip_g_tokenizer(0) { - if (clip_skip <= 0) { - clip_skip = 2; - } - clip_l = std::make_shared(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, false); - clip_g = std::make_shared(backend, wtype, OPEN_CLIP_VIT_BIGG_14, clip_skip, false); - t5 = std::make_shared(backend, wtype); + : clip_g_tokenizer(0) { + clip_l = std::make_shared(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, false); + clip_g = std::make_shared(backend, tensor_types, "text_encoders.clip_g.transformer.text_model", OPEN_CLIP_VIT_BIGG_14, false); + t5 = std::make_shared(backend, tensor_types, "text_encoders.t5xxl.transformer"); + set_clip_skip(clip_skip); } void set_clip_skip(int clip_skip) { + if (clip_skip <= 0) { + clip_skip = 2; + } clip_l->set_clip_skip(clip_skip); clip_g->set_clip_skip(clip_skip); } @@ -716,7 +749,7 @@ struct SD3CLIPEmbedder : public Conditioner { clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, max_length, padding); clip_g_tokenizer.pad_tokens(clip_g_tokens, clip_g_weights, max_length, padding); - t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding); + t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding); // for (int i = 0; i < clip_l_tokens.size(); i++) { // std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", "; @@ -798,21 +831,16 @@ struct SD3CLIPEmbedder : public Conditioner { } if (chunk_idx == 0) { - // auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); - // max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); - // clip_l->compute(n_threads, - // input_ids, - // 0, - // NULL, - // max_token_idx, - // true, - // &pooled_l, - // work_ctx); - - // clip_l.transformer.text_model.text_projection no in file, ignore - // TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection - pooled_l = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768); - ggml_set_f32(pooled_l, 0.f); + auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); + max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); + clip_l->compute(n_threads, + input_ids, + 0, + NULL, + max_token_idx, + true, + &pooled_l, + work_ctx); } } @@ -852,21 +880,16 @@ struct SD3CLIPEmbedder : public Conditioner { } if (chunk_idx == 0) { - // auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID); - // max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); - // clip_g->compute(n_threads, - // input_ids, - // 0, - // NULL, - // max_token_idx, - // true, - // &pooled_g, - // work_ctx); - // clip_l.transformer.text_model.text_projection no in file, ignore pooled_g too - - // TODO: fix pooled_g - pooled_g = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1280); - ggml_set_f32(pooled_g, 0.f); + auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_g_tokenizer.EOS_TOKEN_ID); + max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); + clip_g->compute(n_threads, + input_ids, + 0, + NULL, + max_token_idx, + true, + &pooled_g, + work_ctx); } } @@ -881,6 +904,7 @@ struct SD3CLIPEmbedder : public Conditioner { t5->compute(n_threads, input_ids, + NULL, &chunk_hidden_states_t5, work_ctx); { @@ -979,24 +1003,24 @@ struct SD3CLIPEmbedder : public Conditioner { }; struct FluxCLIPEmbedder : public Conditioner { - ggml_type wtype; CLIPTokenizer clip_l_tokenizer; T5UniGramTokenizer t5_tokenizer; std::shared_ptr clip_l; std::shared_ptr t5; + size_t chunk_len = 256; FluxCLIPEmbedder(ggml_backend_t backend, - ggml_type wtype, - int clip_skip = -1) - : wtype(wtype) { - if (clip_skip <= 0) { - clip_skip = 2; - } - clip_l = std::make_shared(backend, wtype, OPENAI_CLIP_VIT_L_14, clip_skip, true); - t5 = std::make_shared(backend, wtype); + std::map& tensor_types, + int clip_skip = -1) { + clip_l = std::make_shared(backend, tensor_types, "text_encoders.clip_l.transformer.text_model", OPENAI_CLIP_VIT_L_14, true); + t5 = std::make_shared(backend, tensor_types, "text_encoders.t5xxl.transformer"); + set_clip_skip(clip_skip); } void set_clip_skip(int clip_skip) { + if (clip_skip <= 0) { + clip_skip = 2; + } clip_l->set_clip_skip(clip_skip); } @@ -1058,7 +1082,7 @@ struct FluxCLIPEmbedder : public Conditioner { } clip_l_tokenizer.pad_tokens(clip_l_tokens, clip_l_weights, 77, padding); - t5_tokenizer.pad_tokens(t5_tokens, t5_weights, max_length, padding); + t5_tokenizer.pad_tokens(t5_tokens, t5_weights, NULL, max_length, padding); // for (int i = 0; i < clip_l_tokens.size(); i++) { // std::cout << clip_l_tokens[i] << ":" << clip_l_weights[i] << ", "; @@ -1090,7 +1114,6 @@ struct FluxCLIPEmbedder : public Conditioner { struct ggml_tensor* pooled = NULL; // [768,] std::vector hidden_states_vec; - size_t chunk_len = 256; size_t chunk_count = t5_tokens.size() / chunk_len; for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) { // clip_l @@ -1104,21 +1127,17 @@ struct FluxCLIPEmbedder : public Conditioner { auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); size_t max_token_idx = 0; - // auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); - // max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); - // clip_l->compute(n_threads, - // input_ids, - // 0, - // NULL, - // max_token_idx, - // true, - // &pooled, - // work_ctx); - - // clip_l.transformer.text_model.text_projection no in file, ignore - // TODO: use torch.eye(embed_dim) as default clip_l.transformer.text_model.text_projection - pooled = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 768); - ggml_set_f32(pooled, 0.f); + auto it = std::find(chunk_tokens.begin(), chunk_tokens.end(), clip_l_tokenizer.EOS_TOKEN_ID); + max_token_idx = std::min(std::distance(chunk_tokens.begin(), it), chunk_tokens.size() - 1); + + clip_l->compute(n_threads, + input_ids, + 0, + NULL, + max_token_idx, + true, + &pooled, + work_ctx); } // t5 @@ -1132,6 +1151,7 @@ struct FluxCLIPEmbedder : public Conditioner { t5->compute(n_threads, input_ids, + NULL, &chunk_hidden_states, work_ctx); { @@ -1181,7 +1201,209 @@ struct FluxCLIPEmbedder : public Conditioner { int height, int adm_in_channels = -1, bool force_zero_embeddings = false) { - auto tokens_and_weights = tokenize(text, 256, true); + auto tokens_and_weights = tokenize(text, chunk_len, true); + return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings); + } + + std::tuple> get_learned_condition_with_trigger(ggml_context* work_ctx, + int n_threads, + const std::string& text, + int clip_skip, + int width, + int height, + int num_input_imgs, + int adm_in_channels = -1, + bool force_zero_embeddings = false) { + GGML_ASSERT(0 && "Not implemented yet!"); + } + + std::string remove_trigger_from_prompt(ggml_context* work_ctx, + const std::string& prompt) { + GGML_ASSERT(0 && "Not implemented yet!"); + } +}; + +struct PixArtCLIPEmbedder : public Conditioner { + T5UniGramTokenizer t5_tokenizer; + std::shared_ptr t5; + size_t chunk_len = 512; + bool use_mask = false; + int mask_pad = 1; + + PixArtCLIPEmbedder(ggml_backend_t backend, + std::map& tensor_types, + int clip_skip = -1, + bool use_mask = false, + int mask_pad = 1) + : use_mask(use_mask), mask_pad(mask_pad) { + t5 = std::make_shared(backend, tensor_types, "text_encoders.t5xxl.transformer"); + } + + void set_clip_skip(int clip_skip) { + } + + void get_param_tensors(std::map& tensors) { + t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer"); + } + + void alloc_params_buffer() { + t5->alloc_params_buffer(); + } + + void free_params_buffer() { + t5->free_params_buffer(); + } + + size_t get_params_buffer_size() { + size_t buffer_size = 0; + + buffer_size += t5->get_params_buffer_size(); + + return buffer_size; + } + + std::tuple, std::vector, std::vector> tokenize(std::string text, + size_t max_length = 0, + bool padding = false) { + auto parsed_attention = parse_prompt_attention(text); + + { + std::stringstream ss; + ss << "["; + for (const auto& item : parsed_attention) { + ss << "['" << item.first << "', " << item.second << "], "; + } + ss << "]"; + LOG_DEBUG("parse '%s' to %s", text.c_str(), ss.str().c_str()); + } + + auto on_new_token_cb = [&](std::string& str, std::vector& bpe_tokens) -> bool { + return false; + }; + + std::vector t5_tokens; + std::vector t5_weights; + std::vector t5_mask; + for (const auto& item : parsed_attention) { + const std::string& curr_text = item.first; + float curr_weight = item.second; + + std::vector curr_tokens = t5_tokenizer.Encode(curr_text, true); + t5_tokens.insert(t5_tokens.end(), curr_tokens.begin(), curr_tokens.end()); + t5_weights.insert(t5_weights.end(), curr_tokens.size(), curr_weight); + } + + t5_tokenizer.pad_tokens(t5_tokens, t5_weights, &t5_mask, max_length, padding); + + return {t5_tokens, t5_weights, t5_mask}; + } + + void modify_mask_to_attend_padding(struct ggml_tensor* mask, int max_seq_length, int num_extra_padding = 8) { + float* mask_data = (float*)mask->data; + int num_pad = 0; + for (int64_t i = 0; i < max_seq_length; i++) { + if (num_pad >= num_extra_padding) { + break; + } + if (std::isinf(mask_data[i])) { + mask_data[i] = 0; + ++num_pad; + } + } + // LOG_DEBUG("PAD: %d", num_pad); + } + + SDCondition get_learned_condition_common(ggml_context* work_ctx, + int n_threads, + std::tuple, std::vector, std::vector> token_and_weights, + int clip_skip, + bool force_zero_embeddings = false) { + auto& t5_tokens = std::get<0>(token_and_weights); + auto& t5_weights = std::get<1>(token_and_weights); + auto& t5_attn_mask_vec = std::get<2>(token_and_weights); + + int64_t t0 = ggml_time_ms(); + struct ggml_tensor* hidden_states = NULL; // [N, n_token, 4096] + struct ggml_tensor* chunk_hidden_states = NULL; // [n_token, 4096] + struct ggml_tensor* pooled = NULL; // [768,] + struct ggml_tensor* t5_attn_mask = vector_to_ggml_tensor(work_ctx, t5_attn_mask_vec); // [768,] + + std::vector hidden_states_vec; + + size_t chunk_count = t5_tokens.size() / chunk_len; + + for (int chunk_idx = 0; chunk_idx < chunk_count; chunk_idx++) { + // t5 + std::vector chunk_tokens(t5_tokens.begin() + chunk_idx * chunk_len, + t5_tokens.begin() + (chunk_idx + 1) * chunk_len); + std::vector chunk_weights(t5_weights.begin() + chunk_idx * chunk_len, + t5_weights.begin() + (chunk_idx + 1) * chunk_len); + std::vector chunk_mask(t5_attn_mask_vec.begin() + chunk_idx * chunk_len, + t5_attn_mask_vec.begin() + (chunk_idx + 1) * chunk_len); + + auto input_ids = vector_to_ggml_tensor_i32(work_ctx, chunk_tokens); + auto t5_attn_mask_chunk = use_mask ? vector_to_ggml_tensor(work_ctx, chunk_mask) : NULL; + + t5->compute(n_threads, + input_ids, + t5_attn_mask_chunk, + &chunk_hidden_states, + work_ctx); + { + auto tensor = chunk_hidden_states; + float original_mean = ggml_tensor_mean(tensor); + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + float value = ggml_tensor_get_f32(tensor, i0, i1, i2); + value *= chunk_weights[i1]; + ggml_tensor_set_f32(tensor, value, i0, i1, i2); + } + } + } + float new_mean = ggml_tensor_mean(tensor); + ggml_tensor_scale(tensor, (original_mean / new_mean)); + } + + int64_t t1 = ggml_time_ms(); + LOG_DEBUG("computing condition graph completed, taking %" PRId64 " ms", t1 - t0); + if (force_zero_embeddings) { + float* vec = (float*)chunk_hidden_states->data; + for (int i = 0; i < ggml_nelements(chunk_hidden_states); i++) { + vec[i] = 0; + } + } + + hidden_states_vec.insert(hidden_states_vec.end(), + (float*)chunk_hidden_states->data, + ((float*)chunk_hidden_states->data) + ggml_nelements(chunk_hidden_states)); + } + + if (hidden_states_vec.size() > 0) { + hidden_states = vector_to_ggml_tensor(work_ctx, hidden_states_vec); + hidden_states = ggml_reshape_2d(work_ctx, + hidden_states, + chunk_hidden_states->ne[0], + ggml_nelements(hidden_states) / chunk_hidden_states->ne[0]); + } else { + hidden_states = ggml_new_tensor_2d(work_ctx, GGML_TYPE_F32, 4096, 256); + ggml_set_f32(hidden_states, 0.f); + } + + modify_mask_to_attend_padding(t5_attn_mask, ggml_nelements(t5_attn_mask), mask_pad); + + return SDCondition(hidden_states, t5_attn_mask, NULL); + } + + SDCondition get_learned_condition(ggml_context* work_ctx, + int n_threads, + const std::string& text, + int clip_skip, + int width, + int height, + int adm_in_channels = -1, + bool force_zero_embeddings = false) { + auto tokens_and_weights = tokenize(text, chunk_len, true); return get_learned_condition_common(work_ctx, n_threads, tokens_and_weights, clip_skip, force_zero_embeddings); } @@ -1203,4 +1425,4 @@ struct FluxCLIPEmbedder : public Conditioner { } }; -#endif \ No newline at end of file +#endif diff --git a/control.hpp b/control.hpp index 41f31acb7..23b75feff 100644 --- a/control.hpp +++ b/control.hpp @@ -34,11 +34,11 @@ class ControlNetBlock : public GGMLBlock { ControlNetBlock(SDVersion version = VERSION_SD1) : version(version) { - if (version == VERSION_SD2) { + if (sd_version_is_sd2(version)) { context_dim = 1024; num_head_channels = 64; num_heads = -1; - } else if (version == VERSION_SDXL) { + } else if (sd_version_is_sdxl(version)) { context_dim = 2048; attention_resolutions = {4, 2}; channel_mult = {1, 2, 4}; @@ -58,7 +58,7 @@ class ControlNetBlock : public GGMLBlock { // time_embed_1 is nn.SiLU() blocks["time_embed.2"] = std::shared_ptr(new Linear(time_embed_dim, time_embed_dim)); - if (version == VERSION_SDXL || version == VERSION_SVD) { + if (sd_version_is_sdxl(version) || version == VERSION_SVD) { blocks["label_emb.0.0"] = std::shared_ptr(new Linear(adm_in_channels, time_embed_dim)); // label_emb_1 is nn.SiLU() blocks["label_emb.0.2"] = std::shared_ptr(new Linear(time_embed_dim, time_embed_dim)); @@ -317,10 +317,10 @@ struct ControlNet : public GGMLRunner { bool guided_hint_cached = false; ControlNet(ggml_backend_t backend, - ggml_type wtype, + std::map& tensor_types, SDVersion version = VERSION_SD1) - : GGMLRunner(backend, wtype), control_net(version) { - control_net.init(params_ctx, wtype); + : GGMLRunner(backend), control_net(version) { + control_net.init(params_ctx, tensor_types, ""); } ~ControlNet() { diff --git a/denoiser.hpp b/denoiser.hpp index 975699d22..d4bcec590 100644 --- a/denoiser.hpp +++ b/denoiser.hpp @@ -168,24 +168,21 @@ struct AYSSchedule : SigmaSchedule { std::vector inputs; std::vector results(n + 1); - switch (version) { - case VERSION_SD2: /* fallthrough */ - LOG_WARN("AYS not designed for SD2.X models"); - case VERSION_SD1: - LOG_INFO("AYS using SD1.5 noise levels"); - inputs = noise_levels[0]; - break; - case VERSION_SDXL: - LOG_INFO("AYS using SDXL noise levels"); - inputs = noise_levels[1]; - break; - case VERSION_SVD: - LOG_INFO("AYS using SVD noise levels"); - inputs = noise_levels[2]; - break; - default: - LOG_ERROR("Version not compatable with AYS scheduler"); - return results; + if (sd_version_is_sd2((SDVersion)version)) { + LOG_WARN("AYS not designed for SD2.X models"); + } /* fallthrough */ + else if (sd_version_is_sd1((SDVersion)version)) { + LOG_INFO("AYS using SD1.5 noise levels"); + inputs = noise_levels[0]; + } else if (sd_version_is_sdxl((SDVersion)version)) { + LOG_INFO("AYS using SDXL noise levels"); + inputs = noise_levels[1]; + } else if (version == VERSION_SVD) { + LOG_INFO("AYS using SVD noise levels"); + inputs = noise_levels[2]; + } else { + LOG_ERROR("Version not compatible with AYS scheduler"); + return results; } /* Stretches those pre-calculated reference levels out to the desired @@ -346,6 +343,32 @@ struct CompVisVDenoiser : public CompVisDenoiser { } }; +struct EDMVDenoiser : public CompVisVDenoiser { + float min_sigma = 0.002; + float max_sigma = 120.0; + + EDMVDenoiser(float min_sigma = 0.002, float max_sigma = 120.0) + : min_sigma(min_sigma), max_sigma(max_sigma) { + schedule = std::make_shared(); + } + + float t_to_sigma(float t) { + return std::exp(t * 4 / (float)TIMESTEPS); + } + + float sigma_to_t(float s) { + return 0.25 * std::log(s); + } + + float sigma_min() { + return min_sigma; + } + + float sigma_max() { + return max_sigma; + } +}; + float time_snr_shift(float alpha, float t) { if (alpha == 1.0f) { return t; @@ -474,7 +497,8 @@ static void sample_k_diffusion(sample_method_t method, ggml_context* work_ctx, ggml_tensor* x, std::vector sigmas, - std::shared_ptr rng) { + std::shared_ptr rng, + float eta) { size_t steps = sigmas.size() - 1; // sample_euler_ancestral switch (method) { @@ -1005,6 +1029,370 @@ static void sample_k_diffusion(sample_method_t method, } } } break; + case DDIM_TRAILING: // Denoising Diffusion Implicit Models + // with the "trailing" timestep spacing + { + // See J. Song et al., "Denoising Diffusion Implicit + // Models", arXiv:2010.02502 [cs.LG] + // + // DDIM itself needs alphas_cumprod (DDPM, J. Ho et al., + // arXiv:2006.11239 [cs.LG] with k-diffusion's start and + // end beta) (which unfortunately k-diffusion's data + // structure hides from the denoiser), and the sigmas are + // also needed to invert the behavior of CompVisDenoiser + // (k-diffusion's LMSDiscreteScheduler) + float beta_start = 0.00085f; + float beta_end = 0.0120f; + std::vector alphas_cumprod; + std::vector compvis_sigmas; + + alphas_cumprod.reserve(TIMESTEPS); + compvis_sigmas.reserve(TIMESTEPS); + for (int i = 0; i < TIMESTEPS; i++) { + alphas_cumprod[i] = + (i == 0 ? 1.0f : alphas_cumprod[i - 1]) * + (1.0f - + std::pow(sqrtf(beta_start) + + (sqrtf(beta_end) - sqrtf(beta_start)) * + ((float)i / (TIMESTEPS - 1)), + 2)); + compvis_sigmas[i] = + std::sqrt((1 - alphas_cumprod[i]) / + alphas_cumprod[i]); + } + + struct ggml_tensor* pred_original_sample = + ggml_dup_tensor(work_ctx, x); + struct ggml_tensor* variance_noise = + ggml_dup_tensor(work_ctx, x); + + for (int i = 0; i < steps; i++) { + // The "trailing" DDIM timestep, see S. Lin et al., + // "Common Diffusion Noise Schedules and Sample Steps + // are Flawed", arXiv:2305.08891 [cs], p. 4, Table + // 2. Most variables below follow Diffusers naming + // + // Diffuser naming vs. Song et al. (2010), p. 5, (12) + // and p. 16, (16) ( -> ): + // + // - pred_noise_t -> epsilon_theta^(t)(x_t) + // - pred_original_sample -> f_theta^(t)(x_t) or x_0 + // - std_dev_t -> sigma_t (not the LMS sigma) + // - eta -> eta (set to 0 at the moment) + // - pred_sample_direction -> "direction pointing to + // x_t" + // - pred_prev_sample -> "x_t-1" + int timestep = + roundf(TIMESTEPS - + i * ((float)TIMESTEPS / steps)) - + 1; + // 1. get previous step value (=t-1) + int prev_timestep = timestep - TIMESTEPS / steps; + // The sigma here is chosen to cause the + // CompVisDenoiser to produce t = timestep + float sigma = compvis_sigmas[timestep]; + if (i == 0) { + // The function add_noise intializes x to + // Diffusers' latents * sigma (as in Diffusers' + // pipeline) or sample * sigma (Diffusers' + // scheduler), where this sigma = init_noise_sigma + // in Diffusers. For DDPM and DDIM however, + // init_noise_sigma = 1. But the k-diffusion + // model() also evaluates F_theta(c_in(sigma) x; + // ...) instead of the bare U-net F_theta, with + // c_in = 1 / sqrt(sigma^2 + 1), as defined in + // T. Karras et al., "Elucidating the Design Space + // of Diffusion-Based Generative Models", + // arXiv:2206.00364 [cs.CV], p. 3, Table 1. Hence + // the first call has to be prescaled as x <- x / + // (c_in * sigma) with the k-diffusion pipeline + // and CompVisDenoiser. + float* vec_x = (float*)x->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_x[j] *= std::sqrt(sigma * sigma + 1) / + sigma; + } + } else { + // For the subsequent steps after the first one, + // at this point x = latents or x = sample, and + // needs to be prescaled with x <- sample / c_in + // to compensate for model() applying the scale + // c_in before the U-net F_theta + float* vec_x = (float*)x->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_x[j] *= std::sqrt(sigma * sigma + 1); + } + } + // Note (also noise_pred in Diffuser's pipeline) + // model_output = model() is the D(x, sigma) as + // defined in Karras et al. (2022), p. 3, Table 1 and + // p. 8 (7), compare also p. 38 (226) therein. + struct ggml_tensor* model_output = + model(x, sigma, i + 1); + // Here model_output is still the k-diffusion denoiser + // output, not the U-net output F_theta(c_in(sigma) x; + // ...) in Karras et al. (2022), whereas Diffusers' + // model_output is F_theta(...). Recover the actual + // model_output, which is also referred to as the + // "Karras ODE derivative" d or d_cur in several + // samplers above. + { + float* vec_x = (float*)x->data; + float* vec_model_output = + (float*)model_output->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_model_output[j] = + (vec_x[j] - vec_model_output[j]) * + (1 / sigma); + } + } + // 2. compute alphas, betas + float alpha_prod_t = alphas_cumprod[timestep]; + // Note final_alpha_cumprod = alphas_cumprod[0] due to + // trailing timestep spacing + float alpha_prod_t_prev = prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]; + float beta_prod_t = 1 - alpha_prod_t; + // 3. compute predicted original sample from predicted + // noise also called "predicted x_0" of formula (12) + // from https://arxiv.org/pdf/2010.02502.pdf + { + float* vec_x = (float*)x->data; + float* vec_model_output = + (float*)model_output->data; + float* vec_pred_original_sample = + (float*)pred_original_sample->data; + // Note the substitution of latents or sample = x + // * c_in = x / sqrt(sigma^2 + 1) + for (int j = 0; j < ggml_nelements(x); j++) { + vec_pred_original_sample[j] = + (vec_x[j] / std::sqrt(sigma * sigma + 1) - + std::sqrt(beta_prod_t) * + vec_model_output[j]) * + (1 / std::sqrt(alpha_prod_t)); + } + } + // Assuming the "epsilon" prediction type, where below + // pred_epsilon = model_output is inserted, and is not + // defined/copied explicitly. + // + // 5. compute variance: "sigma_t(eta)" -> see formula + // (16) + // + // sigma_t = sqrt((1 - alpha_t-1)/(1 - alpha_t)) * + // sqrt(1 - alpha_t/alpha_t-1) + float beta_prod_t_prev = 1 - alpha_prod_t_prev; + float variance = (beta_prod_t_prev / beta_prod_t) * + (1 - alpha_prod_t / alpha_prod_t_prev); + float std_dev_t = eta * std::sqrt(variance); + // 6. compute "direction pointing to x_t" of formula + // (12) from https://arxiv.org/pdf/2010.02502.pdf + // 7. compute x_t without "random noise" of formula + // (12) from https://arxiv.org/pdf/2010.02502.pdf + { + float* vec_model_output = (float*)model_output->data; + float* vec_pred_original_sample = + (float*)pred_original_sample->data; + float* vec_x = (float*)x->data; + for (int j = 0; j < ggml_nelements(x); j++) { + // Two step inner loop without an explicit + // tensor + float pred_sample_direction = + std::sqrt(1 - alpha_prod_t_prev - + std::pow(std_dev_t, 2)) * + vec_model_output[j]; + vec_x[j] = std::sqrt(alpha_prod_t_prev) * + vec_pred_original_sample[j] + + pred_sample_direction; + } + } + if (eta > 0) { + ggml_tensor_set_f32_randn(variance_noise, rng); + float* vec_variance_noise = + (float*)variance_noise->data; + float* vec_x = (float*)x->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_x[j] += std_dev_t * vec_variance_noise[j]; + } + } + // See the note above: x = latents or sample here, and + // is not scaled by the c_in. For the final output + // this is correct, but for subsequent iterations, x + // needs to be prescaled again, since k-diffusion's + // model() differes from the bare U-net F_theta by the + // factor c_in. + } + } break; + case TCD: // Strategic Stochastic Sampling (Algorithm 4) in + // Trajectory Consistency Distillation + { + // See J. Zheng et al., "Trajectory Consistency + // Distillation: Improved Latent Consistency Distillation + // by Semi-Linear Consistency Function with Trajectory + // Mapping", arXiv:2402.19159 [cs.CV] + float beta_start = 0.00085f; + float beta_end = 0.0120f; + std::vector alphas_cumprod; + std::vector compvis_sigmas; + + alphas_cumprod.reserve(TIMESTEPS); + compvis_sigmas.reserve(TIMESTEPS); + for (int i = 0; i < TIMESTEPS; i++) { + alphas_cumprod[i] = + (i == 0 ? 1.0f : alphas_cumprod[i - 1]) * + (1.0f - + std::pow(sqrtf(beta_start) + + (sqrtf(beta_end) - sqrtf(beta_start)) * + ((float)i / (TIMESTEPS - 1)), + 2)); + compvis_sigmas[i] = + std::sqrt((1 - alphas_cumprod[i]) / + alphas_cumprod[i]); + } + int original_steps = 50; + + struct ggml_tensor* pred_original_sample = + ggml_dup_tensor(work_ctx, x); + struct ggml_tensor* noise = + ggml_dup_tensor(work_ctx, x); + + for (int i = 0; i < steps; i++) { + // Analytic form for TCD timesteps + int timestep = TIMESTEPS - 1 - + (TIMESTEPS / original_steps) * + (int)floor(i * ((float)original_steps / steps)); + // 1. get previous step value + int prev_timestep = i >= steps - 1 ? 0 : TIMESTEPS - 1 - (TIMESTEPS / original_steps) * (int)floor((i + 1) * ((float)original_steps / steps)); + // Here timestep_s is tau_n' in Algorithm 4. The _s + // notation appears to be that from C. Lu, + // "DPM-Solver: A Fast ODE Solver for Diffusion + // Probabilistic Model Sampling in Around 10 Steps", + // arXiv:2206.00927 [cs.LG], but this notation is not + // continued in Algorithm 4, where _n' is used. + int timestep_s = + (int)floor((1 - eta) * prev_timestep); + // Begin k-diffusion specific workaround for + // evaluating F_theta(x; ...) from D(x, sigma), same + // as in DDIM (and see there for detailed comments) + float sigma = compvis_sigmas[timestep]; + if (i == 0) { + float* vec_x = (float*)x->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_x[j] *= std::sqrt(sigma * sigma + 1) / + sigma; + } + } else { + float* vec_x = (float*)x->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_x[j] *= std::sqrt(sigma * sigma + 1); + } + } + struct ggml_tensor* model_output = + model(x, sigma, i + 1); + { + float* vec_x = (float*)x->data; + float* vec_model_output = + (float*)model_output->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_model_output[j] = + (vec_x[j] - vec_model_output[j]) * + (1 / sigma); + } + } + // 2. compute alphas, betas + // + // When comparing TCD with DDPM/DDIM note that Zheng + // et al. (2024) follows the DPM-Solver notation for + // alpha. One can find the following comment in the + // original DPM-Solver code + // (https://github.com/LuChengTHU/dpm-solver/): + // "**Important**: Please pay special attention for + // the args for `alphas_cumprod`: The `alphas_cumprod` + // is the \hat{alpha_n} arrays in the notations of + // DDPM. [...] Therefore, the notation \hat{alpha_n} + // is different from the notation alpha_t in + // DPM-Solver. In fact, we have alpha_{t_n} = + // \sqrt{\hat{alpha_n}}, [...]" + float alpha_prod_t = alphas_cumprod[timestep]; + float beta_prod_t = 1 - alpha_prod_t; + // Note final_alpha_cumprod = alphas_cumprod[0] since + // TCD is always "trailing" + float alpha_prod_t_prev = prev_timestep >= 0 ? alphas_cumprod[prev_timestep] : alphas_cumprod[0]; + // The subscript _s are the only portion in this + // section (2) unique to TCD + float alpha_prod_s = alphas_cumprod[timestep_s]; + float beta_prod_s = 1 - alpha_prod_s; + // 3. Compute the predicted noised sample x_s based on + // the model parameterization + // + // This section is also exactly the same as DDIM + { + float* vec_x = (float*)x->data; + float* vec_model_output = + (float*)model_output->data; + float* vec_pred_original_sample = + (float*)pred_original_sample->data; + for (int j = 0; j < ggml_nelements(x); j++) { + vec_pred_original_sample[j] = + (vec_x[j] / std::sqrt(sigma * sigma + 1) - + std::sqrt(beta_prod_t) * + vec_model_output[j]) * + (1 / std::sqrt(alpha_prod_t)); + } + } + // This consistency function step can be difficult to + // decipher from Algorithm 4, as it is simply stated + // using a consistency function. This step is the + // modified DDIM, i.e. p. 8 (32) in Zheng et + // al. (2024), with eta set to 0 (see the paragraph + // immediately thereafter that states this somewhat + // obliquely). + { + float* vec_pred_original_sample = + (float*)pred_original_sample->data; + float* vec_model_output = + (float*)model_output->data; + float* vec_x = (float*)x->data; + for (int j = 0; j < ggml_nelements(x); j++) { + // Substituting x = pred_noised_sample and + // pred_epsilon = model_output + vec_x[j] = + std::sqrt(alpha_prod_s) * + vec_pred_original_sample[j] + + std::sqrt(beta_prod_s) * + vec_model_output[j]; + } + } + // 4. Sample and inject noise z ~ N(0, I) for + // MultiStep Inference Noise is not used on the final + // timestep of the timestep schedule. This also means + // that noise is not used for one-step sampling. Eta + // (referred to as "gamma" in the paper) was + // introduced to control the stochasticity in every + // step. When eta = 0, it represents deterministic + // sampling, whereas eta = 1 indicates full stochastic + // sampling. + if (eta > 0 && i != steps - 1) { + // In this case, x is still pred_noised_sample, + // continue in-place + ggml_tensor_set_f32_randn(noise, rng); + float* vec_x = (float*)x->data; + float* vec_noise = (float*)noise->data; + for (int j = 0; j < ggml_nelements(x); j++) { + // Corresponding to (35) in Zheng et + // al. (2024), substituting x = + // pred_noised_sample + vec_x[j] = + std::sqrt(alpha_prod_t_prev / + alpha_prod_s) * + vec_x[j] + + std::sqrt(1 - alpha_prod_t_prev / + alpha_prod_s) * + vec_noise[j]; + } + } + } + } break; default: LOG_ERROR("Attempting to sample with nonexisting sample method %i", method); diff --git a/diffusion_model.hpp b/diffusion_model.hpp index 2530f7149..5c349439d 100644 --- a/diffusion_model.hpp +++ b/diffusion_model.hpp @@ -13,11 +13,13 @@ struct DiffusionModel { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, + std::vector ref_latents = {}, int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL) = 0; + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) = 0; virtual void alloc_params_buffer() = 0; virtual void free_params_buffer() = 0; virtual void free_compute_buffer() = 0; @@ -30,9 +32,10 @@ struct UNetModel : public DiffusionModel { UNetModelRunner unet; UNetModel(ggml_backend_t backend, - ggml_type wtype, - SDVersion version = VERSION_SD1) - : unet(backend, wtype, version) { + std::map& tensor_types, + SDVersion version = VERSION_SD1, + bool flash_attn = false) + : unet(backend, tensor_types, "model.diffusion_model", version, flash_attn) { } void alloc_params_buffer() { @@ -66,11 +69,14 @@ struct UNetModel : public DiffusionModel { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, + std::vector ref_latents = {}, int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL) { + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { + (void)skip_layers; // SLG doesn't work with UNet models return unet.compute(n_threads, x, timesteps, context, c_concat, y, num_video_frames, controls, control_strength, output, output_ctx); } }; @@ -79,9 +85,8 @@ struct MMDiTModel : public DiffusionModel { MMDiTRunner mmdit; MMDiTModel(ggml_backend_t backend, - ggml_type wtype, - SDVersion version = VERSION_SD3_2B) - : mmdit(backend, wtype, version) { + std::map& tensor_types) + : mmdit(backend, tensor_types, "model.diffusion_model") { } void alloc_params_buffer() { @@ -115,12 +120,14 @@ struct MMDiTModel : public DiffusionModel { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, + std::vector ref_latents = {}, int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL) { - return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx); + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { + return mmdit.compute(n_threads, x, timesteps, context, y, output, output_ctx, skip_layers); } }; @@ -128,9 +135,11 @@ struct FluxModel : public DiffusionModel { Flux::FluxRunner flux; FluxModel(ggml_backend_t backend, - ggml_type wtype, - SDVersion version = VERSION_FLUX_DEV) - : flux(backend, wtype, version) { + std::map& tensor_types, + SDVersion version = VERSION_FLUX, + bool flash_attn = false, + bool use_mask = false) + : flux(backend, tensor_types, "model.diffusion_model", version, flash_attn, use_mask) { } void alloc_params_buffer() { @@ -164,13 +173,15 @@ struct FluxModel : public DiffusionModel { struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, + std::vector ref_latents = {}, int num_video_frames = -1, std::vector controls = {}, float control_strength = 0.f, struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL) { - return flux.compute(n_threads, x, timesteps, context, y, guidance, output, output_ctx); + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { + return flux.compute(n_threads, x, timesteps, context, c_concat, y, guidance, ref_latents, output, output_ctx, skip_layers); } }; -#endif \ No newline at end of file +#endif diff --git a/docs/chroma.md b/docs/chroma.md new file mode 100644 index 000000000..d013a43c8 --- /dev/null +++ b/docs/chroma.md @@ -0,0 +1,33 @@ +# How to Use + +You can run Chroma using stable-diffusion.cpp with a GPU that has 6GB or even 4GB of VRAM, without needing to offload to RAM. + +## Download weights + +- Download Chroma + - If you don't want to do the conversion yourself, download the preconverted gguf model from [silveroxides/Chroma-GGUF](https://huggingface.co/silveroxides/Chroma-GGUF) + - Otherwise, download chroma's safetensors from [lodestones/Chroma](https://huggingface.co/lodestones/Chroma) +- Download vae from https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/ae.safetensors +- Download t5xxl from https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp16.safetensors + +## Convert Chroma weights + +You can download the preconverted gguf weights from [silveroxides/Chroma-GGUF](https://huggingface.co/silveroxides/Chroma-GGUF), this way you don't have to do the conversion yourself. + +``` +.\bin\Release\sd.exe -M convert -m ..\..\ComfyUI\models\unet\chroma-unlocked-v40.safetensors -o ..\models\chroma-unlocked-v40-q8_0.gguf -v --type q8_0 +``` + +## Run + +### Example +For example: + +``` + .\bin\Release\sd.exe -diffusion-model ..\models\chroma-unlocked-v40-q8_0.gguf --vae ..\models\ae.sft --t5xxl ..\models\t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'chroma.cpp'" --cfg-scale 4.0 --sampling-method euler -v --chroma-disable-dit-mask +``` + +![](../assets/flux/chroma_v40.png) + + + diff --git a/docs/kontext.md b/docs/kontext.md new file mode 100644 index 000000000..698735039 --- /dev/null +++ b/docs/kontext.md @@ -0,0 +1,39 @@ +# How to Use + +You can run Kontext using stable-diffusion.cpp with a GPU that has 6GB or even 4GB of VRAM, without needing to offload to RAM. + +## Download weights + +- Download Kontext + - If you don't want to do the conversion yourself, download the preconverted gguf model from [FLUX.1-Kontext-dev-GGUF](https://huggingface.co/QuantStack/FLUX.1-Kontext-dev-GGUF) + - Otherwise, download FLUX.1-Kontext-dev from https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev/blob/main/flux1-kontext-dev.safetensors +- Download vae from https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/ae.safetensors +- Download clip_l from https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/clip_l.safetensors +- Download t5xxl from https://huggingface.co/comfyanonymous/flux_text_encoders/blob/main/t5xxl_fp16.safetensors + +## Convert Kontext weights + +You can download the preconverted gguf weights from [FLUX.1-Kontext-dev-GGUF](https://huggingface.co/QuantStack/FLUX.1-Kontext-dev-GGUF), this way you don't have to do the conversion yourself. + +``` +.\bin\Release\sd.exe -M convert -m ..\..\ComfyUI\models\unet\flux1-kontext-dev.safetensors -o ..\models\flux1-kontext-dev-q8_0.gguf -v --type q8_0 +``` + +## Run + +- `--cfg-scale` is recommended to be set to 1. + +### Example +For example: + +``` + .\bin\Release\sd.exe -r .\flux1-dev-q8_0.png --diffusion-model ..\models\flux1-kontext-dev-q8_0.gguf --vae ..\models\ae.sft --clip_l ..\models\clip_l.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -p "change 'flux.cpp' to 'kontext.cpp'" --cfg-scale 1.0 --sampling-method euler -v +``` + + +| ref_image | prompt | output | +| ---- | ---- |---- | +| ![](../assets/flux/flux1-dev-q8_0.png) | change 'flux.cpp' to 'kontext.cpp' |![](../assets/flux/kontext1_dev_output.png) | + + + diff --git a/docs/photo_maker.md b/docs/photo_maker.md index b69ad97d9..8305a33bd 100644 --- a/docs/photo_maker.md +++ b/docs/photo_maker.md @@ -29,4 +29,26 @@ Example: ```bash bin/sd -m ../models/sdxlUnstableDiffusers_v11.safetensors --vae ../models/sdxl_vae.safetensors --stacked-id-embd-dir ../models/photomaker-v1.safetensors --input-id-images-dir ../assets/photomaker_examples/scarletthead_woman -p "a girl img, retro futurism, retro game art style but extremely beautiful, intricate details, masterpiece, best quality, space-themed, cosmic, celestial, stars, galaxies, nebulas, planets, science fiction, highly detailed" -n "realistic, photo-realistic, worst quality, greyscale, bad anatomy, bad hands, error, text" --cfg-scale 5.0 --sampling-method euler -H 1024 -W 1024 --style-ratio 10 --vae-on-cpu -o output.png -``` \ No newline at end of file +``` + +## PhotoMaker Version 2 + +[PhotoMaker Version 2 (PMV2)](https://github.com/TencentARC/PhotoMaker/blob/main/README_pmv2.md) has some key improvements. Unfortunately it has a very heavy dependency which makes running it a bit involved in ```SD.cpp```. + +Running PMV2 is now a two-step process: + +- Run a python script ```face_detect.py``` to obtain **id_embeds** for the given input images +``` +python face_detect.py input_image_dir +``` +An ```id_embeds.safetensors``` file will be generated in ```input_images_dir``` + +**Note: this step is only needed to run once; the same ```id_embeds``` can be reused** + +- Run the same command as in version 1 but replacing ```photomaker-v1.safetensors``` with ```photomaker-v2.safetensors```. + + You can download ```photomaker-v2.safetensors``` from [here](https://huggingface.co/bssrdf/PhotoMakerV2) + +- All the command line parameters from Version 1 remain the same for Version 2 + + diff --git a/esrgan.hpp b/esrgan.hpp index 33fcf09a4..5cbb4ad8f 100644 --- a/esrgan.hpp +++ b/esrgan.hpp @@ -130,8 +130,8 @@ class RRDBNet : public GGMLBlock { body_feat = conv_body->forward(ctx, body_feat); feat = ggml_add(ctx, feat, body_feat); // upsample - feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2))); - feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2))); + feat = lrelu(ctx, conv_up1->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST))); + feat = lrelu(ctx, conv_up2->forward(ctx, ggml_upscale(ctx, feat, 2, GGML_SCALE_MODE_NEAREST))); auto out = conv_last->forward(ctx, lrelu(ctx, conv_hr->forward(ctx, feat))); return out; } @@ -142,10 +142,9 @@ struct ESRGAN : public GGMLRunner { int scale = 4; int tile_size = 128; // avoid cuda OOM for 4gb VRAM - ESRGAN(ggml_backend_t backend, - ggml_type wtype) - : GGMLRunner(backend, wtype) { - rrdb_net.init(params_ctx, wtype); + ESRGAN(ggml_backend_t backend, std::map& tensor_types) + : GGMLRunner(backend) { + rrdb_net.init(params_ctx, tensor_types, ""); } std::string get_desc() { diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 81053f9e2..2dcd1d53a 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,3 +1,4 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}) -add_subdirectory(cli) \ No newline at end of file +add_subdirectory(cli) +add_subdirectory(server) \ No newline at end of file diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index f1bdc698b..b3ae569e6 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -1,13 +1,15 @@ #include #include #include +#include #include +#include #include +#include #include #include // #include "preprocessing.hpp" -#include "flux.hpp" #include "stable-diffusion.h" #define STB_IMAGE_IMPLEMENTATION @@ -22,53 +24,26 @@ #define STB_IMAGE_RESIZE_STATIC #include "stb_image_resize.h" -const char* rng_type_to_str[] = { - "std_default", - "cuda", -}; - -// Names of the sampler method, same order as enum sample_method in stable-diffusion.h -const char* sample_method_str[] = { - "euler_a", - "euler", - "heun", - "dpm2", - "dpm++2s_a", - "dpm++2m", - "dpm++2mv2", - "ipndm", - "ipndm_v", - "lcm", -}; - -// Names of the sigma schedule overrides, same order as sample_schedule in stable-diffusion.h -const char* schedule_str[] = { - "default", - "discrete", - "karras", - "exponential", - "ays", - "gits", -}; +#define SAFE_STR(s) ((s) ? (s) : "") +#define BOOL_STR(b) ((b) ? "true" : "false") const char* modes_str[] = { - "txt2img", - "img2img", - "img2vid", + "img_gen", + "vid_gen", "convert", }; +#define SD_ALL_MODES_STR "img_gen, vid_gen, convert" enum SDMode { - TXT2IMG, - IMG2IMG, - IMG2VID, + IMG_GEN, + VID_GEN, CONVERT, MODE_COUNT }; struct SDParams { int n_threads = -1; - SDMode mode = TXT2IMG; + SDMode mode = IMG_GEN; std::string model_path; std::string clip_l_path; std::string clip_g_path; @@ -77,26 +52,31 @@ struct SDParams { std::string vae_path; std::string taesd_path; std::string esrgan_path; - std::string controlnet_path; - std::string embeddings_path; - std::string stacked_id_embeddings_path; + std::string control_net_path; + std::string embedding_dir; + std::string stacked_id_embed_dir; std::string input_id_images_path; sd_type_t wtype = SD_TYPE_COUNT; + std::string tensor_type_rules; std::string lora_model_dir; std::string output_path = "output.png"; std::string input_path; + std::string mask_path; std::string control_image_path; + std::vector ref_image_paths; std::string prompt; std::string negative_prompt; - float min_cfg = 1.0f; - float cfg_scale = 7.0f; - float guidance = 3.5f; - float style_ratio = 20.f; - int clip_skip = -1; // <= 0 represents unspecified - int width = 512; - int height = 512; - int batch_count = 1; + float min_cfg = 1.0f; + float cfg_scale = 7.0f; + float img_cfg_scale = INFINITY; + float guidance = 3.5f; + float eta = 0.f; + float style_ratio = 20.f; + int clip_skip = -1; // <= 0 represents unspecified + int width = 512; + int height = 512; + int batch_count = 1; int video_frames = 6; int motion_bucket_id = 127; @@ -116,9 +96,19 @@ struct SDParams { bool normalize_input = false; bool clip_on_cpu = false; bool vae_on_cpu = false; + bool diffusion_flash_attn = false; bool canny_preprocess = false; bool color = false; int upscale_repeats = 1; + + std::vector skip_layers = {7, 8, 9}; + float slg_scale = 0.f; + float skip_layer_start = 0.01f; + float skip_layer_end = 0.2f; + + bool chroma_use_dit_mask = true; + bool chroma_use_t5_mask = false; + int chroma_t5_mask_pad = 1; }; void print_params(SDParams params) { @@ -134,36 +124,48 @@ void print_params(SDParams params) { printf(" vae_path: %s\n", params.vae_path.c_str()); printf(" taesd_path: %s\n", params.taesd_path.c_str()); printf(" esrgan_path: %s\n", params.esrgan_path.c_str()); - printf(" controlnet_path: %s\n", params.controlnet_path.c_str()); - printf(" embeddings_path: %s\n", params.embeddings_path.c_str()); - printf(" stacked_id_embeddings_path: %s\n", params.stacked_id_embeddings_path.c_str()); + printf(" control_net_path: %s\n", params.control_net_path.c_str()); + printf(" embedding_dir: %s\n", params.embedding_dir.c_str()); + printf(" stacked_id_embed_dir: %s\n", params.stacked_id_embed_dir.c_str()); printf(" input_id_images_path: %s\n", params.input_id_images_path.c_str()); printf(" style ratio: %.2f\n", params.style_ratio); printf(" normalize input image : %s\n", params.normalize_input ? "true" : "false"); printf(" output_path: %s\n", params.output_path.c_str()); printf(" init_img: %s\n", params.input_path.c_str()); + printf(" mask_img: %s\n", params.mask_path.c_str()); printf(" control_image: %s\n", params.control_image_path.c_str()); + printf(" ref_images_paths:\n"); + for (auto& path : params.ref_image_paths) { + printf(" %s\n", path.c_str()); + }; printf(" clip on cpu: %s\n", params.clip_on_cpu ? "true" : "false"); printf(" controlnet cpu: %s\n", params.control_net_cpu ? "true" : "false"); printf(" vae decoder on cpu:%s\n", params.vae_on_cpu ? "true" : "false"); + printf(" diffusion flash attention:%s\n", params.diffusion_flash_attn ? "true" : "false"); printf(" strength(control): %.2f\n", params.control_strength); printf(" prompt: %s\n", params.prompt.c_str()); printf(" negative_prompt: %s\n", params.negative_prompt.c_str()); printf(" min_cfg: %.2f\n", params.min_cfg); printf(" cfg_scale: %.2f\n", params.cfg_scale); + printf(" img_cfg_scale: %.2f\n", params.img_cfg_scale); + printf(" slg_scale: %.2f\n", params.slg_scale); printf(" guidance: %.2f\n", params.guidance); + printf(" eta: %.2f\n", params.eta); printf(" clip_skip: %d\n", params.clip_skip); printf(" width: %d\n", params.width); printf(" height: %d\n", params.height); - printf(" sample_method: %s\n", sample_method_str[params.sample_method]); - printf(" schedule: %s\n", schedule_str[params.schedule]); + printf(" sample_method: %s\n", sd_sample_method_name(params.sample_method)); + printf(" schedule: %s\n", sd_schedule_name(params.schedule)); printf(" sample_steps: %d\n", params.sample_steps); printf(" strength(img2img): %.2f\n", params.strength); - printf(" rng: %s\n", rng_type_to_str[params.rng_type]); + printf(" rng: %s\n", sd_rng_type_name(params.rng_type)); printf(" seed: %ld\n", params.seed); printf(" batch_count: %d\n", params.batch_count); printf(" vae_tiling: %s\n", params.vae_tiling ? "true" : "false"); printf(" upscale_repeats: %d\n", params.upscale_repeats); + printf(" chroma_use_dit_mask: %s\n", params.chroma_use_dit_mask ? "true" : "false"); + printf(" chroma_use_t5_mask: %s\n", params.chroma_use_t5_mask ? "true" : "false"); + printf(" chroma_t5_mask_pad: %d\n", params.chroma_t5_mask_pad); } void print_usage(int argc, const char* argv[]) { @@ -171,14 +173,14 @@ void print_usage(int argc, const char* argv[]) { printf("\n"); printf("arguments:\n"); printf(" -h, --help show this help message and exit\n"); - printf(" -M, --mode [MODEL] run mode (txt2img or img2img or convert, default: txt2img)\n"); + printf(" -M, --mode [MODE] run mode, one of: [img_gen, convert], default: img_gen\n"); printf(" -t, --threads N number of threads to use during computation (default: -1)\n"); printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n"); printf(" -m, --model [MODEL] path to full model\n"); printf(" --diffusion-model path to the standalone diffusion model\n"); printf(" --clip_l path to the clip-l text encoder\n"); - printf(" --clip_g path to the clip-l text encoder\n"); - printf(" --t5xxl path to the the t5xxl text encoder\n"); + printf(" --clip_g path to the clip-g text encoder\n"); + printf(" --t5xxl path to the t5xxl text encoder\n"); printf(" --vae [VAE] path to vae\n"); printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); printf(" --control-net [CONTROL_PATH] path to control net model\n"); @@ -188,22 +190,34 @@ void print_usage(int argc, const char* argv[]) { printf(" --normalize-input normalize PHOTOMAKER input id images\n"); printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n"); printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n"); - printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k)\n"); + printf(" --type [TYPE] weight type (examples: f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_K, q3_K, q4_K)\n"); printf(" If not specified, the default is the type of the weight file\n"); + printf(" --tensor-type-rules [EXPRESSION] weight type per tensor pattern (example: \"^vae\\.=f16,model\\.=q8_0\")\n"); printf(" --lora-model-dir [DIR] lora model directory\n"); printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n"); + printf(" --mask [MASK] path to the mask image, required by img2img with mask\n"); printf(" --control-image [IMAGE] path to image condition, control net\n"); + printf(" -r, --ref-image [PATH] reference image for Flux Kontext models (can be used multiple times) \n"); printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n"); printf(" -p, --prompt [PROMPT] the prompt to render\n"); printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n"); printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n"); + printf(" --img-cfg-scale SCALE image guidance scale for inpaint or instruct-pix2pix models: (default: same as --cfg-scale)\n"); + printf(" --guidance SCALE distilled guidance scale for models with guidance input (default: 3.5)\n"); + printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n"); + printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n"); + printf(" --eta SCALE eta in DDIM, only for DDIM and TCD: (default: 0)\n"); + printf(" --skip-layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n"); + printf(" --skip-layer-start START SLG enabling point: (default: 0.01)\n"); + printf(" --skip-layer-end END SLG disabling point: (default: 0.2)\n"); + printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n"); printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n"); - printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%%)\n"); + printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20)\n"); printf(" --control-strength STRENGTH strength to apply Control Net (default: 0.9)\n"); printf(" 1.0 corresponds to full destruction of information in init image\n"); printf(" -H, --height H image height, in pixel space (default: 512)\n"); printf(" -W, --width W image width, in pixel space (default: 512)\n"); - printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm}\n"); + printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm, ddim_trailing, tcd}\n"); printf(" sampling method (default: \"euler_a\")\n"); printf(" --steps STEPS number of sample steps (default: 20)\n"); printf(" --rng {std_default, cuda} RNG (default: cuda)\n"); @@ -215,341 +229,356 @@ void print_usage(int argc, const char* argv[]) { printf(" --vae-tiling process vae in tiles to reduce memory usage\n"); printf(" --vae-on-cpu keep vae in cpu (for low vram)\n"); printf(" --clip-on-cpu keep clip in cpu (for low vram)\n"); + printf(" --diffusion-fa use flash attention in the diffusion model (for low vram)\n"); + printf(" Might lower quality, since it implies converting k and v to f16.\n"); + printf(" This might crash if it is not supported by the backend.\n"); printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n"); printf(" --canny apply canny preprocessor (edge detection)\n"); - printf(" --color Colors the logging tags according to level\n"); + printf(" --color colors the logging tags according to level\n"); + printf(" --chroma-disable-dit-mask disable dit mask for chroma\n"); + printf(" --chroma-enable-t5-mask enable t5 mask for chroma\n"); + printf(" --chroma-t5-mask-pad PAD_SIZE t5 mask pad size of chroma\n"); printf(" -v, --verbose print extra info\n"); } -void parse_args(int argc, const char** argv, SDParams& params) { +struct StringOption { + std::string short_name; + std::string long_name; + std::string desc; + std::string* target; +}; + +struct IntOption { + std::string short_name; + std::string long_name; + std::string desc; + int* target; +}; + +struct FloatOption { + std::string short_name; + std::string long_name; + std::string desc; + float* target; +}; + +struct BoolOption { + std::string short_name; + std::string long_name; + std::string desc; + bool keep_true; + bool* target; +}; + +struct ManualOption { + std::string short_name; + std::string long_name; + std::string desc; + std::function cb; +}; + +struct ArgOptions { + std::vector string_options; + std::vector int_options; + std::vector float_options; + std::vector bool_options; + std::vector manual_options; +}; + +bool parse_options(int argc, const char** argv, ArgOptions& options) { bool invalid_arg = false; std::string arg; for (int i = 1; i < argc; i++) { arg = argv[i]; - if (arg == "-t" || arg == "--threads") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.n_threads = std::stoi(argv[i]); - } else if (arg == "-M" || arg == "--mode") { - if (++i >= argc) { - invalid_arg = true; - break; - } - const char* mode_selected = argv[i]; - int mode_found = -1; - for (int d = 0; d < MODE_COUNT; d++) { - if (!strcmp(mode_selected, modes_str[d])) { - mode_found = d; + for (auto& option : options.string_options) { + if ((option.short_name.size() > 0 && arg == option.short_name) || (option.long_name.size() > 0 && arg == option.long_name)) { + if (++i >= argc) { + invalid_arg = true; + break; } + *option.target = std::string(argv[i]); } - if (mode_found == -1) { - fprintf(stderr, - "error: invalid mode %s, must be one of [txt2img, img2img, img2vid, convert]\n", - mode_selected); - exit(1); - } - params.mode = (SDMode)mode_found; - } else if (arg == "-m" || arg == "--model") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.model_path = argv[i]; - } else if (arg == "--clip_l") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.clip_l_path = argv[i]; - } else if (arg == "--clip_g") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.clip_g_path = argv[i]; - } else if (arg == "--t5xxl") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.t5xxl_path = argv[i]; - } else if (arg == "--diffusion-model") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.diffusion_model_path = argv[i]; - } else if (arg == "--vae") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.vae_path = argv[i]; - } else if (arg == "--taesd") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.taesd_path = argv[i]; - } else if (arg == "--control-net") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.controlnet_path = argv[i]; - } else if (arg == "--upscale-model") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.esrgan_path = argv[i]; - } else if (arg == "--embd-dir") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.embeddings_path = argv[i]; - } else if (arg == "--stacked-id-embd-dir") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.stacked_id_embeddings_path = argv[i]; - } else if (arg == "--input-id-images-dir") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.input_id_images_path = argv[i]; - } else if (arg == "--type") { - if (++i >= argc) { - invalid_arg = true; - break; - } - std::string type = argv[i]; - if (type == "f32") { - params.wtype = SD_TYPE_F32; - } else if (type == "f16") { - params.wtype = SD_TYPE_F16; - } else if (type == "q4_0") { - params.wtype = SD_TYPE_Q4_0; - } else if (type == "q4_1") { - params.wtype = SD_TYPE_Q4_1; - } else if (type == "q5_0") { - params.wtype = SD_TYPE_Q5_0; - } else if (type == "q5_1") { - params.wtype = SD_TYPE_Q5_1; - } else if (type == "q8_0") { - params.wtype = SD_TYPE_Q8_0; - } else if (type == "q2_k") { - params.wtype = SD_TYPE_Q2_K; - } else if (type == "q3_k") { - params.wtype = SD_TYPE_Q3_K; - } else if (type == "q4_k") { - params.wtype = SD_TYPE_Q4_K; - } else { - fprintf(stderr, "error: invalid weight format %s, must be one of [f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k]\n", - type.c_str()); - exit(1); - } - } else if (arg == "--lora-model-dir") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.lora_model_dir = argv[i]; - } else if (arg == "-i" || arg == "--init-img") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.input_path = argv[i]; - } else if (arg == "--control-image") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.control_image_path = argv[i]; - } else if (arg == "-o" || arg == "--output") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.output_path = argv[i]; - } else if (arg == "-p" || arg == "--prompt") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.prompt = argv[i]; - } else if (arg == "--upscale-repeats") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.upscale_repeats = std::stoi(argv[i]); - if (params.upscale_repeats < 1) { - fprintf(stderr, "error: upscale multiplier must be at least 1\n"); - exit(1); - } - } else if (arg == "-n" || arg == "--negative-prompt") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.negative_prompt = argv[i]; - } else if (arg == "--cfg-scale") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.cfg_scale = std::stof(argv[i]); - } else if (arg == "--guidance") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.guidance = std::stof(argv[i]); - } else if (arg == "--strength") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.strength = std::stof(argv[i]); - } else if (arg == "--style-ratio") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.style_ratio = std::stof(argv[i]); - } else if (arg == "--control-strength") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.control_strength = std::stof(argv[i]); - } else if (arg == "-H" || arg == "--height") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.height = std::stoi(argv[i]); - } else if (arg == "-W" || arg == "--width") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.width = std::stoi(argv[i]); - } else if (arg == "--steps") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.sample_steps = std::stoi(argv[i]); - } else if (arg == "--clip-skip") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.clip_skip = std::stoi(argv[i]); - } else if (arg == "--vae-tiling") { - params.vae_tiling = true; - } else if (arg == "--control-net-cpu") { - params.control_net_cpu = true; - } else if (arg == "--normalize-input") { - params.normalize_input = true; - } else if (arg == "--clip-on-cpu") { - params.clip_on_cpu = true; // will slow down get_learned_condiotion but necessary for low MEM GPUs - } else if (arg == "--vae-on-cpu") { - params.vae_on_cpu = true; // will slow down latent decoding but necessary for low MEM GPUs - } else if (arg == "--canny") { - params.canny_preprocess = true; - } else if (arg == "-b" || arg == "--batch-count") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.batch_count = std::stoi(argv[i]); - } else if (arg == "--rng") { - if (++i >= argc) { - invalid_arg = true; - break; - } - std::string rng_type_str = argv[i]; - if (rng_type_str == "std_default") { - params.rng_type = STD_DEFAULT_RNG; - } else if (rng_type_str == "cuda") { - params.rng_type = CUDA_RNG; - } else { - invalid_arg = true; - break; - } - } else if (arg == "--schedule") { - if (++i >= argc) { - invalid_arg = true; - break; - } - const char* schedule_selected = argv[i]; - int schedule_found = -1; - for (int d = 0; d < N_SCHEDULES; d++) { - if (!strcmp(schedule_selected, schedule_str[d])) { - schedule_found = d; + } + if (invalid_arg) { + break; + } + + for (auto& option : options.int_options) { + if ((option.short_name.size() > 0 && arg == option.short_name) || (option.long_name.size() > 0 && arg == option.long_name)) { + if (++i >= argc) { + invalid_arg = true; + break; } + *option.target = std::stoi(argv[i]); } - if (schedule_found == -1) { - invalid_arg = true; - break; - } - params.schedule = (schedule_t)schedule_found; - } else if (arg == "-s" || arg == "--seed") { - if (++i >= argc) { - invalid_arg = true; - break; - } - params.seed = std::stoll(argv[i]); - } else if (arg == "--sampling-method") { - if (++i >= argc) { - invalid_arg = true; - break; + } + if (invalid_arg) { + break; + } + + for (auto& option : options.float_options) { + if ((option.short_name.size() > 0 && arg == option.short_name) || (option.long_name.size() > 0 && arg == option.long_name)) { + if (++i >= argc) { + invalid_arg = true; + break; + } + *option.target = std::stof(argv[i]); } - const char* sample_method_selected = argv[i]; - int sample_method_found = -1; - for (int m = 0; m < N_SAMPLE_METHODS; m++) { - if (!strcmp(sample_method_selected, sample_method_str[m])) { - sample_method_found = m; + } + if (invalid_arg) { + break; + } + + for (auto& option : options.bool_options) { + if ((option.short_name.size() > 0 && arg == option.short_name) || (option.long_name.size() > 0 && arg == option.long_name)) { + if (option.keep_true) { + *option.target = true; + } else { + *option.target = false; } } - if (sample_method_found == -1) { - invalid_arg = true; - break; + } + if (invalid_arg) { + break; + } + + for (auto& option : options.manual_options) { + if ((option.short_name.size() > 0 && arg == option.short_name) || (option.long_name.size() > 0 && arg == option.long_name)) { + int ret = option.cb(argc, argv, i); + if (ret < 0) { + invalid_arg = true; + break; + } + i += ret; } - params.sample_method = (sample_method_t)sample_method_found; - } else if (arg == "-h" || arg == "--help") { - print_usage(argc, argv); - exit(0); - } else if (arg == "-v" || arg == "--verbose") { - params.verbose = true; - } else if (arg == "--color") { - params.color = true; - } else { - fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); - print_usage(argc, argv); - exit(1); + } + if (invalid_arg) { + break; } } if (invalid_arg) { fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); + return false; + } + return true; +} + +void parse_args(int argc, const char** argv, SDParams& params) { + ArgOptions options; + options.string_options = { + {"-m", "--model", "", ¶ms.model_path}, + {"", "--clip_l", "", ¶ms.clip_l_path}, + {"", "--clip_g", "", ¶ms.clip_g_path}, + {"", "--t5xxl", "", ¶ms.t5xxl_path}, + {"", "--diffusion-model", "", ¶ms.diffusion_model_path}, + {"", "--vae", "", ¶ms.vae_path}, + {"", "--taesd", "", ¶ms.taesd_path}, + {"", "--control-net", "", ¶ms.control_net_path}, + {"", "--embd-dir", "", ¶ms.embedding_dir}, + {"", "--stacked-id-embd-dir", "", ¶ms.stacked_id_embed_dir}, + {"", "--lora-model-dir", "", ¶ms.lora_model_dir}, + {"-i", "--init-img", "", ¶ms.input_path}, + {"", "--tensor-type-rules", "", ¶ms.tensor_type_rules}, + {"", "--input-id-images-dir", "", ¶ms.input_id_images_path}, + {"", "--mask", "", ¶ms.mask_path}, + {"", "--control-image", "", ¶ms.control_image_path}, + {"-o", "--output", "", ¶ms.output_path}, + {"-p", "--prompt", "", ¶ms.prompt}, + {"-n", "--negative-prompt", "", ¶ms.negative_prompt}, + + {"", "--upscale-model", "", ¶ms.esrgan_path}, + }; + + options.int_options = { + {"-t", "--threads", "", ¶ms.n_threads}, + {"", "--upscale-repeats", "", ¶ms.upscale_repeats}, + {"-H", "--height", "", ¶ms.height}, + {"-W", "--width", "", ¶ms.width}, + {"", "--steps", "", ¶ms.sample_steps}, + {"", "--clip-skip", "", ¶ms.clip_skip}, + {"-b", "--batch-count", "", ¶ms.batch_count}, + {"", "--chroma-t5-mask-pad", "", ¶ms.chroma_t5_mask_pad}, + }; + + options.float_options = { + {"", "--cfg-scale", "", ¶ms.cfg_scale}, + {"", "--img-cfg-scale", "", ¶ms.img_cfg_scale}, + {"", "--guidance", "", ¶ms.guidance}, + {"", "--eta", "", ¶ms.eta}, + {"", "--strength", "", ¶ms.strength}, + {"", "--style-ratio", "", ¶ms.style_ratio}, + {"", "--control-strength", "", ¶ms.control_strength}, + {"", "--slg-scale", "", ¶ms.slg_scale}, + {"", "--skip-layer-start", "", ¶ms.skip_layer_start}, + {"", "--skip-layer-end", "", ¶ms.skip_layer_end}, + + }; + + options.bool_options = { + {"", "--vae-tiling", "", true, ¶ms.vae_tiling}, + {"", "--control-net-cpu", "", true, ¶ms.control_net_cpu}, + {"", "--normalize-input", "", true, ¶ms.normalize_input}, + {"", "--clip-on-cpu", "", true, ¶ms.clip_on_cpu}, + {"", "--vae-on-cpu", "", true, ¶ms.vae_on_cpu}, + {"", "--diffusion-fa", "", true, ¶ms.diffusion_flash_attn}, + {"", "--canny", "", true, ¶ms.canny_preprocess}, + {"-v", "--verbos", "", true, ¶ms.verbose}, + {"", "--color", "", true, ¶ms.color}, + {"", "--chroma-disable-dit-mask", "", false, ¶ms.chroma_use_dit_mask}, + {"", "--chroma-enable-t5-mask", "", true, ¶ms.chroma_use_t5_mask}, + }; + + auto on_mode_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* mode = argv[index]; + if (mode != NULL) { + int mode_found = -1; + for (int i = 0; i < MODE_COUNT; i++) { + if (!strcmp(mode, modes_str[i])) { + mode_found = i; + } + } + if (mode_found == -1) { + fprintf(stderr, + "error: invalid mode %s, must be one of [%s]\n", + mode, SD_ALL_MODES_STR); + exit(1); + } + params.mode = (SDMode)mode_found; + } + return 1; + }; + + auto on_type_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* arg = argv[index]; + params.wtype = str_to_sd_type(arg); + if (params.wtype == SD_TYPE_COUNT) { + fprintf(stderr, "error: invalid weight format %s\n", + arg); + return -1; + } + return 1; + }; + + auto on_rng_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* arg = argv[index]; + params.rng_type = str_to_rng_type(arg); + if (params.rng_type == RNG_TYPE_COUNT) { + fprintf(stderr, "error: invalid rng type %s\n", + arg); + return -1; + } + return 1; + }; + + auto on_schedule_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* arg = argv[index]; + params.schedule = str_to_schedule(arg); + if (params.schedule == SCHEDULE_COUNT) { + fprintf(stderr, "error: invalid schedule %s\n", + arg); + return -1; + } + return 1; + }; + + auto on_sample_method_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + const char* arg = argv[index]; + params.sample_method = str_to_sample_method(arg); + if (params.sample_method == SAMPLE_METHOD_COUNT) { + fprintf(stderr, "error: invalid sample method %s\n", + arg); + return -1; + } + return 1; + }; + + auto on_seed_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + params.seed = std::stoll(argv[index]); + return 1; + }; + + auto on_help_arg = [&](int argc, const char** argv, int index) { + print_usage(argc, argv); + exit(0); + return 0; + }; + + auto on_skip_layers_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + std::string layers_str = argv[index]; + if (layers_str[0] != '[' || layers_str[layers_str.size() - 1] != ']') { + return -1; + } + + layers_str = layers_str.substr(1, layers_str.size() - 2); + + std::regex regex("[, ]+"); + std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1); + std::sregex_token_iterator end; + std::vector tokens(iter, end); + std::vector layers; + for (const auto& token : tokens) { + try { + layers.push_back(std::stoi(token)); + } catch (const std::invalid_argument& e) { + return -1; + } + } + params.skip_layers = layers; + return 1; + }; + + auto on_ref_image_arg = [&](int argc, const char** argv, int index) { + if (++index >= argc) { + return -1; + } + params.ref_image_paths.push_back(argv[index]); + return 1; + }; + + options.manual_options = { + {"-M", "--mode", "", on_mode_arg}, + {"", "--type", "", on_type_arg}, + {"", "--rng", "", on_rng_arg}, + {"-s", "--seed", "", on_seed_arg}, + {"", "--sampling-method", "", on_sample_method_arg}, + {"", "--schedule", "", on_schedule_arg}, + {"", "--skip-layers", "", on_skip_layers_arg}, + {"-r", "--ref-image", "", on_ref_image_arg}, + {"-h", "--help", "", on_help_arg}, + }; + + if (!parse_options(argc, argv, options)) { print_usage(argc, argv); exit(1); } + if (params.n_threads <= 0) { params.n_threads = get_num_physical_cores(); } - if (params.mode != CONVERT && params.mode != IMG2VID && params.prompt.length() == 0) { + if (params.mode != CONVERT && params.mode != VID_GEN && params.prompt.length() == 0) { fprintf(stderr, "error: the following arguments are required: prompt\n"); print_usage(argc, argv); exit(1); @@ -561,12 +590,6 @@ void parse_args(int argc, const char** argv, SDParams& params) { exit(1); } - if ((params.mode == IMG2IMG || params.mode == IMG2VID) && params.input_path.length() == 0) { - fprintf(stderr, "error: when using the img2img mode, the following arguments are required: init-img\n"); - print_usage(argc, argv); - exit(1); - } - if (params.output_path.length() == 0) { fprintf(stderr, "error: the following arguments are required: output_path\n"); print_usage(argc, argv); @@ -593,6 +616,15 @@ void parse_args(int argc, const char** argv, SDParams& params) { exit(1); } + if (params.mode != CONVERT && params.tensor_type_rules.size() > 0) { + fprintf(stderr, "warning: --tensor-type-rules is currently supported only for conversion\n"); + } + + if (params.upscale_repeats < 1) { + fprintf(stderr, "error: upscale multiplier must be at least 1\n"); + exit(1); + } + if (params.seed < 0) { srand((int)time(NULL)); params.seed = rand(); @@ -603,6 +635,10 @@ void parse_args(int argc, const char** argv, SDParams& params) { params.output_path = "output.gguf"; } } + + if (!isfinite(params.img_cfg_scale)) { + params.img_cfg_scale = params.cfg_scale; + } } static std::string sd_basename(const std::string& path) { @@ -624,12 +660,23 @@ std::string get_image_params(SDParams params, int64_t seed) { } parameter_string += "Steps: " + std::to_string(params.sample_steps) + ", "; parameter_string += "CFG scale: " + std::to_string(params.cfg_scale) + ", "; + if (params.slg_scale != 0 && params.skip_layers.size() != 0) { + parameter_string += "SLG scale: " + std::to_string(params.cfg_scale) + ", "; + parameter_string += "Skip layers: ["; + for (const auto& layer : params.skip_layers) { + parameter_string += std::to_string(layer) + ", "; + } + parameter_string += "], "; + parameter_string += "Skip layer start: " + std::to_string(params.skip_layer_start) + ", "; + parameter_string += "Skip layer end: " + std::to_string(params.skip_layer_end) + ", "; + } parameter_string += "Guidance: " + std::to_string(params.guidance) + ", "; + parameter_string += "Eta: " + std::to_string(params.eta) + ", "; parameter_string += "Seed: " + std::to_string(seed) + ", "; parameter_string += "Size: " + std::to_string(params.width) + "x" + std::to_string(params.height) + ", "; parameter_string += "Model: " + sd_basename(params.model_path) + ", "; - parameter_string += "RNG: " + std::string(rng_type_to_str[params.rng_type]) + ", "; - parameter_string += "Sampler: " + std::string(sample_method_str[params.sample_method]); + parameter_string += "RNG: " + std::string(sd_rng_type_name(params.rng_type)) + ", "; + parameter_string += "Sampler: " + std::string(sd_sample_method_name(params.sample_method)); if (params.schedule == KARRAS) { parameter_string += " karras"; } @@ -686,6 +733,18 @@ int main(int argc, const char* argv[]) { parse_args(argc, argv, params); + sd_guidance_params_t guidance_params = {params.cfg_scale, + params.img_cfg_scale, + params.min_cfg, + params.guidance, + { + params.skip_layers.data(), + params.skip_layers.size(), + params.skip_layer_start, + params.skip_layer_end, + params.slg_scale, + }}; + sd_set_log_callback(sd_log_cb, (void*)¶ms); if (params.verbose) { @@ -694,7 +753,7 @@ int main(int argc, const char* argv[]) { } if (params.mode == CONVERT) { - bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype); + bool success = convert(params.model_path.c_str(), params.vae_path.c_str(), params.output_path.c_str(), params.wtype, params.tensor_type_rules.c_str()); if (!success) { fprintf(stderr, "convert '%s'/'%s' to '%s' failed\n", @@ -711,7 +770,7 @@ int main(int argc, const char* argv[]) { } } - if (params.mode == IMG2VID) { + if (params.mode == VID_GEN) { fprintf(stderr, "SVD support is broken, do not use it!!!\n"); return 1; } @@ -719,7 +778,10 @@ int main(int argc, const char* argv[]) { bool vae_decode_only = true; uint8_t* input_image_buffer = NULL; uint8_t* control_image_buffer = NULL; - if (params.mode == IMG2IMG || params.mode == IMG2VID) { + uint8_t* mask_image_buffer = NULL; + std::vector ref_images; + + if (params.input_path.size() > 0) { vae_decode_only = false; int c = 0; @@ -769,37 +831,81 @@ int main(int argc, const char* argv[]) { free(input_image_buffer); input_image_buffer = resized_image_buffer; } + } else if (params.ref_image_paths.size() > 0) { + vae_decode_only = false; + for (auto& path : params.ref_image_paths) { + int c = 0; + int width = 0; + int height = 0; + uint8_t* image_buffer = stbi_load(path.c_str(), &width, &height, &c, 3); + if (image_buffer == NULL) { + fprintf(stderr, "load image from '%s' failed\n", path.c_str()); + return 1; + } + if (c < 3) { + fprintf(stderr, "the number of channels for the input image must be >= 3, but got %d channels\n", c); + free(image_buffer); + return 1; + } + if (width <= 0) { + fprintf(stderr, "error: the width of image must be greater than 0\n"); + free(image_buffer); + return 1; + } + if (height <= 0) { + fprintf(stderr, "error: the height of image must be greater than 0\n"); + free(image_buffer); + return 1; + } + ref_images.push_back({(uint32_t)width, + (uint32_t)height, + 3, + image_buffer}); + } } - sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(), - params.clip_l_path.c_str(), - params.clip_g_path.c_str(), - params.t5xxl_path.c_str(), - params.diffusion_model_path.c_str(), - params.vae_path.c_str(), - params.taesd_path.c_str(), - params.controlnet_path.c_str(), - params.lora_model_dir.c_str(), - params.embeddings_path.c_str(), - params.stacked_id_embeddings_path.c_str(), - vae_decode_only, - params.vae_tiling, - true, - params.n_threads, - params.wtype, - params.rng_type, - params.schedule, - params.clip_on_cpu, - params.control_net_cpu, - params.vae_on_cpu); + sd_ctx_params_t sd_ctx_params = { + params.model_path.c_str(), + params.clip_l_path.c_str(), + params.clip_g_path.c_str(), + params.t5xxl_path.c_str(), + params.diffusion_model_path.c_str(), + params.vae_path.c_str(), + params.taesd_path.c_str(), + params.control_net_path.c_str(), + params.lora_model_dir.c_str(), + params.embedding_dir.c_str(), + params.stacked_id_embed_dir.c_str(), + vae_decode_only, + params.vae_tiling, + true, + params.n_threads, + params.wtype, + params.rng_type, + params.schedule, + params.clip_on_cpu, + params.control_net_cpu, + params.vae_on_cpu, + params.diffusion_flash_attn, + params.chroma_use_dit_mask, + params.chroma_use_t5_mask, + params.chroma_t5_mask_pad, + }; + + sd_ctx_t* sd_ctx = new_sd_ctx(&sd_ctx_params); if (sd_ctx == NULL) { printf("new_sd_ctx_t failed\n"); return 1; } + sd_image_t input_image = {(uint32_t)params.width, + (uint32_t)params.height, + 3, + input_image_buffer}; + sd_image_t* control_image = NULL; - if (params.controlnet_path.size() > 0 && params.control_image_path.size() > 0) { + if (params.control_net_path.size() > 0 && params.control_image_path.size() > 0) { int c = 0; control_image_buffer = stbi_load(params.control_image_path.c_str(), ¶ms.width, ¶ms.height, &c, 3); if (control_image_buffer == NULL) { @@ -822,88 +928,65 @@ int main(int argc, const char* argv[]) { } } - sd_image_t* results; - if (params.mode == TXT2IMG) { - results = txt2img(sd_ctx, - params.prompt.c_str(), - params.negative_prompt.c_str(), - params.clip_skip, - params.cfg_scale, - params.guidance, - params.width, - params.height, - params.sample_method, - params.sample_steps, - params.seed, - params.batch_count, - control_image, - params.control_strength, - params.style_ratio, - params.normalize_input, - params.input_id_images_path.c_str()); + std::vector default_mask_image_vec(params.width * params.height, 255); + if (params.mask_path != "") { + int c = 0; + mask_image_buffer = stbi_load(params.mask_path.c_str(), ¶ms.width, ¶ms.height, &c, 1); } else { - sd_image_t input_image = {(uint32_t)params.width, - (uint32_t)params.height, - 3, - input_image_buffer}; - - if (params.mode == IMG2VID) { - results = img2vid(sd_ctx, - input_image, - params.width, - params.height, - params.video_frames, - params.motion_bucket_id, - params.fps, - params.augmentation_level, - params.min_cfg, - params.cfg_scale, - params.sample_method, - params.sample_steps, - params.strength, - params.seed); - if (results == NULL) { - printf("generate failed\n"); - free_sd_ctx(sd_ctx); - return 1; - } - size_t last = params.output_path.find_last_of("."); - std::string dummy_name = last != std::string::npos ? params.output_path.substr(0, last) : params.output_path; - for (int i = 0; i < params.video_frames; i++) { - if (results[i].data == NULL) { - continue; - } - std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ".png" : dummy_name + ".png"; - stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, - results[i].data, 0, get_image_params(params, params.seed + i).c_str()); - printf("save result image to '%s'\n", final_image_path.c_str()); - free(results[i].data); - results[i].data = NULL; - } - free(results); - free_sd_ctx(sd_ctx); - return 0; - } else { - results = img2img(sd_ctx, - input_image, - params.prompt.c_str(), - params.negative_prompt.c_str(), - params.clip_skip, - params.cfg_scale, - params.guidance, - params.width, - params.height, - params.sample_method, - params.sample_steps, - params.strength, - params.seed, - params.batch_count, - control_image, - params.control_strength, - params.style_ratio, - params.normalize_input, - params.input_id_images_path.c_str()); - } + mask_image_buffer = default_mask_image_vec.data(); + } + sd_image_t mask_image = {(uint32_t)params.width, + (uint32_t)params.height, + 1, + mask_image_buffer}; + + sd_image_t* results; + int expected_num_results = 1; + if (params.mode == IMG_GEN) { + sd_img_gen_params_t img_gen_params = { + params.prompt.c_str(), + params.negative_prompt.c_str(), + params.clip_skip, + guidance_params, + input_image, + ref_images.data(), + (int)ref_images.size(), + mask_image, + params.width, + params.height, + params.sample_method, + params.sample_steps, + params.eta, + params.strength, + params.seed, + params.batch_count, + control_image, + params.control_strength, + params.style_ratio, + params.normalize_input, + params.input_id_images_path.c_str(), + }; + + results = generate_image(sd_ctx, &img_gen_params); + expected_num_results = params.batch_count; + } else if (params.mode == VID_GEN) { + sd_vid_gen_params_t vid_gen_params = { + input_image, + params.width, + params.height, + guidance_params, + params.sample_method, + params.sample_steps, + params.strength, + params.seed, + params.video_frames, + params.motion_bucket_id, + params.fps, + params.augmentation_level, + }; + + results = generate_video(sd_ctx, &vid_gen_params); + expected_num_results = params.video_frames; } if (results == NULL) { @@ -915,8 +998,7 @@ int main(int argc, const char* argv[]) { int upscale_factor = 4; // unused for RealESRGAN_x4plus_anime_6B.pth if (params.esrgan_path.size() > 0 && params.upscale_repeats > 0) { upscaler_ctx_t* upscaler_ctx = new_upscaler_ctx(params.esrgan_path.c_str(), - params.n_threads, - params.wtype); + params.n_threads); if (upscaler_ctx == NULL) { printf("new_upscaler_ctx failed\n"); @@ -940,16 +1022,41 @@ int main(int argc, const char* argv[]) { } } - size_t last = params.output_path.find_last_of("."); - std::string dummy_name = last != std::string::npos ? params.output_path.substr(0, last) : params.output_path; - for (int i = 0; i < params.batch_count; i++) { + std::string dummy_name, ext, lc_ext; + bool is_jpg; + size_t last = params.output_path.find_last_of("."); + size_t last_path = std::min(params.output_path.find_last_of("/"), + params.output_path.find_last_of("\\")); + if (last != std::string::npos // filename has extension + && (last_path == std::string::npos || last > last_path)) { + dummy_name = params.output_path.substr(0, last); + ext = lc_ext = params.output_path.substr(last); + std::transform(ext.begin(), ext.end(), lc_ext.begin(), ::tolower); + is_jpg = lc_ext == ".jpg" || lc_ext == ".jpeg" || lc_ext == ".jpe"; + } else { + dummy_name = params.output_path; + ext = lc_ext = ""; + is_jpg = false; + } + // appending ".png" to absent or unknown extension + if (!is_jpg && lc_ext != ".png") { + dummy_name += ext; + ext = ".png"; + } + for (int i = 0; i < expected_num_results; i++) { if (results[i].data == NULL) { continue; } - std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ".png" : dummy_name + ".png"; - stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, - results[i].data, 0, get_image_params(params, params.seed + i).c_str()); - printf("save result image to '%s'\n", final_image_path.c_str()); + std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1) + ext : dummy_name + ext; + if (is_jpg) { + stbi_write_jpg(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, + results[i].data, 90, get_image_params(params, params.seed + i).c_str()); + printf("save result JPEG image to '%s'\n", final_image_path.c_str()); + } else { + stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, + results[i].data, 0, get_image_params(params, params.seed + i).c_str()); + printf("save result PNG image to '%s'\n", final_image_path.c_str()); + } free(results[i].data); results[i].data = NULL; } diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt new file mode 100644 index 000000000..8cbd979d5 --- /dev/null +++ b/examples/server/CMakeLists.txt @@ -0,0 +1,6 @@ +set(TARGET sd-server) + +add_executable(${TARGET} main.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_link_libraries(${TARGET} PRIVATE stable-diffusion ${CMAKE_THREAD_LIBS_INIT}) +target_compile_features(${TARGET} PUBLIC cxx_std_17) \ No newline at end of file diff --git a/examples/server/b64.cpp b/examples/server/b64.cpp new file mode 100644 index 000000000..e0aacf045 --- /dev/null +++ b/examples/server/b64.cpp @@ -0,0 +1,42 @@ + +//FROM +//https://stackoverflow.com/a/34571089/5155484 + +static const std::string b = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";//= +static std::string base64_encode(const std::string &in) { + std::string out; + + int val=0, valb=-6; + for (uint8_t c : in) { + val = (val<<8) + c; + valb += 8; + while (valb>=0) { + out.push_back(b[(val>>valb)&0x3F]); + valb-=6; + } + } + if (valb>-6) out.push_back(b[((val<<8)>>(valb+8))&0x3F]); + while (out.size()%4) out.push_back('='); + return out; +} + + +static std::string base64_decode(const std::string &in) { + + std::string out; + + std::vector T(256,-1); + for (int i=0; i<64; i++) T[b[i]] = i; + + int val=0, valb=-8; + for (uint8_t c : in) { + if (T[c] == -1) break; + val = (val<<6) + T[c]; + valb += 6; + if (valb>=0) { + out.push_back(char((val>>valb)&0xFF)); + valb-=8; + } + } + return out; +} \ No newline at end of file diff --git a/examples/server/frontend.cpp b/examples/server/frontend.cpp new file mode 100644 index 000000000..b35279f52 --- /dev/null +++ b/examples/server/frontend.cpp @@ -0,0 +1,861 @@ +const std::string html_content = R"xxx( + + + + + + + + SDCPP Server + + +)xxx" +R"xxx( + + +

+

SDCPP Server

+

Model:

+
+
+
+
+
+ + +
+
+ + +
+ +
+
+

Settings

+
+
+ + +
+ +
+
+
+ + +
+
+ + +
+
+
+
+ + +
+
+ + +
+
+
+
+ + +
+
+ + +
+
+
+
+ + +
+
+ + +
+
+
+
+ + +
+
+ + +
+
+
+ +
+

Note: Changing these parameters may cause a longer wait time due to the models + reloading. Please use these parameters carefully.

+
+ + +
+
+
+ + +
+
+ + +
+
+
+
+ + +
+
+ + +
+
+
+

Load new model

+
+ + +
+
+ + +
+
+
+ + +
+
+ + +
+
+ + +
+
+
+
+ + +
+
+ + +
+
+
+ + +
+
+
+
+
+
+ +
+ +
+ 0% + 0% + 0% +
+
+
+
+

Current task status: -- | Queue: 0

+
+ )xxx" + R"xxx( + + + + +)xxx"; \ No newline at end of file diff --git a/examples/server/main.cpp b/examples/server/main.cpp new file mode 100644 index 000000000..5d635f8ee --- /dev/null +++ b/examples/server/main.cpp @@ -0,0 +1,1666 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +// #include "preprocessing.hpp" +#include "flux.hpp" +#include "stable-diffusion.h" + +#define STB_IMAGE_IMPLEMENTATION +#define STB_IMAGE_STATIC +#include "stb_image.h" + +#define STB_IMAGE_WRITE_IMPLEMENTATION +#define STB_IMAGE_WRITE_STATIC +#include "stb_image_write.h" + +#define STB_IMAGE_RESIZE_IMPLEMENTATION +#define STB_IMAGE_RESIZE_STATIC +#include "stb_image_resize.h" + +#include "b64.cpp" +#include "httplib.h" +#include "json.hpp" + +#include +#include +#include +#include +#include +#include + +#include "frontend.cpp" + +struct SDCtxParams { + std::string model_path; + std::string clip_l_path; + std::string clip_g_path; + std::string t5xxl_path; + std::string diffusion_model_path; + std::string vae_path; + std::string taesd_path; + + std::string control_net_path; + std::string lora_model_dir; + std::string embeddings_path; + std::string stacked_id_embeddings_path; + + bool vae_decode_only = false; + bool vae_tiling = false; + + int n_threads = -1; + sd_type_t wtype = SD_TYPE_COUNT; + + rng_type_t rng_type = CUDA_RNG; + schedule_t schedule = DEFAULT; + + bool keep_control_net_on_cpu = false; + bool keep_clip_on_cpu = false; + bool keep_vae_on_cpu = false; + + bool diffusion_flash_attn = false; +}; + +struct SDRequestParams { + // TODO set to true if esrgan_path is specified in args + bool upscale = false; + + std::string prompt; + std::string negative_prompt; + + float min_cfg = 1.0f; + float cfg_scale = 7.0f; + float guidance = 3.5f; + float style_ratio = 20.f; + int clip_skip = -1; // <= 0 represents unspecified + int width = 512; + int height = 512; + int batch_count = 1; + + sample_method_t sample_method = EULER_A; + int sample_steps = 20; + float strength = 0.75f; + float control_strength = 0.9f; + int64_t seed = 42; + + std::vector skip_layers = {7, 8, 9}; + float slg_scale = 0.; + float skip_layer_start = 0.01; + float skip_layer_end = 0.2; + bool normalize_input = false; + + // float apg_eta = 1.0f; + // float apg_momentum = 0.0f; + // float apg_norm_threshold = 0.0f; + // float apg_norm_smoothing = 0.0f; + + // sd_preview_t preview_method = SD_PREVIEW_NONE; + // int preview_interval = 1; +}; + +struct SDParams { + SDCtxParams ctxParams; + SDRequestParams lastRequest; + + std::string esrgan_path; + + std::string output_path = "./server/output.png"; + std::string input_path = "./server/input.png"; + std::string control_image_path = "./server/control.png"; + + // std::string preview_path = "./server/preview.png"; + + std::string models_dir; + std::string diffusion_models_dir; + std::string clip_dir; + std::string vae_dir; + std::string tae_dir; + + std::vector models_files; + std::vector diffusion_models_files; + std::vector clip_files; + std::vector vae_files; + std::vector tae_files; + + // external dir + std::string input_id_images_path; + + // Don't use TAE decoding by default + // bool taesd_preview = true; + + bool verbose = false; + + bool color = false; + + // server things + int port = 8080; + std::string host = "127.0.0.1"; +}; + +void print_params(SDParams params) { + printf("Starting Options: \n"); + printf(" n_threads: %d\n", params.ctxParams.n_threads); + printf(" mode: server\n"); + printf(" model_path: %s\n", params.ctxParams.model_path.c_str()); + printf(" wtype: %s\n", params.ctxParams.wtype < SD_TYPE_COUNT ? sd_type_name(params.ctxParams.wtype) : "unspecified"); + printf(" clip_l_path: %s\n", params.ctxParams.clip_l_path.c_str()); + printf(" clip_g_path: %s\n", params.ctxParams.clip_g_path.c_str()); + printf(" t5xxl_path: %s\n", params.ctxParams.t5xxl_path.c_str()); + printf(" diffusion_model_path: %s\n", params.ctxParams.diffusion_model_path.c_str()); + printf(" vae_path: %s\n", params.ctxParams.vae_path.c_str()); + printf(" taesd_path: %s\n", params.ctxParams.taesd_path.c_str()); + printf(" control_net_path: %s\n", params.ctxParams.control_net_path.c_str()); + printf(" embeddings_path: %s\n", params.ctxParams.embeddings_path.c_str()); + printf(" stacked_id_embeddings_path: %s\n", params.ctxParams.stacked_id_embeddings_path.c_str()); + printf(" input_id_images_path: %s\n", params.input_id_images_path.c_str()); + printf(" style ratio: %.2f\n", params.lastRequest.style_ratio); + printf(" normalize input image : %s\n", params.lastRequest.normalize_input ? "true" : "false"); + printf(" output_path: %s\n", params.output_path.c_str()); + printf(" init_img: %s\n", params.input_path.c_str()); + printf(" control_image: %s\n", params.control_image_path.c_str()); + printf(" clip on cpu: %s\n", params.ctxParams.keep_clip_on_cpu ? "true" : "false"); + printf(" control_net cpu: %s\n", params.ctxParams.keep_control_net_on_cpu ? "true" : "false"); + printf(" vae decoder on cpu:%s\n", params.ctxParams.keep_vae_on_cpu ? "true" : "false"); + printf(" diffusion flash attention:%s\n", params.ctxParams.diffusion_flash_attn ? "true" : "false"); + printf(" strength(control): %.2f\n", params.lastRequest.control_strength); + printf(" prompt: %s\n", params.lastRequest.prompt.c_str()); + printf(" negative_prompt: %s\n", params.lastRequest.negative_prompt.c_str()); + printf(" min_cfg: %.2f\n", params.lastRequest.min_cfg); + printf(" cfg_scale: %.2f\n", params.lastRequest.cfg_scale); + printf(" slg_scale: %.2f\n", params.lastRequest.slg_scale); + printf(" guidance: %.2f\n", params.lastRequest.guidance); + printf(" clip_skip: %d\n", params.lastRequest.clip_skip); + printf(" width: %d\n", params.lastRequest.width); + printf(" height: %d\n", params.lastRequest.height); + printf(" sample_method: %s\n", sd_sample_method_name(params.lastRequest.sample_method)); + printf(" schedule: %s\n", sd_schedule_name(params.ctxParams.schedule)); + printf(" sample_steps: %d\n", params.lastRequest.sample_steps); + printf(" strength(img2img): %.2f\n", params.lastRequest.strength); + printf(" rng: %s\n", sd_rng_type_name(params.ctxParams.rng_type)); + printf(" seed: %ld\n", params.lastRequest.seed); + printf(" batch_count: %d\n", params.lastRequest.batch_count); + printf(" vae_tiling: %s\n", params.ctxParams.vae_tiling ? "true" : "false"); +} + +void print_usage(int argc, const char* argv[]) { + printf("usage: %s [arguments]\n", argv[0]); + printf("\n"); + printf("arguments:\n"); + printf(" -h, --help show this help message and exit\n"); + printf(" -t, --threads N number of threads to use during computation (default: -1)\n"); + printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n"); + printf(" -m, --model [MODEL] path to full model\n"); + printf(" --diffusion-model path to the standalone diffusion model\n"); + printf(" --clip_l path to the clip-l text encoder\n"); + printf(" --clip_g path to the clip-g text encoder\n"); + printf(" --t5xxl path to the the t5xxl text encoder\n"); + printf(" --vae [VAE] path to vae\n"); + printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); + printf(" --control-net [CONTROL_PATH] path to control net model\n"); + printf(" --embd-dir [EMBEDDING_PATH] path to embeddings\n"); + printf(" --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings\n"); + printf(" --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir\n"); + printf(" --normalize-input normalize PHOTOMAKER input id images\n"); + // printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n"); + // printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n"); + printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k)\n"); + printf(" If not specified, the default is the type of the weight file\n"); + printf(" --lora-model-dir [DIR] lora model directory\n"); + printf(" --control-image [IMAGE] path to image condition, control net\n"); + printf(" -o, --output OUTPUT path to write result image to (default: ./output.png)\n"); + printf(" -p, --prompt [PROMPT] the prompt to render\n"); + printf(" -n, --negative-prompt PROMPT the negative prompt (default: \"\")\n"); + printf(" --cfg-scale SCALE unconditional guidance scale: (default: 7.0)\n"); + printf(" --slg-scale SCALE skip layer guidance (SLG) scale, only for DiT models: (default: 0)\n"); + printf(" 0 means disabled, a value of 2.5 is nice for sd3.5 medium\n"); + printf(" --skip_layers LAYERS Layers to skip for SLG steps: (default: [7,8,9])\n"); + printf(" --skip_layer_start START SLG enabling point: (default: 0.01)\n"); + printf(" --skip_layer_end END SLG disabling point: (default: 0.2)\n"); + printf(" SLG will be enabled at step int([STEPS]*[START]) and disabled at int([STEPS]*[END])\n"); + printf(" --strength STRENGTH strength for noising/unnoising (default: 0.75)\n"); + printf(" --style-ratio STYLE-RATIO strength for keeping input identity (default: 20%%)\n"); + printf(" --control-strength STRENGTH strength to apply Control Net (default: 0.9)\n"); + printf(" 1.0 corresponds to full destruction of information in init image\n"); + printf(" -H, --height H image height, in pixel space (default: 512)\n"); + printf(" -W, --width W image width, in pixel space (default: 512)\n"); + printf(" --sampling-method {euler, euler_a, heun, dpm2, dpm++2s_a, dpm++2m, dpm++2mv2, ipndm, ipndm_v, lcm}\n"); + printf(" sampling method (default: \"euler_a\")\n"); + printf(" --steps STEPS number of sample steps (default: 20)\n"); + printf(" --rng {std_default, cuda} RNG (default: cuda)\n"); + printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n"); + printf(" -b, --batch-count COUNT number of images to generate\n"); + printf(" --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)\n"); + printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n"); + printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n"); + printf(" --vae-tiling process vae in tiles to reduce memory usage\n"); + printf(" --vae-on-cpu keep vae in cpu (for low vram)\n"); + printf(" --clip-on-cpu keep clip in cpu (for low vram)\n"); + printf(" --diffusion-fa use flash attention in the diffusion model (for low vram)\n"); + printf(" Might lower quality, since it implies converting k and v to f16.\n"); + printf(" This might crash if it is not supported by the backend.\n"); + printf(" --control-net-cpu keep control_net in cpu (for low vram)\n"); + printf(" --canny apply canny preprocessor (edge detection)\n"); + printf(" --color Colors the logging tags according to level\n"); + printf(" -v, --verbose print extra info\n"); + printf(" --port port used for server (default: 8080)\n"); + printf(" --host IP address used for server. Use 0.0.0.0 to expose server to LAN (default: localhost)\n"); +} + +void parse_args(int argc, const char** argv, SDParams& params) { + bool invalid_arg = false; + std::string arg; + for (int i = 1; i < argc; i++) { + arg = argv[i]; + + if (arg == "-t" || arg == "--threads") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.ctxParams.n_threads = std::stoi(argv[i]); + } else if (arg == "-m" || arg == "--model") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.ctxParams.model_path = argv[i]; + } else if (arg == "--clip_l") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.ctxParams.clip_l_path = argv[i]; + } else if (arg == "--clip_g") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.ctxParams.clip_g_path = argv[i]; + } else if (arg == "--t5xxl") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.ctxParams.t5xxl_path = argv[i]; + } else if (arg == "--diffusion-model") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.ctxParams.diffusion_model_path = argv[i]; + } else if (arg == "--vae") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.ctxParams.vae_path = argv[i]; + } else if (arg == "--taesd") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.ctxParams.taesd_path = argv[i]; + } else if (arg == "--control-net") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.ctxParams.control_net_path = argv[i]; + } else if (arg == "--upscale-model") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.esrgan_path = argv[i]; + } else if (arg == "--embd-dir") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.ctxParams.embeddings_path = argv[i]; + } else if (arg == "--stacked-id-embd-dir") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.ctxParams.stacked_id_embeddings_path = argv[i]; + } else if (arg == "--input-id-images-dir") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.input_id_images_path = argv[i]; + } else if (arg == "--type") { + if (++i >= argc) { + invalid_arg = true; + break; + } + std::string type = argv[i]; + bool found = false; + std::string valid_types = ""; + for (size_t i = 0; i < SD_TYPE_COUNT; i++) { + auto trait = ggml_get_type_traits((ggml_type)i); + std::string name(trait->type_name); + if (name == "f32" || trait->to_float && trait->type_size) { + if (i) + valid_types += ", "; + valid_types += name; + if (type == name) { + if (ggml_quantize_requires_imatrix((ggml_type)i)) { + printf("\033[35;1m[WARNING]\033[0m: type %s requires imatrix to work properly. A dummy imatrix will be used, expect poor quality.\n", trait->type_name); + } + params.ctxParams.wtype = (enum sd_type_t)i; + found = true; + break; + } + } + } + if (!found) { + fprintf(stderr, "error: invalid weight format %s, must be one of [%s]\n", + type.c_str(), + valid_types.c_str()); + exit(1); + } + } else if (arg == "--lora-model-dir") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.ctxParams.lora_model_dir = argv[i]; + } else if (arg == "-i" || arg == "--init-img") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.input_path = argv[i]; + } else if (arg == "--control-image") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.control_image_path = argv[i]; + } else if (arg == "-o" || arg == "--output") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.output_path = argv[i]; + } else if (arg == "-p" || arg == "--prompt") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.lastRequest.prompt = argv[i]; + } else if (arg == "-n" || arg == "--negative-prompt") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.lastRequest.negative_prompt = argv[i]; + } else if (arg == "--cfg-scale") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.lastRequest.cfg_scale = std::stof(argv[i]); + } else if (arg == "--guidance") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.lastRequest.guidance = std::stof(argv[i]); + } else if (arg == "--strength") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.lastRequest.strength = std::stof(argv[i]); + } else if (arg == "--style-ratio") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.lastRequest.style_ratio = std::stof(argv[i]); + } else if (arg == "--control-strength") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.lastRequest.control_strength = std::stof(argv[i]); + } else if (arg == "-H" || arg == "--height") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.lastRequest.height = std::stoi(argv[i]); + } else if (arg == "-W" || arg == "--width") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.lastRequest.width = std::stoi(argv[i]); + } else if (arg == "--steps") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.lastRequest.sample_steps = std::stoi(argv[i]); + } else if (arg == "--clip-skip") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.lastRequest.clip_skip = std::stoi(argv[i]); + } else if (arg == "--vae-tiling") { + params.ctxParams.vae_tiling = true; + } else if (arg == "--control-net-cpu") { + params.ctxParams.keep_control_net_on_cpu = true; + } else if (arg == "--normalize-input") { + params.lastRequest.normalize_input = true; + } else if (arg == "--clip-on-cpu") { + params.ctxParams.keep_clip_on_cpu = true; // will slow down get_learned_condiotion but necessary for low MEM GPUs + } else if (arg == "--vae-on-cpu") { + params.ctxParams.keep_vae_on_cpu = true; // will slow down latent decoding but necessary for low MEM GPUs + } else if (arg == "--diffusion-fa") { + params.ctxParams.diffusion_flash_attn = true; // can reduce MEM significantly + } else if (arg == "-b" || arg == "--batch-count") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.lastRequest.batch_count = std::stoi(argv[i]); + } else if (arg == "--rng") { + if (++i >= argc) { + invalid_arg = true; + break; + } + std::string rng_type_str = argv[i]; + if (rng_type_str == "std_default") { + params.ctxParams.rng_type = STD_DEFAULT_RNG; + } else if (rng_type_str == "cuda") { + params.ctxParams.rng_type = CUDA_RNG; + } else { + invalid_arg = true; + break; + } + } else if (arg == "--schedule") { + if (++i >= argc) { + invalid_arg = true; + break; + } + const char* schedule_selected = argv[i]; + schedule_t schedule_found = str_to_schedule(schedule_selected); + if (schedule_found == SCHEDULE_COUNT) { + invalid_arg = true; + break; + } + params.ctxParams.schedule = schedule_found; + } else if (arg == "-s" || arg == "--seed") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.lastRequest.seed = std::stoll(argv[i]); + } else if (arg == "--sampling-method") { + if (++i >= argc) { + invalid_arg = true; + break; + } + const char* sample_method_selected = argv[i]; + int sample_method_found = str_to_sample_method(sample_method_selected); + if (sample_method_found == SAMPLE_METHOD_COUNT) { + invalid_arg = true; + break; + } + params.lastRequest.sample_method = (sample_method_t)sample_method_found; + } else if (arg == "-h" || arg == "--help") { + print_usage(argc, argv); + exit(0); + } else if (arg == "-v" || arg == "--verbose") { + params.verbose = true; + } else if (arg == "--color") { + params.color = true; + } else if (arg == "--slg-scale") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.lastRequest.slg_scale = std::stof(argv[i]); + } else if (arg == "--skip-layers") { + if (++i >= argc) { + invalid_arg = true; + break; + } + if (argv[i][0] != '[') { + invalid_arg = true; + break; + } + std::string layers_str = argv[i]; + while (layers_str.back() != ']') { + if (++i >= argc) { + invalid_arg = true; + break; + } + layers_str += " " + std::string(argv[i]); + } + layers_str = layers_str.substr(1, layers_str.size() - 2); + + std::regex regex("[, ]+"); + std::sregex_token_iterator iter(layers_str.begin(), layers_str.end(), regex, -1); + std::sregex_token_iterator end; + std::vector tokens(iter, end); + std::vector layers; + for (const auto& token : tokens) { + try { + layers.push_back(std::stoi(token)); + } catch (const std::invalid_argument& e) { + invalid_arg = true; + break; + } + } + params.lastRequest.skip_layers = layers; + + if (invalid_arg) { + break; + } + } else if (arg == "--skip-layer-start") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.lastRequest.skip_layer_start = std::stof(argv[i]); + } else if (arg == "--skip-layer-end") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.lastRequest.skip_layer_end = std::stof(argv[i]); + } else if (arg == "--port") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.port = std::stoi(argv[i]); + } else if (arg == "--host") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.host = argv[i]; + } else if (arg == "--models-dir") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.models_dir = argv[i]; + } else if (arg == "--diffusion-models-dir") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.diffusion_models_dir = argv[i]; + } else if (arg == "--encoders-dir") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.clip_dir = argv[i]; + } else if (arg == "--vae-dir") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.vae_dir = argv[i]; + } else if (arg == "--tae-dir") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.tae_dir = argv[i]; + } else { + fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); + print_usage(argc, argv); + exit(1); + } + } + if (invalid_arg) { + fprintf(stderr, "error: invalid parameter for argument: %s\n", arg.c_str()); + print_usage(argc, argv); + exit(1); + } + if (params.ctxParams.n_threads <= 0) { + params.ctxParams.n_threads = get_num_physical_cores(); + } +} + +static std::string sd_basename(const std::string& path) { + size_t pos = path.find_last_of('/'); + if (pos != std::string::npos) { + return path.substr(pos + 1); + } + pos = path.find_last_of('\\'); + if (pos != std::string::npos) { + return path.substr(pos + 1); + } + return path; +} + +std::string get_image_params(SDParams params, int64_t seed) { + std::string parameter_string = params.lastRequest.prompt + "\n"; + if (params.lastRequest.negative_prompt.size() != 0) { + parameter_string += "Negative prompt: " + params.lastRequest.negative_prompt + "\n"; + } + parameter_string += "Steps: " + std::to_string(params.lastRequest.sample_steps) + ", "; + parameter_string += "CFG scale: " + std::to_string(params.lastRequest.cfg_scale) + ", "; + parameter_string += "Guidance: " + std::to_string(params.lastRequest.guidance) + ", "; + parameter_string += "Seed: " + std::to_string(seed) + ", "; + parameter_string += "Size: " + std::to_string(params.lastRequest.width) + "x" + std::to_string(params.lastRequest.height) + ", "; + parameter_string += "Model: " + sd_basename(params.ctxParams.model_path) + ", "; + parameter_string += "RNG: " + std::string(sd_rng_type_name(params.ctxParams.rng_type)) + ", "; + parameter_string += "Sampler: " + std::string(sd_sample_method_name(params.lastRequest.sample_method)); + if (params.ctxParams.schedule == KARRAS) { + parameter_string += " karras"; + } + parameter_string += ", "; + parameter_string += "Version: stable-diffusion.cpp"; + return parameter_string; +} + +/* Enables Printing the log level tag in color using ANSI escape codes */ +void sd_log_cb(enum sd_log_level_t level, const char* log, void* data) { + SDParams* params = (SDParams*)data; + int tag_color; + const char* level_str; + FILE* out_stream = (level == SD_LOG_ERROR) ? stderr : stdout; + + if (!log || (!params->verbose && level <= SD_LOG_DEBUG)) { + return; + } + + switch (level) { + case SD_LOG_DEBUG: + tag_color = 37; + level_str = "DEBUG"; + break; + case SD_LOG_INFO: + tag_color = 34; + level_str = "INFO"; + break; + case SD_LOG_WARN: + tag_color = 35; + level_str = "WARN"; + break; + case SD_LOG_ERROR: + tag_color = 31; + level_str = "ERROR"; + break; + default: /* Potential future-proofing */ + tag_color = 33; + level_str = "?????"; + break; + } + + if (params->color == true) { + fprintf(out_stream, "\033[%d;1m[%-5s]\033[0m ", tag_color, level_str); + } else { + fprintf(out_stream, "[%-5s] ", level_str); + } + fputs(log, out_stream); + fflush(out_stream); +} + +void* server_log_params = NULL; + +// enable logging in the server +#define LOG_BUFFER_SIZE 1024 +void sd_log(enum sd_log_level_t level, const char* format, ...) { + va_list args; + va_start(args, format); + + char log[LOG_BUFFER_SIZE]; + vsnprintf(log, 1024, format, args); + strncat(log, "\n", LOG_BUFFER_SIZE - strlen(log)); + + sd_log_cb(level, log, server_log_params); + va_end(args); +} + +static void log_server_request(const httplib::Request& req, const httplib::Response& res) { + printf("request: %s %s (%s)\n", req.method.c_str(), req.path.c_str(), req.body.c_str()); +} + +bool parseJsonPrompt(std::string json_str, SDParams* params) { + bool updatectx = false; + using namespace nlohmann; + json payload = json::parse(json_str); + // if no exception, the request is a json object + // now we try to get the new param values from the payload object + // const char *prompt, const char *negative_prompt, int clip_skip, float cfg_scale, float guidance, int width, int height, sample_method_t sample_method, int sample_steps, int64_t seed, int batch_count, const sd_image_t *control_cond, float control_strength, float style_strength, bool normalize_input, const char *input_id_images_path + try { + std::string prompt = payload["prompt"]; + params->lastRequest.prompt = prompt; + } catch (...) { + } + try { + std::string negative_prompt = payload["negative_prompt"]; + params->lastRequest.negative_prompt = negative_prompt; + } catch (...) { + } + try { + int clip_skip = payload["clip_skip"]; + params->lastRequest.clip_skip = clip_skip; + } catch (...) { + } + try { + json guidance_params = payload["guidance_params"]; + try { + float cfg_scale = guidance_params["cfg_scale"]; + params->lastRequest.cfg_scale = cfg_scale; + } catch (...) { + } + try { + float guidance = guidance_params["guidance"]; + params->lastRequest.guidance = guidance; + } catch (...) { + } + try { + json slg = guidance_params["slg"]; + try { + params->lastRequest.skip_layers = slg["layers"].get>(); + } catch (...) { + } + try { + float slg_scale = slg["scale"]; + params->lastRequest.slg_scale = slg_scale; + } catch (...) { + } + try { + float skip_layer_start = slg["start"]; + params->lastRequest.skip_layer_start = skip_layer_start; + } catch (...) { + } + try { + float skip_layer_end = slg["end"]; + params->lastRequest.skip_layer_end = skip_layer_end; + } catch (...) { + } + } catch (...) { + } + // try { + // json apg = guidance_params["apg"]; + // try { + // float apg_eta = apg["eta"]; + // params->lastRequest.apg_eta = apg_eta; + // } catch (...) { + // } + // try { + // float apg_momentum = apg["momentum"]; + // params->lastRequest.apg_momentum = apg_momentum; + // } catch (...) { + // } + // try { + // float apg_norm_threshold = apg["norm_threshold"]; + // params->lastRequest.apg_norm_threshold = apg_norm_threshold; + // } catch (...) { + // } + // try { + // float apg_norm_smoothing = apg["norm_smoothing"]; + // params->lastRequest.apg_norm_smoothing = apg_norm_smoothing; + // } catch (...) { + // } + // } catch (...) { + // } + } catch (...) { + } + try { + int width = payload["width"]; + params->lastRequest.width = width; + } catch (...) { + } + try { + int height = payload["height"]; + params->lastRequest.height = height; + } catch (...) { + } + try { + std::string sample_method = payload["sample_method"]; + + sample_method_t sample_method_found = str_to_sample_method(sample_method.c_str()); + if (sample_method_found != SAMPLE_METHOD_COUNT) { + params->lastRequest.sample_method = sample_method_found; + } else { + sd_log(sd_log_level_t::SD_LOG_WARN, "Unknown sampling method: %s\n", sample_method.c_str()); + } + } catch (...) { + } + try { + int sample_steps = payload["sample_steps"]; + params->lastRequest.sample_steps = sample_steps; + } catch (...) { + } + try { + int64_t seed = payload["seed"]; + params->lastRequest.seed = seed; + } catch (...) { + } + try { + int batch_count = payload["batch_count"]; + params->lastRequest.batch_count = batch_count; + } catch (...) { + } + + try { + std::string control_cond = payload["control_cond"]; + + // TODO map to enum value + // LOG_WARN("control_cond is not supported yet\n"); + sd_log(sd_log_level_t::SD_LOG_WARN, "control_cond is not supported yet\n"); + } catch (...) { + } + try { + float control_strength = payload["control_strength"]; + // params->control_strength = control_strength; + // LOG_WARN("control_strength is not supported yet\n"); + sd_log(sd_log_level_t::SD_LOG_WARN, "control_strength is not supported yet\n", params); + } catch (...) { + } + try { + float style_strength = payload["style_strength"]; + // params->style_strength = style_strength; + // LOG_WARN("style_strength is not supported yet\n"); + sd_log(sd_log_level_t::SD_LOG_WARN, "style_strength is not supported yet\n", params); + } catch (...) { + } + try { + bool normalize_input = payload["normalize_input"]; + params->lastRequest.normalize_input = normalize_input; + } catch (...) { + } + try { + std::string input_id_images_path = payload["input_id_images_path"]; + // TODO replace with b64 image maybe? + params->input_id_images_path = input_id_images_path; + } catch (...) { + } + + try { + bool vae_cpu = payload["keep_vae_on_cpu"]; + if (params->ctxParams.keep_vae_on_cpu != vae_cpu) { + params->ctxParams.keep_vae_on_cpu = vae_cpu; + updatectx = true; + } + } catch (...) { + } + try { + bool clip_cpu = payload["keep_clip_on_cpu"]; + if (params->ctxParams.keep_clip_on_cpu != clip_cpu) { + params->ctxParams.keep_clip_on_cpu = clip_cpu; + updatectx = true; + } + } catch (...) { + } + try { + bool vae_tiling = payload["vae_tiling"]; + if (params->ctxParams.vae_tiling != vae_tiling) { + params->ctxParams.vae_tiling = vae_tiling; + updatectx = true; + } + } catch (...) { + } + const int MODEL_UNLOAD = -2; + const int MODEL_KEEP = -1; + try { + int model_index = payload["model"]; + if (model_index >= 0 && model_index < params->models_files.size()) { + std::string new_path = params->models_dir + params->models_files[model_index]; + if (params->ctxParams.model_path != new_path) { + params->ctxParams.model_path = new_path; + params->ctxParams.diffusion_model_path = ""; + updatectx = true; + } + } else { + if (model_index == MODEL_UNLOAD) { + if (params->ctxParams.model_path != "") { + updatectx = true; + } + params->ctxParams.model_path = ""; + } else if (model_index != MODEL_KEEP) { + sd_log(sd_log_level_t::SD_LOG_WARN, "Invalid model index: %d\n", model_index); + } + } + } catch (...) { + } + try { + int diffusion_model_index = payload["diffusion_model"]; + if (diffusion_model_index >= 0 && diffusion_model_index < params->diffusion_models_files.size()) { + std::string new_path = params->diffusion_models_dir + params->diffusion_models_files[diffusion_model_index]; + if (params->ctxParams.diffusion_model_path != new_path) { + params->ctxParams.diffusion_model_path = new_path; + params->ctxParams.model_path = ""; + updatectx = true; + } + } else if (diffusion_model_index == MODEL_UNLOAD) { + if (params->ctxParams.diffusion_model_path != "") { + updatectx = true; + } + params->ctxParams.diffusion_model_path = ""; + } else if (diffusion_model_index != MODEL_KEEP) { + sd_log(sd_log_level_t::SD_LOG_WARN, "Invalid diffusion model index: %d\n", diffusion_model_index); + } + } catch (...) { + } + try { + int clip_l_index = payload["clip_l"]; + if (clip_l_index >= 0 && clip_l_index < params->clip_files.size()) { + std::string new_path = params->clip_dir + params->clip_files[clip_l_index]; + if (params->ctxParams.clip_l_path != new_path) { + params->ctxParams.clip_l_path = new_path; + updatectx = true; + } + } else if (clip_l_index == MODEL_UNLOAD) { + if (params->ctxParams.clip_l_path != "") { + updatectx = true; + } + params->ctxParams.clip_l_path = ""; + } else if (clip_l_index != MODEL_KEEP) { + sd_log(sd_log_level_t::SD_LOG_WARN, "Invalid clip_l index: %d\n", clip_l_index); + } + } catch (...) { + } + try { + int clip_g_index = payload["clip_g"]; + if (clip_g_index >= 0 && clip_g_index < params->clip_files.size()) { + std::string new_path = params->clip_dir + params->clip_files[clip_g_index]; + if (params->ctxParams.clip_g_path != new_path) { + params->ctxParams.clip_g_path = new_path; + updatectx = true; + } + } else if (clip_g_index == MODEL_UNLOAD) { + if (params->ctxParams.clip_g_path != "") { + updatectx = true; + } + params->ctxParams.clip_g_path = ""; + } else if (clip_g_index != MODEL_KEEP) { + sd_log(sd_log_level_t::SD_LOG_WARN, "Invalid clip_g index: %d\n", clip_g_index); + } + } catch (...) { + } + try { + int t5xxl_index = payload["t5xxl"]; + if (t5xxl_index >= 0 && t5xxl_index < params->clip_files.size()) { + std::string new_path = params->clip_dir + params->clip_files[t5xxl_index]; + if (params->ctxParams.t5xxl_path != new_path) { + params->ctxParams.t5xxl_path = new_path; + updatectx = true; + } + } else if (t5xxl_index == MODEL_UNLOAD) { + if (params->ctxParams.t5xxl_path != "") { + updatectx = true; + } + params->ctxParams.t5xxl_path = ""; + } else if (t5xxl_index != MODEL_KEEP) { + sd_log(sd_log_level_t::SD_LOG_WARN, "Invalid t5xxl index: %d\n", t5xxl_index); + } + } catch (...) { + } + try { + int vae_index = payload["vae"]; + if (vae_index >= 0 && vae_index < params->vae_files.size()) { + std::string new_path = params->vae_dir + params->vae_files[vae_index]; + if (params->ctxParams.vae_path != new_path) { + params->ctxParams.vae_path = new_path; + updatectx = true; + } + } else if (vae_index == MODEL_UNLOAD) { + if (params->ctxParams.vae_path != "") { + updatectx = true; + } + params->ctxParams.vae_path = ""; + } else if (vae_index != MODEL_KEEP) { + sd_log(sd_log_level_t::SD_LOG_WARN, "Invalid vae index: %d\n", vae_index); + } + } catch (...) { + } + try { + int tae_index = payload["tae"]; + if (tae_index >= 0 && tae_index < params->tae_files.size()) { + std::string new_path = params->tae_dir + params->tae_files[tae_index]; + if (params->ctxParams.taesd_path != new_path) { + params->ctxParams.taesd_path = new_path; + updatectx = true; + } + } else if (tae_index == MODEL_UNLOAD) { + if (params->ctxParams.taesd_path != "") { + updatectx = true; + } + params->ctxParams.taesd_path = ""; + } else if (tae_index != MODEL_KEEP) { + sd_log(sd_log_level_t::SD_LOG_WARN, "Invalid tae index: %d\n", tae_index); + } + } catch (...) { + } + + try { + std::string schedule = payload["schedule"]; + schedule_t schedule_found = str_to_schedule(schedule.c_str()); + if (schedule_found != SCHEDULE_COUNT) { + if (params->ctxParams.schedule != schedule_found) { + params->ctxParams.schedule = schedule_found; + updatectx = true; + } + } else { + sd_log(sd_log_level_t::SD_LOG_WARN, "Unknown schedule: %s\n", schedule.c_str()); + } + } catch (...) { + } + + // try { + // bool tae_decode = payload["tae_decode"]; + // if (params->taesd_preview == tae_decode) { + // params->taesd_preview = !tae_decode; + // updatectx = true; + // } + // } catch (...) { + // } + + // try { + // std::string preview = payload["preview_mode"]; + // int preview_found = -1; + // for (int m = 0; m < N_PREVIEWS; m++) { + // if (!strcmp(preview.c_str(), previews_str[m])) { + // preview_found = m; + // } + // } + // if (preview_found >= 0) { + // if (params->lastRequest.preview_method != (sd_preview_t)preview_found) { + // params->lastRequest.preview_method = (sd_preview_t)preview_found; + // } + // } else { + // sd_log(sd_log_level_t::SD_LOG_WARN, "Unknown preview: %s\n", preview.c_str()); + // } + // } catch (...) { + // } + // try { + // int interval = payload["preview_interval"]; + // params->lastRequest.preview_interval = interval; + // } catch (...) { + // } + try { + std::string type = payload["type"]; + if (type != "") { + for (size_t i = 0; i < SD_TYPE_COUNT; i++) { + auto trait = ggml_get_type_traits((ggml_type)i); + std::string name(trait->type_name); + if (name == "f32" || trait->to_float && trait->type_size) { + if (type == name) { + params->ctxParams.wtype = (enum sd_type_t)i; + updatectx = true; + break; + } + } + } + } + } catch (...) { + } + return updatectx; +} + +std::vector list_files(const std::string& dir_path) { + namespace fs = std::filesystem; + std::vector files; + if (dir_path != "") + for (const auto& entry : fs::recursive_directory_iterator(dir_path)) { + if (entry.is_regular_file()) { + auto relative_path = fs::relative(entry.path(), dir_path); + std::string path_str = relative_path.string(); + std::replace(path_str.begin(), path_str.end(), '\\', '/'); + files.push_back(path_str); + } + } + return files; +} + +//--------------------------------------// +// Thread-safe queue +std::queue> task_queue; +std::mutex queue_mutex; +std::condition_variable queue_cond; +bool stop_worker = false; +std::atomic is_busy(false); +std::string running_task_id(""); + +std::unordered_map task_results; +std::mutex results_mutex; + +void worker_thread() { + while (!stop_worker) { + std::unique_lock lock(queue_mutex); + queue_cond.wait(lock, [] { return !task_queue.empty() || stop_worker; }); + + if (!task_queue.empty()) { + is_busy = true; + auto task = task_queue.front(); + task_queue.pop(); + lock.unlock(); + task(); + is_busy = false; + running_task_id = ""; + } + } +} + +void add_task(std::string task_id, std::function task) { + std::lock_guard lock(queue_mutex); + task_queue.push([task_id, task]() { + task(); + }); + queue_cond.notify_one(); +} + +void update_progress_cb(int step, int steps, float time, void* _data) { + using json = nlohmann::json; + if (running_task_id != "") { + std::lock_guard results_lock(results_mutex); + json running_task_json = task_results[running_task_id]; + if (running_task_json["status"] == "Working" && running_task_json["step"] == running_task_json["steps"]) { + running_task_json["status"] = "Decoding"; + } + running_task_json["step"] = step; + running_task_json["steps"] = steps; + task_results[running_task_id] = running_task_json; + } +} + +bool is_model_file(const std::string& path) { + size_t name_start = path.find_last_of("/\\"); + if (name_start == std::string::npos) { + name_start = 0; + } + size_t extension_start = path.substr(name_start).find_last_of("."); + if (extension_start == std::string::npos) { + return false; // No extension + } + std::string file_extension = path.substr(name_start + extension_start + 1); + return (file_extension == "gguf" || file_extension == "safetensors" || file_extension == "sft" || file_extension == "ckpt"); +} + +void start_server(SDParams params) { + // preview_path = params.preview_path.c_str(); + sd_set_log_callback(sd_log_cb, (void*)¶ms); + sd_set_progress_callback(update_progress_cb, NULL); + + params.models_files = list_files(params.models_dir); + params.diffusion_models_files = list_files(params.diffusion_models_dir); + params.clip_files = list_files(params.clip_dir); + params.vae_files = list_files(params.vae_dir); + params.tae_files = list_files(params.tae_dir); + std::vector lora_files = list_files(params.ctxParams.lora_model_dir); + + + server_log_params = (void*)¶ms; + + if (params.verbose) { + print_params(params); + printf("%s", sd_get_system_info()); + } + + sd_ctx_t* sd_ctx = NULL; + + int n_prompts = 0; + + std::unique_ptr svr; + svr.reset(new httplib::Server()); + svr->set_default_headers({{"Server", "sd.cpp"}}); + // CORS preflight + svr->Options(R"(.*)", [](const httplib::Request&, httplib::Response& res) { + // Access-Control-Allow-Origin is already set by middleware + res.set_header("Access-Control-Allow-Credentials", "true"); + res.set_header("Access-Control-Allow-Methods", "POST"); + res.set_header("Access-Control-Allow-Headers", "*"); + return res.set_content("", "text/html"); // blank response, no data + }); + if (params.verbose) { + svr->set_logger(log_server_request); + } + + svr->Post("/txt2img", [&sd_ctx, ¶ms, &n_prompts](const httplib::Request& req, httplib::Response& res) { + using json = nlohmann::json; + std::string task_id = std::to_string(std::chrono::system_clock::now().time_since_epoch().count()); + + { + json pending_task_json = json::object(); + pending_task_json["status"] = "Pending"; + pending_task_json["data"] = json::array(); + pending_task_json["step"] = -1; + pending_task_json["steps"] = 0; + pending_task_json["eta"] = "?"; + + std::lock_guard results_lock(results_mutex); + task_results[task_id] = pending_task_json; + } + + auto task = [req, &sd_ctx, ¶ms, &n_prompts, task_id]() { + running_task_id = task_id; + // LOG_DEBUG("raw body is: %s\n", req.body.c_str()); + sd_log(sd_log_level_t::SD_LOG_DEBUG, "raw body is: %s\n", req.body.c_str()); + // parse req.body as json using jsoncpp + bool updateCTX = false; + try { + std::string json_str = req.body; + updateCTX = parseJsonPrompt(json_str, ¶ms); + } catch (json::parse_error& e) { + // assume the request is just a prompt + // LOG_WARN("Failed to parse json: %s\n Assuming it's just a prompt...\n", e.what()); + sd_log(sd_log_level_t::SD_LOG_WARN, "Failed to parse json: %s\n Assuming it's just a prompt...\n", e.what()); + std::string prompt = req.body; + if (!prompt.empty()) { + params.lastRequest.prompt = prompt; + } else { + params.lastRequest.seed += 1; + } + } catch (...) { + // Handle any other type of exception + // LOG_ERROR("An unexpected error occurred\n"); + sd_log(sd_log_level_t::SD_LOG_ERROR, "An unexpected error occurred\n"); + } + // LOG_DEBUG("prompt is: %s\n", params.prompt.c_str()); + sd_log(sd_log_level_t::SD_LOG_INFO, "prompt is: %s\n", params.lastRequest.prompt.c_str()); + + if (updateCTX && sd_ctx != NULL) { + free_sd_ctx(sd_ctx); + sd_ctx = NULL; + } + + if (sd_ctx == NULL) { + printf("Loading sd_ctx\n"); + { + json task_json = json::object(); + task_json["status"] = "Loading"; + task_json["data"] = json::array(); + task_json["step"] = -1; + task_json["steps"] = 0; + task_json["eta"] = "?"; + + std::lock_guard results_lock(results_mutex); + task_results[task_id] = task_json; + } + sd_ctx_params_t sd_ctx_params = { + params.ctxParams.model_path.c_str(), + params.ctxParams.clip_l_path.c_str(), + params.ctxParams.clip_g_path.c_str(), + params.ctxParams.t5xxl_path.c_str(), + params.ctxParams.diffusion_model_path.c_str(), + params.ctxParams.vae_path.c_str(), + params.ctxParams.taesd_path.c_str(), + params.ctxParams.control_net_path.c_str(), + params.ctxParams.lora_model_dir.c_str(), + params.ctxParams.embeddings_path.c_str(), + params.ctxParams.stacked_id_embeddings_path.c_str(), + params.ctxParams.vae_decode_only, + params.ctxParams.vae_tiling, + false, + params.ctxParams.n_threads, + params.ctxParams.wtype, + params.ctxParams.rng_type, + params.ctxParams.schedule, + params.ctxParams.keep_clip_on_cpu, + params.ctxParams.keep_control_net_on_cpu, + params.ctxParams.keep_vae_on_cpu, + params.ctxParams.diffusion_flash_attn, + true, false, 1}; + sd_ctx = new_sd_ctx(&sd_ctx_params); + if (sd_ctx == NULL) { + printf("new_sd_ctx_t failed\n"); + std::lock_guard results_lock(results_mutex); + task_results[task_id]["status"] = "Failed"; + return; + } + } + + { + json started_task_json = json::object(); + started_task_json["status"] = "Working"; + started_task_json["data"] = json::array(); + started_task_json["step"] = 0; + started_task_json["steps"] = params.lastRequest.sample_steps; + started_task_json["eta"] = "?"; + + std::lock_guard results_lock(results_mutex); + task_results[task_id] = started_task_json; + } + + { + sd_image_t* results; + sd_slg_params_t slg = { + params.lastRequest.skip_layers.data(), + params.lastRequest.skip_layers.size(), + params.lastRequest.skip_layer_start, + params.lastRequest.skip_layer_end, + params.lastRequest.slg_scale}; + sd_guidance_params_t guidance = { + params.lastRequest.cfg_scale, + params.lastRequest.cfg_scale, + params.lastRequest.cfg_scale, + params.lastRequest.guidance, + slg}; + sd_image_t input_image = { + (uint32_t)params.lastRequest.width, + (uint32_t)params.lastRequest.height, + 3, + NULL}; + sd_image_t mask_img = input_image; + sd_img_gen_params_t gen_params = { + params.lastRequest.prompt.c_str(), + params.lastRequest.negative_prompt.c_str(), + params.lastRequest.clip_skip, + guidance, + input_image, + NULL, // ref images + 0, // ref images count + mask_img, + params.lastRequest.width, + params.lastRequest.height, + params.lastRequest.sample_method, + params.lastRequest.sample_steps, + 0.f, // eta + 1.f, // denoise strength + params.lastRequest.seed, + params.lastRequest.batch_count, + NULL, // control image ptr + 1.f, // control strength + params.lastRequest.style_ratio, + params.lastRequest.normalize_input, + params.input_id_images_path.c_str()}; + results = generate_image(sd_ctx, &gen_params); + + if (results == NULL) { + printf("generate failed\n"); + free_sd_ctx(sd_ctx); + std::lock_guard results_lock(results_mutex); + task_results[task_id]["status"] = "Failed"; + return; + } + + size_t last = params.output_path.find_last_of("."); + std::string dummy_name = last != std::string::npos ? params.output_path.substr(0, last) : params.output_path; + json images_json = json::array(); + for (int i = 0; i < params.lastRequest.batch_count; i++) { + if (results[i].data == NULL) { + continue; + } + // TODO allow disable save to disk + std::string final_image_path = i > 0 ? dummy_name + "_" + std::to_string(i + 1 + n_prompts * params.lastRequest.batch_count) + ".png" : dummy_name + ".png"; + stbi_write_png(final_image_path.c_str(), results[i].width, results[i].height, results[i].channel, + results[i].data, 0, get_image_params(params, params.lastRequest.seed + i).c_str()); + printf("save result image to '%s'\n", final_image_path.c_str()); + // Todo: return base64 encoded image via httplib::Response& res + + int len; + unsigned char* png = stbi_write_png_to_mem((const unsigned char*)results[i].data, 0, results[i].width, results[i].height, results[i].channel, &len, get_image_params(params, params.lastRequest.seed + i).c_str()); + + std::string data_str(png, png + len); + std::string encoded_img = base64_encode(data_str); + + images_json.push_back({{"width", results[i].width}, + {"height", results[i].height}, + {"channel", results[i].channel}, + {"data", encoded_img}, + {"encoding", "png"}}); + + free(results[i].data); + results[i].data = NULL; + } + free(results); + n_prompts++; + // res.set_content(images_json.dump(), "application/json"); + json end_task_json = json::object(); + end_task_json["status"] = "Done"; + end_task_json["data"] = images_json; + end_task_json["step"] = -1; + end_task_json["steps"] = 0; + end_task_json["eta"] = "?"; + std::lock_guard results_lock(results_mutex); + task_results[task_id] = end_task_json; + } + }; + // Add the task to the queue + add_task(task_id, task); + + json response = json::object(); + response["task_id"] = task_id; + res.set_content(response.dump(), "application/json"); + }); + + svr->Get("/params", [¶ms](const httplib::Request& req, httplib::Response& res) { + using json = nlohmann::json; + json response; + json params_json = json::object(); + params_json["prompt"] = params.lastRequest.prompt; + params_json["negative_prompt"] = params.lastRequest.negative_prompt; + params_json["clip_skip"] = params.lastRequest.clip_skip; + params_json["cfg_scale"] = params.lastRequest.cfg_scale; + params_json["guidance"] = params.lastRequest.guidance; + params_json["width"] = params.lastRequest.width; + params_json["height"] = params.lastRequest.height; + params_json["sample_method"] = sd_sample_method_name(params.lastRequest.sample_method); + params_json["sample_steps"] = params.lastRequest.sample_steps; + params_json["seed"] = params.lastRequest.seed; + params_json["batch_count"] = params.lastRequest.batch_count; + params_json["normalize_input"] = params.lastRequest.normalize_input; + // params_json["input_id_images_path"] = params.input_id_images_path; + + json context_params = json::object(); + // Do not expose paths + // context_params["model_path"] = params.ctxParams.model_path; + // context_params["clip_l_path"] = params.ctxParams.clip_l_path; + // context_params["clip_g_path"] = params.ctxParams.clip_g_path; + // context_params["t5xxl_path"] = params.ctxParams.t5xxl_path; + // context_params["diffusion_model_path"] = params.ctxParams.diffusion_model_path; + // context_params["vae_path"] = params.ctxParams.vae_path; + // context_params["control_net_path"] = params.ctxParams.control_net_path; + context_params["lora_model_dir"] = params.ctxParams.lora_model_dir; + // context_params["embeddings_path"] = params.ctxParams.embeddings_path; + // context_params["stacked_id_embeddings_path"] = params.ctxParams.stacked_id_embeddings_path; + context_params["vae_decode_only"] = params.ctxParams.vae_decode_only; + context_params["vae_tiling"] = params.ctxParams.vae_tiling; + context_params["n_threads"] = params.ctxParams.n_threads; + context_params["wtype"] = params.ctxParams.wtype; + context_params["rng_type"] = params.ctxParams.rng_type; + context_params["schedule"] = sd_schedule_name(params.ctxParams.schedule); + context_params["keep_clip_on_cpu"] = params.ctxParams.keep_clip_on_cpu; + context_params["keep_control_net_on_cpu"] = params.ctxParams.keep_control_net_on_cpu; + context_params["keep_vae_on_cpu"] = params.ctxParams.keep_vae_on_cpu; + context_params["diffusion_flash_attn"] = params.ctxParams.diffusion_flash_attn; + + // response["taesd_preview"] = params.taesd_preview; + // params_json["preview_method"] = previews_str[params.lastRequest.preview_method]; + // params_json["preview_interval"] = params.lastRequest.preview_interval; + + response["generation_params"] = params_json; + response["context_params"] = context_params; + res.set_content(response.dump(), "application/json"); + }); + + svr->Get("/result", [](const httplib::Request& req, httplib::Response& res) { + using json = nlohmann::json; + // Parse task ID from query parameters + try { + std::string task_id = req.get_param_value("task_id"); + std::lock_guard lock(results_mutex); + if (task_results.find(task_id) != task_results.end()) { + json result = task_results[task_id]; + res.set_content(result.dump(), "application/json"); + // Erase data after sending + result["data"] = json::array(); + task_results[task_id] = result; + } else { + res.set_content("Cannot find task " + task_id + " in queue", "text/plain"); + } + } catch (...) { + sd_log(sd_log_level_t::SD_LOG_WARN, "Error when fetching result"); + } + }); + + svr->Get("/sample_methods", [](const httplib::Request& req, httplib::Response& res) { + using json = nlohmann::json; + json response; + for (int m = 0; m < SAMPLE_METHOD_COUNT; m++) { + response.push_back(sd_sample_method_name((sample_method_t)m)); + } + res.set_content(response.dump(), "application/json"); + }); + + svr->Get("/schedules", [](const httplib::Request& req, httplib::Response& res) { + using json = nlohmann::json; + json response; + for (int s = 0; s < SCHEDULE_COUNT; s++) { + response.push_back(sd_schedule_name((schedule_t)s)); + } + res.set_content(response.dump(), "application/json"); + }); + + svr->Get("/previews", [](const httplib::Request& req, httplib::Response& res) { + using json = nlohmann::json; + json response; + // unsupported + // for (int s = 0; s < N_PREVIEWS; s++) { + // response.push_back(previews_str[s]); + // } + res.set_content(response.dump(), "application/json"); + }); + + svr->Get("/models", [¶ms, &lora_files](const httplib::Request& req, httplib::Response& res) { + using json = nlohmann::json; + json response; + + json models; + json diffusion_models; + json text_encoders; + json vaes; + json taes; + for (size_t i = 0; i < params.models_files.size(); i++) { + if (is_model_file(params.models_files[i])) { + models.push_back({{"id", i}, {"name", params.models_files[i]}}); + } + } + for (size_t i = 0; i < params.diffusion_models_files.size(); i++) { + if (is_model_file(params.diffusion_models_files[i])) { + diffusion_models.push_back({{"id", i}, {"name", params.diffusion_models_files[i]}}); + } + } + for (size_t i = 0; i < params.clip_files.size(); i++) { + if (is_model_file(params.clip_files[i])) { + text_encoders.push_back({{"id", i}, {"name", params.clip_files[i]}}); + } + } + for (size_t i = 0; i < params.vae_files.size(); i++) { + if (is_model_file(params.vae_files[i])) { + vaes.push_back({{"id", i}, {"name", params.vae_files[i]}}); + } + } + for (size_t i = 0; i < params.tae_files.size(); i++) { + if (is_model_file(params.tae_files[i])) { + taes.push_back({{"id", i}, {"name", params.tae_files[i]}}); + } + } + response["models"] = models; + response["diffusion_models"] = diffusion_models; + response["text_encoders"] = text_encoders; + response["vaes"] = vaes; + response["taes"] = taes; + + for (size_t i = 0; i < lora_files.size(); i++) { + std::string lora_name = lora_files[i]; + // Remove file extension + size_t pos = lora_name.find_last_of("."); + if (pos != std::string::npos) { + // Check if extension was either ".safetensors" or ".ckpt" + std::string extension = lora_name.substr(pos + 1); + lora_name = lora_name.substr(0, pos); + if (extension == "safetensors" || extension == "ckpt") { + response["loras"].push_back(lora_name); + } + } + } + + res.set_content(response.dump(), "application/json"); + }); + + svr->Get("/model", [¶ms](const httplib::Request& req, httplib::Response& res) { + using json = nlohmann::json; + json response; + if (!params.ctxParams.model_path.empty()) { + response["model"] = sd_basename(params.ctxParams.model_path); + } + if (!params.ctxParams.diffusion_model_path.empty()) { + response["diffusion_model"] = sd_basename(params.ctxParams.diffusion_model_path); + } + + if (!params.ctxParams.clip_l_path.empty()) { + response["clip_l"] = sd_basename(params.ctxParams.clip_l_path); + } + if (!params.ctxParams.clip_g_path.empty()) { + response["clip_g"] = sd_basename(params.ctxParams.clip_g_path); + } + if (!params.ctxParams.t5xxl_path.empty()) { + response["t5xxl"] = sd_basename(params.ctxParams.t5xxl_path); + } + + if (!params.ctxParams.vae_path.empty()) { + response["vae"] = sd_basename(params.ctxParams.vae_path); + } + if (!params.ctxParams.taesd_path.empty()) { + response["tae"] = sd_basename(params.ctxParams.taesd_path); + } + res.set_content(response.dump(), "application/json"); + }); + + svr->Get("/index.html", [](const httplib::Request& req, httplib::Response& res) { + try { + res.set_content(html_content, "text/html"); + } catch (const std::exception& e) { + res.set_content("Error loading page", "text/plain"); + } + }); + + // redirect base url to index + svr->Get("/", [](const httplib::Request& req, httplib::Response& res) { + res.set_redirect("/index.html"); + }); + + // bind HTTP listen port, run the HTTP server in a thread + if (!svr->bind_to_port(params.host, params.port)) { + // TODO: Error message + return; + } + std::thread t([&]() { svr->listen_after_bind(); }); + svr->wait_until_ready(); + + printf("Server listening at %s:%d\n", params.host.c_str(), params.port); + + t.join(); + + free_sd_ctx(sd_ctx); +} + +int main(int argc, const char* argv[]) { + SDParams params; + // Setup default args + parse_args(argc, argv, params); + + std::thread worker(worker_thread); + // Start the HTTP server + start_server(params); + + // Cleanup + stop_worker = true; + queue_cond.notify_one(); + worker.join(); + + return 0; +} \ No newline at end of file diff --git a/examples/server/test_client.py b/examples/server/test_client.py new file mode 100644 index 000000000..1a81e7f54 --- /dev/null +++ b/examples/server/test_client.py @@ -0,0 +1,184 @@ +import requests, json, base64 +from io import BytesIO +from PIL import Image, PngImagePlugin, ImageShow +import os +if os.name == 'nt': + from threading import Thread + +import time + + +from typing import List + + +def save_img(img: Image, path: str) -> None: + """ + Save the image to the specified path with metadata. + + Args: + img (Image): The image to be saved. + path (str): The path where the image will be saved. + + Returns: + None + """ + info = PngImagePlugin.PngInfo() + for key, value in img.info.items(): + info.add_text(key, value) + img.save(path, pnginfo=info) + +def show_img(img: Image, title = None) -> None: + """ + Display the image (with metadata) in a new window and print the path of the temporary file. + + Args: + img (Image): The image to be displayed. + title (str, optional): The title of the image window. Defaults to None. + + Returns: + None + """ + info = PngImagePlugin.PngInfo() + for key, value in img.info.items(): + info.add_text(key, value) + tmp = img._dump(None, format=img.format, pnginfo=info) + print(f"Image path: {tmp}\n") + for viewer in ImageShow._viewers: + if viewer.show_file(tmp,title=title): + return + +_protocol = "http" +_server = "localhost" +_port = 8080 +_endpoint = "txt2img" +url="" + +def update_url(protocol=None, server=None, port=None, endpoint=None) -> str: + """ + Update the global URL variable with the provided protocol, server, port, and endpoint. + + This function takes optional arguments for protocol, server, port, and endpoint. + If any of these arguments are provided, the corresponding global variable is updated with the new value. + The function then constructs the URL using the updated global variables and returns it. + + Args: + protocol (str, optional): The protocol to be used in the URL. Defaults to None. + server (str, optional): The server address to be used in the URL. Defaults to None. + port (int, optional): The port number to be used in the URL. Defaults to None. + endpoint (str, optional): The endpoint to be used in the URL. Defaults to None. + + Returns: + str: The updated URL. + """ + global _protocol, _server, _port, _endpoint, url + if protocol: + _protocol = protocol + if server: + _server = server + if port: + _port = port + if endpoint: + _endpoint = endpoint + url = f"{_protocol}://{_server}:{_port}/{_endpoint}" + return url + +# set default url value +update_url() + +def poll_result(id: str, show_previews = False): + global _protocol + global _server + global _port + + res = {'status':""} + while res['status'] != "Done": + res = requests.get(f"{_protocol}://{_server}:{_port}/result", params={'task_id':id}, timeout=.25).json() + if(show_previews and res['status'] == "Working" and len(res['data'])>0): + showImages(getImages(json.dumps(res['data']))) + + return json.dumps(res['data']) + +def sendRequest(payload: str) -> str: + """ + Send a POST request to the API endpoint with the provided payload. + + This function takes a payload as input and sends a POST request to the API endpoint specified by the global URL variable. + The function then returns the text content of the response. + + Args: + payload (str): The payload to be sent in the POST request. + + Returns: + str: The text content of the response from the POST request. + """ + global url + return requests.post(url, payload).json()['task_id'] + +def getImages(response: str) -> List[Image.Image]: + """ + Convert base64 encoded image data from the API response into a list of Image objects. + + This function takes the text response from the API as input and parses it as JSON. + It then iterates over each image data in the JSON response, decodes the base64 encoded image data, + and uses the BytesIO class to convert it into a PIL Image object. + The function returns a list of these Image objects. + + Args: + response (str): The text response from the API containing base64 encoded image data. + + Returns: + List[Image.Image]: A list of PIL Image objects decoded from the base64 encoded image data in the API response. + """ + return [Image.open(BytesIO(base64.b64decode(img["data"]))) for img in json.loads(response)] + +def showImages(imgs: List[Image.Image]) -> None: + """ + Display a list of images in separate threads. + + This function takes a list of PIL Image objects as input and creates a new thread for each image. + Each thread calls the show_img function to display the image in a new window and print the path of the temporary file. + The function does not return any value. + + Args: + imgs (List[Image.Image]): A list of PIL Image objects to be displayed. + + Returns: + None + """ + for (i,img) in enumerate(imgs): + if os.name == 'nt': + t = Thread(target=show_img, args=(img, f"IMG {i}")) + t.daemon = True + t.start() + else: + show_img(img, f"IMG {i}") + +def saveImages(imgs: List[Image.Image], path: str) -> None: + """ + Save a list of images to the specified path with metadata. + + This function takes a list of PIL Image objects and a path as input. + For each image, it calls the save_img function to save the image to a file + with the name "{path}{i}.png", where i is the index of the image in the list. + The function does not return any value. + + Args: + imgs (List[Image.Image]): A list of PIL Image objects to be saved. + path (str): The path where the images will be saved. + + Returns: + None + """ + if path.endswith(".png"): + path = path[:-4] + for (i, img) in enumerate(imgs): + save_img(img, f"{path}{i}.png") + + +def _print_usage(): + print("""Example usage (images will be displayed and saved to a temporary file): +update_url(server="127.0.0.1", port=8080) +showImages(getImages(poll_result(sendRequest(json.dumps({'seed': -1, 'batch_count':4, 'sample_steps':24, 'width': 512, 'height':768, 'negative_prompt': "Bad quality", 'prompt': "A beautiful image"})))))""") + +if __name__ == "__main__": + _print_usage() diff --git a/face_detect.py b/face_detect.py new file mode 100644 index 000000000..7131af31f --- /dev/null +++ b/face_detect.py @@ -0,0 +1,88 @@ +import os +import sys + +import numpy as np +import torch +from diffusers.utils import load_image +# pip install insightface==0.7.3 +from insightface.app import FaceAnalysis +from insightface.data import get_image as ins_get_image +from safetensors.torch import save_file + +### +# https://github.com/cubiq/ComfyUI_IPAdapter_plus/issues/165#issue-2055829543 +### +class FaceAnalysis2(FaceAnalysis): + # NOTE: allows setting det_size for each detection call. + # the model allows it but the wrapping code from insightface + # doesn't show it, and people end up loading duplicate models + # for different sizes where there is absolutely no need to + def get(self, img, max_num=0, det_size=(640, 640)): + if det_size is not None: + self.det_model.input_size = det_size + + return super().get(img, max_num) + +def analyze_faces(face_analysis: FaceAnalysis, img_data: np.ndarray, det_size=(640, 640)): + # NOTE: try detect faces, if no faces detected, lower det_size until it does + detection_sizes = [None] + [(size, size) for size in range(640, 256, -64)] + [(256, 256)] + + for size in detection_sizes: + faces = face_analysis.get(img_data, det_size=size) + if len(faces) > 0: + return faces + + return [] + +if __name__ == "__main__": + #face_detector = FaceAnalysis2(providers=['CUDAExecutionProvider'], allowed_modules=['detection', 'recognition']) + face_detector = FaceAnalysis2(providers=['CPUExecutionProvider'], allowed_modules=['detection', 'recognition']) + face_detector.prepare(ctx_id=0, det_size=(640, 640)) + #input_folder_name = './scarletthead_woman' + input_folder_name = sys.argv[1] + image_basename_list = os.listdir(input_folder_name) + image_path_list = sorted([os.path.join(input_folder_name, basename) for basename in image_basename_list]) + + input_id_images = [] + for image_path in image_path_list: + input_id_images.append(load_image(image_path)) + + id_embed_list = [] + + for img in input_id_images: + img = np.array(img) + img = img[:, :, ::-1] + faces = analyze_faces(face_detector, img) + if len(faces) > 0: + id_embed_list.append(torch.from_numpy((faces[0]['embedding']))) + + if len(id_embed_list) == 0: + raise ValueError(f"No face detected in input image pool") + + id_embeds = torch.stack(id_embed_list) + + # for r in id_embeds: + # print(r) + # #torch.save(id_embeds, input_folder_name+'/id_embeds.pt'); + # weights = dict() + # weights["id_embeds"] = id_embeds + # save_file(weights, input_folder_name+'/id_embeds.safetensors') + + binary_data = id_embeds.numpy().tobytes() + two = 4 + zero = 0 + one = 1 + tensor_name = "id_embeds" +# Write binary data to a file + with open(input_folder_name+'/id_embeds.bin', "wb") as f: + f.write(two.to_bytes(4, byteorder='little')) + f.write((len(tensor_name)).to_bytes(4, byteorder='little')) + f.write(zero.to_bytes(4, byteorder='little')) + f.write((id_embeds.shape[1]).to_bytes(4, byteorder='little')) + f.write((id_embeds.shape[0]).to_bytes(4, byteorder='little')) + f.write(one.to_bytes(4, byteorder='little')) + f.write(one.to_bytes(4, byteorder='little')) + f.write(tensor_name.encode('ascii')) + f.write(binary_data) + + \ No newline at end of file diff --git a/flux.hpp b/flux.hpp index 73bc345a7..11045918f 100644 --- a/flux.hpp +++ b/flux.hpp @@ -35,8 +35,9 @@ namespace Flux { int64_t hidden_size; float eps; - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["scale"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "scale") != tensor_types.end()) ? tensor_types[prefix + "scale"] : GGML_TYPE_F32; + params["scale"] = ggml_new_tensor_1d(ctx, wtype, hidden_size); } public: @@ -115,25 +116,29 @@ namespace Flux { struct ggml_tensor* q, struct ggml_tensor* k, struct ggml_tensor* v, - struct ggml_tensor* pe) { + struct ggml_tensor* pe, + struct ggml_tensor* mask, + bool flash_attn) { // q,k,v: [N, L, n_head, d_head] // pe: [L, d_head/2, 2, 2] // return: [N, L, n_head*d_head] q = apply_rope(ctx, q, pe); // [N*n_head, L, d_head] k = apply_rope(ctx, k, pe); // [N*n_head, L, d_head] - auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], NULL, false, true); // [N, L, n_head*d_head] + auto x = ggml_nn_attention_ext(ctx, q, k, v, v->ne[1], mask, false, true, flash_attn); // [N, L, n_head*d_head] return x; } struct SelfAttention : public GGMLBlock { public: int64_t num_heads; + bool flash_attn; public: SelfAttention(int64_t dim, int64_t num_heads = 8, - bool qkv_bias = false) + bool qkv_bias = false, + bool flash_attn = false) : num_heads(num_heads) { int64_t head_dim = dim / num_heads; blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); @@ -163,13 +168,13 @@ namespace Flux { return x; } - struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* pe) { + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* pe, struct ggml_tensor* mask) { // x: [N, n_token, dim] // pe: [n_token, d_head/2, 2, 2] // return [N, n_token, dim] - auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] - x = attention(ctx, qkv[0], qkv[1], qkv[2], pe); // [N, n_token, dim] - x = post_attention(ctx, x); // [N, n_token, dim] + auto qkv = pre_attention(ctx, x); // q,k,v: [N, n_token, n_head, d_head] + x = attention(ctx, qkv[0], qkv[1], qkv[2], pe, mask, flash_attn); // [N, n_token, dim] + x = post_attention(ctx, x); // [N, n_token, dim] return x; } }; @@ -181,6 +186,13 @@ namespace Flux { ModulationOut(ggml_tensor* shift = NULL, ggml_tensor* scale = NULL, ggml_tensor* gate = NULL) : shift(shift), scale(scale), gate(gate) {} + + ModulationOut(struct ggml_context* ctx, ggml_tensor* vec, int64_t offset) { + int64_t stride = vec->nb[1] * vec->ne[1]; + shift = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim] + scale = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim] + gate = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 2)); // [N, dim] + } }; struct Modulation : public GGMLBlock { @@ -206,19 +218,12 @@ namespace Flux { auto m = ggml_reshape_3d(ctx, out, vec->ne[0], multiplier, vec->ne[1]); // [N, multiplier, dim] m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [multiplier, N, dim] - int64_t offset = m->nb[1] * m->ne[1]; - auto shift_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, dim] - auto scale_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, dim] - auto gate_0 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, dim] - + ModulationOut m_0 = ModulationOut(ctx, m, 0); if (is_double) { - auto shift_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, dim] - auto scale_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, dim] - auto gate_1 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, dim] - return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut(shift_1, scale_1, gate_1)}; + return {m_0, ModulationOut(ctx, m, 3)}; } - return {ModulationOut(shift_0, scale_0, gate_0), ModulationOut()}; + return {m_0, ModulationOut()}; } }; @@ -237,24 +242,36 @@ namespace Flux { } struct DoubleStreamBlock : public GGMLBlock { + bool flash_attn; + bool prune_mod; + int idx = 0; + public: DoubleStreamBlock(int64_t hidden_size, int64_t num_heads, float mlp_ratio, - bool qkv_bias = false) { + int idx = 0, + bool qkv_bias = false, + bool flash_attn = false, + bool prune_mod = false) + : idx(idx), flash_attn(flash_attn), prune_mod(prune_mod) { int64_t mlp_hidden_dim = hidden_size * mlp_ratio; - blocks["img_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); - blocks["img_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias)); + if (!prune_mod) { + blocks["img_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); + } + blocks["img_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); + blocks["img_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn)); blocks["img_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); blocks["img_mlp.0"] = std::shared_ptr(new Linear(hidden_size, mlp_hidden_dim)); // img_mlp.1 is nn.GELU(approximate="tanh") blocks["img_mlp.2"] = std::shared_ptr(new Linear(mlp_hidden_dim, hidden_size)); - blocks["txt_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); + if (!prune_mod) { + blocks["txt_mod"] = std::shared_ptr(new Modulation(hidden_size, true)); + } blocks["txt_norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); - blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias)); + blocks["txt_attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, flash_attn)); blocks["txt_norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); blocks["txt_mlp.0"] = std::shared_ptr(new Linear(hidden_size, mlp_hidden_dim)); @@ -262,17 +279,34 @@ namespace Flux { blocks["txt_mlp.2"] = std::shared_ptr(new Linear(mlp_hidden_dim, hidden_size)); } + std::vector get_distil_img_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { + // TODO: not hardcoded? + const int single_blocks_count = 38; + const int double_blocks_count = 19; + + int64_t offset = 6 * idx + 3 * single_blocks_count; + return {ModulationOut(ctx, vec, offset), ModulationOut(ctx, vec, offset + 3)}; + } + + std::vector get_distil_txt_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { + // TODO: not hardcoded? + const int single_blocks_count = 38; + const int double_blocks_count = 19; + + int64_t offset = 6 * idx + 6 * double_blocks_count + 3 * single_blocks_count; + return {ModulationOut(ctx, vec, offset), ModulationOut(ctx, vec, offset + 3)}; + } + std::pair forward(struct ggml_context* ctx, struct ggml_tensor* img, struct ggml_tensor* txt, struct ggml_tensor* vec, - struct ggml_tensor* pe) { + struct ggml_tensor* pe, + struct ggml_tensor* mask = NULL) { // img: [N, n_img_token, hidden_size] // txt: [N, n_txt_token, hidden_size] // pe: [n_img_token + n_txt_token, d_head/2, 2, 2] // return: ([N, n_img_token, hidden_size], [N, n_txt_token, hidden_size]) - - auto img_mod = std::dynamic_pointer_cast(blocks["img_mod"]); auto img_norm1 = std::dynamic_pointer_cast(blocks["img_norm1"]); auto img_attn = std::dynamic_pointer_cast(blocks["img_attn"]); @@ -280,7 +314,6 @@ namespace Flux { auto img_mlp_0 = std::dynamic_pointer_cast(blocks["img_mlp.0"]); auto img_mlp_2 = std::dynamic_pointer_cast(blocks["img_mlp.2"]); - auto txt_mod = std::dynamic_pointer_cast(blocks["txt_mod"]); auto txt_norm1 = std::dynamic_pointer_cast(blocks["txt_norm1"]); auto txt_attn = std::dynamic_pointer_cast(blocks["txt_attn"]); @@ -288,10 +321,22 @@ namespace Flux { auto txt_mlp_0 = std::dynamic_pointer_cast(blocks["txt_mlp.0"]); auto txt_mlp_2 = std::dynamic_pointer_cast(blocks["txt_mlp.2"]); - auto img_mods = img_mod->forward(ctx, vec); + std::vector img_mods; + if (prune_mod) { + img_mods = get_distil_img_mod(ctx, vec); + } else { + auto img_mod = std::dynamic_pointer_cast(blocks["img_mod"]); + img_mods = img_mod->forward(ctx, vec); + } ModulationOut img_mod1 = img_mods[0]; ModulationOut img_mod2 = img_mods[1]; - auto txt_mods = txt_mod->forward(ctx, vec); + std::vector txt_mods; + if (prune_mod) { + txt_mods = get_distil_txt_mod(ctx, vec); + } else { + auto txt_mod = std::dynamic_pointer_cast(blocks["txt_mod"]); + txt_mods = txt_mod->forward(ctx, vec); + } ModulationOut txt_mod1 = txt_mods[0]; ModulationOut txt_mod2 = txt_mods[1]; @@ -316,7 +361,7 @@ namespace Flux { auto k = ggml_concat(ctx, txt_k, img_k, 2); // [N, n_txt_token + n_img_token, n_head, d_head] auto v = ggml_concat(ctx, txt_v, img_v, 2); // [N, n_txt_token + n_img_token, n_head, d_head] - auto attn = attention(ctx, q, k, v, pe); // [N, n_txt_token + n_img_token, n_head*d_head] + auto attn = attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_txt_token + n_img_token, n_head*d_head] attn = ggml_cont(ctx, ggml_permute(ctx, attn, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] auto txt_attn_out = ggml_view_3d(ctx, attn, @@ -364,13 +409,19 @@ namespace Flux { int64_t num_heads; int64_t hidden_size; int64_t mlp_hidden_dim; + bool flash_attn; + bool prune_mod; + int idx = 0; public: SingleStreamBlock(int64_t hidden_size, int64_t num_heads, float mlp_ratio = 4.0f, - float qk_scale = 0.f) - : hidden_size(hidden_size), num_heads(num_heads) { + int idx = 0, + float qk_scale = 0.f, + bool flash_attn = false, + bool prune_mod = false) + : hidden_size(hidden_size), num_heads(num_heads), idx(idx), flash_attn(flash_attn), prune_mod(prune_mod) { int64_t head_dim = hidden_size / num_heads; float scale = qk_scale; if (scale <= 0.f) { @@ -383,26 +434,37 @@ namespace Flux { blocks["norm"] = std::shared_ptr(new QKNorm(head_dim)); blocks["pre_norm"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-6f, false)); // mlp_act is nn.GELU(approximate="tanh") - blocks["modulation"] = std::shared_ptr(new Modulation(hidden_size, false)); + if (!prune_mod) { + blocks["modulation"] = std::shared_ptr(new Modulation(hidden_size, false)); + } + } + + ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { + int64_t offset = 3 * idx; + return ModulationOut(ctx, vec, offset); } struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* vec, - struct ggml_tensor* pe) { + struct ggml_tensor* pe, + struct ggml_tensor* mask = NULL) { // x: [N, n_token, hidden_size] // pe: [n_token, d_head/2, 2, 2] // return: [N, n_token, hidden_size] - auto linear1 = std::dynamic_pointer_cast(blocks["linear1"]); - auto linear2 = std::dynamic_pointer_cast(blocks["linear2"]); - auto norm = std::dynamic_pointer_cast(blocks["norm"]); - auto pre_norm = std::dynamic_pointer_cast(blocks["pre_norm"]); - auto modulation = std::dynamic_pointer_cast(blocks["modulation"]); - - auto mods = modulation->forward(ctx, vec); - ModulationOut mod = mods[0]; - + auto linear1 = std::dynamic_pointer_cast(blocks["linear1"]); + auto linear2 = std::dynamic_pointer_cast(blocks["linear2"]); + auto norm = std::dynamic_pointer_cast(blocks["norm"]); + auto pre_norm = std::dynamic_pointer_cast(blocks["pre_norm"]); + ModulationOut mod; + if (prune_mod) { + mod = get_distil_mod(ctx, vec); + } else { + auto modulation = std::dynamic_pointer_cast(blocks["modulation"]); + + mod = modulation->forward(ctx, vec)[0]; + } auto x_mod = Flux::modulate(ctx, pre_norm->forward(ctx, x), mod.shift, mod.scale); auto qkv_mlp = linear1->forward(ctx, x_mod); // [N, n_token, hidden_size * 3 + mlp_hidden_dim] qkv_mlp = ggml_cont(ctx, ggml_permute(ctx, qkv_mlp, 2, 0, 1, 3)); // [hidden_size * 3 + mlp_hidden_dim, N, n_token] @@ -433,7 +495,7 @@ namespace Flux { auto v = ggml_reshape_4d(ctx, qkv_vec[2], head_dim, num_heads, qkv_vec[2]->ne[1], qkv_vec[2]->ne[2]); // [N, n_token, n_head, d_head] q = norm->query_norm(ctx, q); k = norm->key_norm(ctx, k); - auto attn = attention(ctx, q, k, v, pe); // [N, n_token, hidden_size] + auto attn = attention(ctx, q, k, v, pe, mask, flash_attn); // [N, n_token, hidden_size] auto attn_mlp = ggml_concat(ctx, attn, ggml_gelu_inplace(ctx, mlp), 0); // [N, n_token, hidden_size + mlp_hidden_dim] auto output = linear2->forward(ctx, attn_mlp); // [N, n_token, hidden_size] @@ -444,13 +506,28 @@ namespace Flux { }; struct LastLayer : public GGMLBlock { + bool prune_mod; + public: LastLayer(int64_t hidden_size, int64_t patch_size, - int64_t out_channels) { - blocks["norm_final"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); - blocks["linear"] = std::shared_ptr(new Linear(hidden_size, patch_size * patch_size * out_channels)); - blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, 2 * hidden_size)); + int64_t out_channels, + bool prune_mod = false) + : prune_mod(prune_mod) { + blocks["norm_final"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); + blocks["linear"] = std::shared_ptr(new Linear(hidden_size, patch_size * patch_size * out_channels)); + if (!prune_mod) { + blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, 2 * hidden_size)); + } + } + + ModulationOut get_distil_mod(struct ggml_context* ctx, struct ggml_tensor* vec) { + int64_t offset = vec->ne[2] - 2; + int64_t stride = vec->nb[1] * vec->ne[1]; + auto shift = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 0)); // [N, dim] + auto scale = ggml_view_2d(ctx, vec, vec->ne[0], vec->ne[1], vec->nb[1], stride * (offset + 1)); // [N, dim] + // No gate + return ModulationOut(shift, scale, NULL); } struct ggml_tensor* forward(struct ggml_context* ctx, @@ -459,17 +536,24 @@ namespace Flux { // x: [N, n_token, hidden_size] // c: [N, hidden_size] // return: [N, n_token, patch_size * patch_size * out_channels] - auto norm_final = std::dynamic_pointer_cast(blocks["norm_final"]); - auto linear = std::dynamic_pointer_cast(blocks["linear"]); - auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); - - auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size] - m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] - m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] - - int64_t offset = m->nb[1] * m->ne[1]; - auto shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] - auto scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + auto norm_final = std::dynamic_pointer_cast(blocks["norm_final"]); + auto linear = std::dynamic_pointer_cast(blocks["linear"]); + struct ggml_tensor *shift, *scale; + if (prune_mod) { + auto mod = get_distil_mod(ctx, c); + shift = mod.shift; + scale = mod.scale; + } else { + auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); + + auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, 2 * hidden_size] + m = ggml_reshape_3d(ctx, m, c->ne[0], 2, c->ne[1]); // [N, 2, hidden_size] + m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [2, N, hidden_size] + + int64_t offset = m->nb[1] * m->ne[1]; + shift = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] + scale = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + } x = Flux::modulate(ctx, norm_final->forward(ctx, x), shift, scale); x = linear->forward(ctx, x); @@ -478,8 +562,37 @@ namespace Flux { } }; + struct ChromaApproximator : public GGMLBlock { + int64_t inner_size = 5120; + int64_t n_layers = 5; + ChromaApproximator(int64_t in_channels = 64, int64_t hidden_size = 3072) { + blocks["in_proj"] = std::shared_ptr(new Linear(in_channels, inner_size, true)); + for (int i = 0; i < n_layers; i++) { + blocks["norms." + std::to_string(i)] = std::shared_ptr(new RMSNorm(inner_size)); + blocks["layers." + std::to_string(i)] = std::shared_ptr(new MLPEmbedder(inner_size, inner_size)); + } + blocks["out_proj"] = std::shared_ptr(new Linear(inner_size, hidden_size, true)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + auto in_proj = std::dynamic_pointer_cast(blocks["in_proj"]); + auto out_proj = std::dynamic_pointer_cast(blocks["out_proj"]); + + x = in_proj->forward(ctx, x); + for (int i = 0; i < n_layers; i++) { + auto norm = std::dynamic_pointer_cast(blocks["norms." + std::to_string(i)]); + auto embed = std::dynamic_pointer_cast(blocks["layers." + std::to_string(i)]); + x = ggml_add_inplace(ctx, x, embed->forward(ctx, norm->forward(ctx, x))); + } + x = out_proj->forward(ctx, x); + + return x; + } + }; + struct FluxParams { int64_t in_channels = 64; + int64_t out_channels = 64; int64_t vec_in_dim = 768; int64_t context_in_dim = 4096; int64_t hidden_size = 3072; @@ -492,6 +605,8 @@ namespace Flux { int theta = 10000; bool qkv_bias = true; bool guidance_embed = true; + bool flash_attn = true; + bool is_chroma = false; }; struct Flux : public GGMLBlock { @@ -558,17 +673,22 @@ namespace Flux { } // Generate IDs for image patches and text - std::vector> gen_ids(int h, int w, int patch_size, int bs, int context_len) { + std::vector> gen_txt_ids(int bs, int context_len) { + return std::vector>(bs * context_len, std::vector(3, 0.0)); + } + + std::vector> gen_img_ids(int h, int w, int patch_size, int bs, int index = 0, int h_offset = 0, int w_offset = 0) { int h_len = (h + (patch_size / 2)) / patch_size; int w_len = (w + (patch_size / 2)) / patch_size; std::vector> img_ids(h_len * w_len, std::vector(3, 0.0)); - std::vector row_ids = linspace(0, h_len - 1, h_len); - std::vector col_ids = linspace(0, w_len - 1, w_len); + std::vector row_ids = linspace(h_offset, h_len - 1 + h_offset, h_len); + std::vector col_ids = linspace(w_offset, w_len - 1 + w_offset, w_len); for (int i = 0; i < h_len; ++i) { for (int j = 0; j < w_len; ++j) { + img_ids[i * w_len + j][0] = index; img_ids[i * w_len + j][1] = row_ids[i]; img_ids[i * w_len + j][2] = col_ids[j]; } @@ -580,24 +700,54 @@ namespace Flux { img_ids_repeated[i * img_ids.size() + j] = img_ids[j]; } } + return img_ids_repeated; + } - std::vector> txt_ids(bs * context_len, std::vector(3, 0.0)); - std::vector> ids(bs * (context_len + img_ids.size()), std::vector(3)); + std::vector> concat_ids(const std::vector>& a, + const std::vector>& b, + int bs) { + size_t a_len = a.size() / bs; + size_t b_len = b.size() / bs; + std::vector> ids(a.size() + b.size(), std::vector(3)); for (int i = 0; i < bs; ++i) { - for (int j = 0; j < context_len; ++j) { - ids[i * (context_len + img_ids.size()) + j] = txt_ids[j]; + for (int j = 0; j < a_len; ++j) { + ids[i * (a_len + b_len) + j] = a[i * a_len + j]; } - for (int j = 0; j < img_ids.size(); ++j) { - ids[i * (context_len + img_ids.size()) + context_len + j] = img_ids_repeated[i * img_ids.size() + j]; + for (int j = 0; j < b_len; ++j) { + ids[i * (a_len + b_len) + a_len + j] = b[i * b_len + j]; } } + return ids; + } + + std::vector> gen_ids(int h, int w, int patch_size, int bs, int context_len, std::vector ref_latents) { + auto txt_ids = gen_txt_ids(bs, context_len); + auto img_ids = gen_img_ids(h, w, patch_size, bs); + + auto ids = concat_ids(txt_ids, img_ids, bs); + uint64_t curr_h_offset = 0; + uint64_t curr_w_offset = 0; + for (ggml_tensor* ref : ref_latents) { + uint64_t h_offset = 0; + uint64_t w_offset = 0; + if (ref->ne[1] + curr_h_offset > ref->ne[0] + curr_w_offset) { + w_offset = curr_w_offset; + } else { + h_offset = curr_h_offset; + } + + auto ref_ids = gen_img_ids(ref->ne[1], ref->ne[0], patch_size, bs, 1, h_offset, w_offset); + ids = concat_ids(ids, ref_ids, bs); + curr_h_offset = std::max(curr_h_offset, ref->ne[1] + h_offset); + curr_w_offset = std::max(curr_w_offset, ref->ne[0] + w_offset); + } return ids; } // Generate positional embeddings - std::vector gen_pe(int h, int w, int patch_size, int bs, int context_len, int theta, const std::vector& axes_dim) { - std::vector> ids = gen_ids(h, w, patch_size, bs, context_len); + std::vector gen_pe(int h, int w, int patch_size, int bs, int context_len, std::vector ref_latents, int theta, const std::vector& axes_dim) { + std::vector> ids = gen_ids(h, w, patch_size, bs, context_len, ref_latents); std::vector> trans_ids = transpose(ids); size_t pos_len = ids.size(); int num_axes = axes_dim.size(); @@ -631,14 +781,17 @@ namespace Flux { Flux() {} Flux(FluxParams params) : params(params) { - int64_t out_channels = params.in_channels; - int64_t pe_dim = params.hidden_size / params.num_heads; - - blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size, true)); - blocks["time_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); - blocks["vector_in"] = std::shared_ptr(new MLPEmbedder(params.vec_in_dim, params.hidden_size)); - if (params.guidance_embed) { - blocks["guidance_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); + int64_t pe_dim = params.hidden_size / params.num_heads; + + blocks["img_in"] = std::shared_ptr(new Linear(params.in_channels, params.hidden_size, true)); + if (params.is_chroma) { + blocks["distilled_guidance_layer"] = std::shared_ptr(new ChromaApproximator(params.in_channels, params.hidden_size)); + } else { + blocks["time_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); + blocks["vector_in"] = std::shared_ptr(new MLPEmbedder(params.vec_in_dim, params.hidden_size)); + if (params.guidance_embed) { + blocks["guidance_in"] = std::shared_ptr(new MLPEmbedder(256, params.hidden_size)); + } } blocks["txt_in"] = std::shared_ptr(new Linear(params.context_in_dim, params.hidden_size, true)); @@ -646,16 +799,23 @@ namespace Flux { blocks["double_blocks." + std::to_string(i)] = std::shared_ptr(new DoubleStreamBlock(params.hidden_size, params.num_heads, params.mlp_ratio, - params.qkv_bias)); + i, + params.qkv_bias, + params.flash_attn, + params.is_chroma)); } for (int i = 0; i < params.depth_single_blocks; i++) { blocks["single_blocks." + std::to_string(i)] = std::shared_ptr(new SingleStreamBlock(params.hidden_size, params.num_heads, - params.mlp_ratio)); + params.mlp_ratio, + i, + 0.f, + params.flash_attn, + params.is_chroma)); } - blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, out_channels)); + blocks["final_layer"] = std::shared_ptr(new LastLayer(params.hidden_size, 1, params.out_channels, params.is_chroma)); } struct ggml_tensor* patchify(struct ggml_context* ctx, @@ -711,40 +871,78 @@ namespace Flux { struct ggml_tensor* timesteps, struct ggml_tensor* y, struct ggml_tensor* guidance, - struct ggml_tensor* pe) { + struct ggml_tensor* pe, + struct ggml_tensor* mod_index_arange = NULL, + std::vector skip_layers = {}) { auto img_in = std::dynamic_pointer_cast(blocks["img_in"]); - auto time_in = std::dynamic_pointer_cast(blocks["time_in"]); - auto vector_in = std::dynamic_pointer_cast(blocks["vector_in"]); auto txt_in = std::dynamic_pointer_cast(blocks["txt_in"]); auto final_layer = std::dynamic_pointer_cast(blocks["final_layer"]); - img = img_in->forward(ctx, img); - auto vec = time_in->forward(ctx, ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f)); + img = img_in->forward(ctx, img); + struct ggml_tensor* vec; + struct ggml_tensor* txt_img_mask = NULL; + if (params.is_chroma) { + int64_t mod_index_length = 344; + auto approx = std::dynamic_pointer_cast(blocks["distilled_guidance_layer"]); + auto distill_timestep = ggml_nn_timestep_embedding(ctx, timesteps, 16, 10000, 1000.f); + auto distill_guidance = ggml_nn_timestep_embedding(ctx, guidance, 16, 10000, 1000.f); + + // auto mod_index_arange = ggml_arange(ctx, 0, (float)mod_index_length, 1); + // ggml_arange tot working on a lot of backends, precomputing it on CPU instead + GGML_ASSERT(arange != NULL); + auto modulation_index = ggml_nn_timestep_embedding(ctx, mod_index_arange, 32, 10000, 1000.f); // [1, 344, 32] + + // Batch broadcast (will it ever be useful) + modulation_index = ggml_repeat(ctx, modulation_index, ggml_new_tensor_3d(ctx, GGML_TYPE_F32, modulation_index->ne[0], modulation_index->ne[1], img->ne[2])); // [N, 344, 32] + + auto timestep_guidance = ggml_concat(ctx, distill_timestep, distill_guidance, 0); // [N, 1, 32] + timestep_guidance = ggml_repeat(ctx, timestep_guidance, modulation_index); // [N, 344, 32] + + vec = ggml_concat(ctx, timestep_guidance, modulation_index, 0); // [N, 344, 64] + // Permute for consistency with non-distilled modulation implementation + vec = ggml_cont(ctx, ggml_permute(ctx, vec, 0, 2, 1, 3)); // [344, N, 64] + vec = approx->forward(ctx, vec); // [344, N, hidden_size] + + if (y != NULL) { + txt_img_mask = ggml_pad(ctx, y, img->ne[1], 0, 0, 0); + } + } else { + auto time_in = std::dynamic_pointer_cast(blocks["time_in"]); + auto vector_in = std::dynamic_pointer_cast(blocks["vector_in"]); + vec = time_in->forward(ctx, ggml_nn_timestep_embedding(ctx, timesteps, 256, 10000, 1000.f)); + if (params.guidance_embed) { + GGML_ASSERT(guidance != NULL); + auto guidance_in = std::dynamic_pointer_cast(blocks["guidance_in"]); + // bf16 and fp16 result is different + auto g_in = ggml_nn_timestep_embedding(ctx, guidance, 256, 10000, 1000.f); + vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in)); + } - if (params.guidance_embed) { - GGML_ASSERT(guidance != NULL); - auto guidance_in = std::dynamic_pointer_cast(blocks["guidance_in"]); - // bf16 and fp16 result is different - auto g_in = ggml_nn_timestep_embedding(ctx, guidance, 256, 10000, 1000.f); - vec = ggml_add(ctx, vec, guidance_in->forward(ctx, g_in)); + vec = ggml_add(ctx, vec, vector_in->forward(ctx, y)); } - vec = ggml_add(ctx, vec, vector_in->forward(ctx, y)); txt = txt_in->forward(ctx, txt); for (int i = 0; i < params.depth; i++) { + if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) { + continue; + } + auto block = std::dynamic_pointer_cast(blocks["double_blocks." + std::to_string(i)]); - auto img_txt = block->forward(ctx, img, txt, vec, pe); + auto img_txt = block->forward(ctx, img, txt, vec, pe, txt_img_mask); img = img_txt.first; // [N, n_img_token, hidden_size] txt = img_txt.second; // [N, n_txt_token, hidden_size] } auto txt_img = ggml_concat(ctx, txt, img, 1); // [N, n_txt_token + n_img_token, hidden_size] for (int i = 0; i < params.depth_single_blocks; i++) { + if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i + params.depth) != skip_layers.end()) { + continue; + } auto block = std::dynamic_pointer_cast(blocks["single_blocks." + std::to_string(i)]); - txt_img = block->forward(ctx, txt_img, vec, pe); + txt_img = block->forward(ctx, txt_img, vec, pe, txt_img_mask); } txt_img = ggml_cont(ctx, ggml_permute(ctx, txt_img, 0, 2, 1, 3)); // [n_txt_token + n_img_token, N, hidden_size] @@ -759,7 +957,20 @@ namespace Flux { img = ggml_cont(ctx, ggml_permute(ctx, img, 0, 2, 1, 3)); // [N, n_img_token, hidden_size] img = final_layer->forward(ctx, img, vec); // (N, T, patch_size ** 2 * out_channels) + return img; + } + struct ggml_tensor* process_img(struct ggml_context* ctx, + struct ggml_tensor* x) { + int64_t W = x->ne[0]; + int64_t H = x->ne[1]; + int64_t patch_size = 2; + int pad_h = (patch_size - H % patch_size) % patch_size; + int pad_w = (patch_size - W % patch_size) % patch_size; + x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] + + // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) + auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size] return img; } @@ -767,13 +978,18 @@ namespace Flux { struct ggml_tensor* x, struct ggml_tensor* timestep, struct ggml_tensor* context, + struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, - struct ggml_tensor* pe) { + struct ggml_tensor* pe, + struct ggml_tensor* mod_index_arange = NULL, + std::vector ref_latents = {}, + std::vector skip_layers = {}) { // Forward pass of DiT. // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) // timestep: (N,) tensor of diffusion timesteps // context: (N, L, D) + // c_concat: NULL, or for (N,C+M, H, W) for Fill // y: (N, adm_in_channels) tensor of class labels // guidance: (N,) // pe: (L, d_head/2, 2, 2) @@ -783,15 +999,37 @@ namespace Flux { int64_t W = x->ne[0]; int64_t H = x->ne[1]; + int64_t C = x->ne[2]; int64_t patch_size = 2; int pad_h = (patch_size - H % patch_size) % patch_size; int pad_w = (patch_size - W % patch_size) % patch_size; - x = ggml_pad(ctx, x, pad_w, pad_h, 0, 0); // [N, C, H + pad_h, W + pad_w] - // img = rearrange(x, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=patch_size, pw=patch_size) - auto img = patchify(ctx, x, patch_size); // [N, h*w, C * patch_size * patch_size] + auto img = process_img(ctx, x); + uint64_t img_tokens = img->ne[1]; - auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe); // [N, h*w, C * patch_size * patch_size] + if (c_concat != NULL) { + ggml_tensor* masked = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], C, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], 0); + ggml_tensor* mask = ggml_view_4d(ctx, c_concat, c_concat->ne[0], c_concat->ne[1], 8 * 8, 1, c_concat->nb[1], c_concat->nb[2], c_concat->nb[3], c_concat->nb[2] * C); + + masked = process_img(ctx, masked); + mask = process_img(ctx, mask); + + img = ggml_concat(ctx, img, ggml_concat(ctx, masked, mask, 0), 0); + } + + if (ref_latents.size() > 0) { + for (ggml_tensor* ref : ref_latents) { + ref = process_img(ctx, ref); + img = ggml_concat(ctx, img, ref, 1); + } + } + + auto out = forward_orig(ctx, img, context, timestep, y, guidance, pe, mod_index_arange, skip_layers); // [N, num_tokens, C * patch_size * patch_size] + if (out->ne[1] > img_tokens) { + out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [num_tokens, N, C * patch_size * patch_size] + out = ggml_view_3d(ctx, out, out->ne[0], out->ne[1], img_tokens, out->nb[1], out->nb[2], 0); + out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); // [N, h*w, C * patch_size * patch_size] + } // rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=2, pw=2) out = unpatchify(ctx, out, (H + pad_h) / patch_size, (W + pad_w) / patch_size, patch_size); // [N, C, H + pad_h, W + pad_w] @@ -801,20 +1039,69 @@ namespace Flux { }; struct FluxRunner : public GGMLRunner { + static std::map empty_tensor_types; + public: FluxParams flux_params; Flux flux; - std::vector pe_vec; // for cache + std::vector pe_vec; + std::vector mod_index_arange_vec; // for cache + SDVersion version; + bool use_mask = false; FluxRunner(ggml_backend_t backend, - ggml_type wtype, - SDVersion version = VERSION_FLUX_DEV) - : GGMLRunner(backend, wtype) { - if (version == VERSION_FLUX_SCHNELL) { - flux_params.guidance_embed = false; + std::map& tensor_types = empty_tensor_types, + const std::string prefix = "", + SDVersion version = VERSION_FLUX, + bool flash_attn = false, + bool use_mask = false) + : GGMLRunner(backend), use_mask(use_mask) { + flux_params.flash_attn = flash_attn; + flux_params.guidance_embed = false; + flux_params.depth = 0; + flux_params.depth_single_blocks = 0; + if (version == VERSION_FLUX_FILL) { + flux_params.in_channels = 384; + } + for (auto pair : tensor_types) { + std::string tensor_name = pair.first; + if (tensor_name.find("model.diffusion_model.") == std::string::npos) + continue; + if (tensor_name.find("guidance_in.in_layer.weight") != std::string::npos) { + // not schnell + flux_params.guidance_embed = true; + } + if (tensor_name.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { + // Chroma + flux_params.is_chroma = true; + } + size_t db = tensor_name.find("double_blocks."); + if (db != std::string::npos) { + tensor_name = tensor_name.substr(db); // remove prefix + int block_depth = atoi(tensor_name.substr(14, tensor_name.find(".", 14)).c_str()); + if (block_depth + 1 > flux_params.depth) { + flux_params.depth = block_depth + 1; + } + } + size_t sb = tensor_name.find("single_blocks."); + if (sb != std::string::npos) { + tensor_name = tensor_name.substr(sb); // remove prefix + int block_depth = atoi(tensor_name.substr(14, tensor_name.find(".", 14)).c_str()); + if (block_depth + 1 > flux_params.depth_single_blocks) { + flux_params.depth_single_blocks = block_depth + 1; + } + } + } + + LOG_INFO("Flux blocks: %d double, %d single", flux_params.depth, flux_params.depth_single_blocks); + if (flux_params.is_chroma) { + LOG_INFO("Using pruned modulation (Chroma)"); + } else if (!flux_params.guidance_embed) { + LOG_INFO("Flux guidance is disabled (Schnell mode)"); } + flux = Flux(flux_params); - flux.init(params_ctx, wtype); + flux.init(params_ctx, tensor_types, prefix); } std::string get_desc() { @@ -828,20 +1115,44 @@ namespace Flux { struct ggml_cgraph* build_graph(struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, + struct ggml_tensor* c_concat, struct ggml_tensor* y, - struct ggml_tensor* guidance) { + struct ggml_tensor* guidance, + std::vector ref_latents = {}, + std::vector skip_layers = {}) { GGML_ASSERT(x->ne[3] == 1); struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, FLUX_GRAPH_SIZE, false); - x = to_backend(x); - context = to_backend(context); - y = to_backend(y); + struct ggml_tensor* mod_index_arange = NULL; + + x = to_backend(x); + context = to_backend(context); + if (c_concat != NULL) { + c_concat = to_backend(c_concat); + } + if (flux_params.is_chroma) { + guidance = ggml_set_f32(guidance, 0); + + if (!use_mask) { + y = NULL; + } + + // ggml_arange is not working on some backends, precompute it + mod_index_arange_vec = arange(0, 344); + mod_index_arange = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_F32, mod_index_arange_vec.size()); + set_backend_tensor_data(mod_index_arange, mod_index_arange_vec.data()); + } + y = to_backend(y); + timesteps = to_backend(timesteps); - if (flux_params.guidance_embed) { + if (flux_params.guidance_embed || flux_params.is_chroma) { guidance = to_backend(guidance); } + for (int i = 0; i < ref_latents.size(); i++) { + ref_latents[i] = to_backend(ref_latents[i]); + } - pe_vec = flux.gen_pe(x->ne[1], x->ne[0], 2, x->ne[3], context->ne[1], flux_params.theta, flux_params.axes_dim); + pe_vec = flux.gen_pe(x->ne[1], x->ne[0], 2, x->ne[3], context->ne[1], ref_latents, flux_params.theta, flux_params.axes_dim); int pos_len = pe_vec.size() / flux_params.axes_dim_sum / 2; // LOG_DEBUG("pos_len %d", pos_len); auto pe = ggml_new_tensor_4d(compute_ctx, GGML_TYPE_F32, 2, 2, flux_params.axes_dim_sum / 2, pos_len); @@ -854,9 +1165,13 @@ namespace Flux { x, timesteps, context, + c_concat, y, guidance, - pe); + pe, + mod_index_arange, + ref_latents, + skip_layers); ggml_build_forward_expand(gf, out); @@ -867,17 +1182,20 @@ namespace Flux { struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, + struct ggml_tensor* c_concat, struct ggml_tensor* y, struct ggml_tensor* guidance, - struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL) { + std::vector ref_latents = {}, + struct ggml_tensor** output = NULL, + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { // x: [N, in_channels, h, w] // timesteps: [N, ] // context: [N, max_position, hidden_size] // y: [N, adm_in_channels] or [1, adm_in_channels] // guidance: [N, ] auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context, y, guidance); + return build_graph(x, timesteps, context, c_concat, y, guidance, ref_latents, skip_layers); }; GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); @@ -917,7 +1235,7 @@ namespace Flux { struct ggml_tensor* out = NULL; int t0 = ggml_time_ms(); - compute(8, x, timesteps, context, y, guidance, &out, work_ctx); + compute(8, x, timesteps, context, NULL, y, guidance, {}, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out); @@ -929,7 +1247,7 @@ namespace Flux { // ggml_backend_t backend = ggml_backend_cuda_init(0); ggml_backend_t backend = ggml_backend_cpu_init(); ggml_type model_data_type = GGML_TYPE_Q8_0; - std::shared_ptr flux = std::shared_ptr(new FluxRunner(backend, model_data_type)); + std::shared_ptr flux = std::shared_ptr(new FluxRunner(backend)); { LOG_INFO("loading from '%s'", file_path.c_str()); @@ -958,4 +1276,4 @@ namespace Flux { } // namespace Flux -#endif // __FLUX_HPP__ \ No newline at end of file +#endif // __FLUX_HPP__ diff --git a/ggml b/ggml index 21d3a308f..9e4bee1c5 160000 --- a/ggml +++ b/ggml @@ -1 +1 @@ -Subproject commit 21d3a308fcb7f31cb9beceaeebad4fb622f3c337 +Subproject commit 9e4bee1c5afc2d677a5b32ecb90cbdb483e81fff diff --git a/ggml_extend.hpp b/ggml_extend.hpp index e50137d5e..9f6a4fef6 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -22,9 +22,12 @@ #include "ggml-alloc.h" #include "ggml-backend.h" +#include "ggml-cpu.h" #include "ggml.h" -#ifdef SD_USE_CUBLAS +#include "model.h" + +#ifdef SD_USE_CUDA #include "ggml-cuda.h" #endif @@ -36,6 +39,10 @@ #include "ggml-vulkan.h" #endif +#ifdef SD_USE_OPENCL +#include "ggml-opencl.h" +#endif + #ifdef SD_USE_SYCL #include "ggml-sycl.h" #endif @@ -49,6 +56,72 @@ #define __STATIC_INLINE__ static inline #endif +// n-mode trensor-matrix product +// example: 2-mode product +// A: [ne03, k, ne01, ne00] +// B: k rows, m columns => [k, m] +// result is [ne03, m, ne01, ne00] +__STATIC_INLINE__ struct ggml_tensor* ggml_mul_n_mode(struct ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b, int mode = 0) { + // reshape A + // swap 0th and nth axis + a = ggml_cont(ctx, ggml_permute(ctx, a, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0)); + int ne1 = a->ne[1]; + int ne2 = a->ne[2]; + int ne3 = a->ne[3]; + // make 2D + a = ggml_cont(ctx, ggml_reshape_2d(ctx, a, a->ne[0], (ne3 * ne2 * ne1))); + + struct ggml_tensor* result = ggml_cont(ctx, ggml_transpose(ctx, ggml_mul_mat(ctx, a, b))); + + // reshape output (same shape as a after permutation except first dim) + result = ggml_reshape_4d(ctx, result, result->ne[0], ne1, ne2, ne3); + // swap back 0th and nth axis + result = ggml_permute(ctx, result, mode, mode != 1 ? 1 : 0, mode != 2 ? 2 : 0, mode != 3 ? 3 : 0); + return result; +} + +__STATIC_INLINE__ struct ggml_tensor* ggml_merge_lora(ggml_context* ctx, struct ggml_tensor* lora_down, struct ggml_tensor* lora_up, struct ggml_tensor* lora_mid = NULL) { + struct ggml_tensor* updown; + // flat lora tensors to multiply it + int64_t lora_up_rows = lora_up->ne[ggml_n_dims(lora_up) - 1]; + lora_up = ggml_reshape_2d(ctx, lora_up, ggml_nelements(lora_up) / lora_up_rows, lora_up_rows); + auto lora_down_n_dims = ggml_n_dims(lora_down); + // assume n_dims should always be a multiple of 2 (otherwise rank 1 doesn't work) + lora_down_n_dims = (lora_down_n_dims + lora_down_n_dims % 2); + int64_t lora_down_rows = lora_down->ne[lora_down_n_dims - 1]; + lora_down = ggml_reshape_2d(ctx, lora_down, ggml_nelements(lora_down) / lora_down_rows, lora_down_rows); + + // ggml_mul_mat requires tensor b transposed + lora_down = ggml_cont(ctx, ggml_transpose(ctx, lora_down)); + if (lora_mid == NULL) { + updown = ggml_mul_mat(ctx, lora_up, lora_down); + updown = ggml_cont(ctx, ggml_transpose(ctx, updown)); + } else { + // undoing tucker decomposition for conv layers. + // lora_mid has shape (3, 3, Rank, Rank) + // lora_down has shape (Rank, In, 1, 1) + // lora_up has shape (Rank, Out, 1, 1) + // conv layer shape is (3, 3, Out, In) + updown = ggml_mul_n_mode(ctx, ggml_mul_n_mode(ctx, lora_mid, lora_down, 3), lora_up, 2); + updown = ggml_cont(ctx, updown); + } + return updown; +} + +// Kronecker product +// [ne03,ne02,ne01,ne00] x [ne13,ne12,ne11,ne10] => [ne03*ne13,ne02*ne12,ne01*ne11,ne00*ne10] +__STATIC_INLINE__ struct ggml_tensor* ggml_kronecker(ggml_context* ctx, struct ggml_tensor* a, struct ggml_tensor* b) { + return ggml_mul(ctx, + ggml_upscale_ext(ctx, + a, + a->ne[0] * b->ne[0], + a->ne[1] * b->ne[1], + a->ne[2] * b->ne[2], + a->ne[3] * b->ne[3], + GGML_SCALE_MODE_NEAREST), + b); +} + __STATIC_INLINE__ void ggml_log_callback_default(ggml_log_level level, const char* text, void* user_data) { (void)level; (void)user_data; @@ -100,17 +173,11 @@ __STATIC_INLINE__ ggml_fp16_t ggml_tensor_get_f16(const ggml_tensor* tensor, int static struct ggml_tensor* get_tensor_from_graph(struct ggml_cgraph* gf, const char* name) { struct ggml_tensor* res = NULL; - for (int i = 0; i < gf->n_nodes; i++) { - // printf("%d, %s \n", i, gf->nodes[i]->name); - if (strcmp(ggml_get_name(gf->nodes[i]), name) == 0) { - res = gf->nodes[i]; - break; - } - } - for (int i = 0; i < gf->n_leafs; i++) { - // printf("%d, %s \n", i, gf->leafs[i]->name); - if (strcmp(ggml_get_name(gf->leafs[i]), name) == 0) { - res = gf->leafs[i]; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + struct ggml_tensor* node = ggml_graph_node(gf, i); + // printf("%d, %s \n", i, ggml_get_name(node)); + if (strcmp(ggml_get_name(node), name) == 0) { + res = node; break; } } @@ -293,6 +360,44 @@ __STATIC_INLINE__ void sd_image_to_tensor(const uint8_t* image_data, } } +__STATIC_INLINE__ void sd_mask_to_tensor(const uint8_t* image_data, + struct ggml_tensor* output, + bool scale = true) { + int64_t width = output->ne[0]; + int64_t height = output->ne[1]; + int64_t channels = output->ne[2]; + GGML_ASSERT(channels == 1 && output->type == GGML_TYPE_F32); + for (int iy = 0; iy < height; iy++) { + for (int ix = 0; ix < width; ix++) { + float value = *(image_data + iy * width * channels + ix); + if (scale) { + value /= 255.f; + } + ggml_tensor_set_f32(output, value, ix, iy); + } + } +} + +__STATIC_INLINE__ void sd_apply_mask(struct ggml_tensor* image_data, + struct ggml_tensor* mask, + struct ggml_tensor* output) { + int64_t width = output->ne[0]; + int64_t height = output->ne[1]; + int64_t channels = output->ne[2]; + GGML_ASSERT(output->type == GGML_TYPE_F32); + for (int ix = 0; ix < width; ix++) { + for (int iy = 0; iy < height; iy++) { + float m = ggml_tensor_get_f32(mask, ix, iy); + m = round(m); // inpaint models need binary masks + ggml_tensor_set_f32(mask, m, ix, iy); + for (int k = 0; k < channels; k++) { + float value = (1 - m) * (ggml_tensor_get_f32(image_data, ix, iy, k) - .5) + .5; + ggml_tensor_set_f32(output, value, ix, iy, k); + } + } + } +} + __STATIC_INLINE__ void sd_mul_images_to_tensor(const uint8_t* image_data, struct ggml_tensor* output, int idx, @@ -497,6 +602,8 @@ typedef std::function on_tile_process; // Tiling __STATIC_INLINE__ void sd_tiling(ggml_tensor* input, ggml_tensor* output, const int scale, const int tile_size, const float tile_overlap_factor, on_tile_process on_processing) { + output = ggml_set_f32(output, 0); + int input_width = (int)input->ne[0]; int input_height = (int)input->ne[1]; int output_width = (int)output->ne[0]; @@ -675,18 +782,16 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention(struct ggml_context* ctx struct ggml_tensor* k, struct ggml_tensor* v, bool mask = false) { -#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUBLAS) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL) +#if defined(SD_USE_FLASH_ATTENTION) && !defined(SD_USE_CUDA) && !defined(SD_USE_METAL) && !defined(SD_USE_VULKAN) && !defined(SD_USE_SYCL) struct ggml_tensor* kqv = ggml_flash_attn(ctx, q, k, v, false); // [N * n_head, n_token, d_head] #else - float d_head = (float)q->ne[0]; - + float d_head = (float)q->ne[0]; struct ggml_tensor* kq = ggml_mul_mat(ctx, k, q); // [N * n_head, n_token, n_k] kq = ggml_scale_inplace(ctx, kq, 1.0f / sqrt(d_head)); if (mask) { kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); } - kq = ggml_soft_max_inplace(ctx, kq); - + kq = ggml_soft_max_inplace(ctx, kq); struct ggml_tensor* kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, n_token, d_head] #endif return kqv; @@ -703,7 +808,8 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* int64_t n_head, struct ggml_tensor* mask = NULL, bool diag_mask_inf = false, - bool skip_reshape = false) { + bool skip_reshape = false, + bool flash_attn = false) { int64_t L_q; int64_t L_k; int64_t C; @@ -734,13 +840,54 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* float scale = (1.0f / sqrt((float)d_head)); - bool use_flash_attn = false; - ggml_tensor* kqv = NULL; - if (use_flash_attn) { + // if (flash_attn) { + // LOG_DEBUG("attention_ext L_q:%d L_k:%d n_head:%d C:%d d_head:%d N:%d", L_q, L_k, n_head, C, d_head, N); + // } + // is there anything oddly shaped?? ping Green-Sky if you can trip this assert + GGML_ASSERT(((L_k % 256 == 0) && L_q == L_k) || !(L_k % 256 == 0)); + + bool can_use_flash_attn = true; + can_use_flash_attn = can_use_flash_attn && L_k % 256 == 0; + can_use_flash_attn = can_use_flash_attn && d_head % 64 == 0; // double check + + // cuda max d_head seems to be 256, cpu does seem to work with 512 + can_use_flash_attn = can_use_flash_attn && d_head <= 256; // double check + + if (mask != nullptr) { + // TODO(Green-Sky): figure out if we can bend t5 to work too + can_use_flash_attn = can_use_flash_attn && mask->ne[2] == 1; + can_use_flash_attn = can_use_flash_attn && mask->ne[3] == 1; + } + + // TODO(Green-Sky): more pad or disable for funny tensor shapes + + ggml_tensor* kqv = nullptr; + // GGML_ASSERT((flash_attn && can_use_flash_attn) || !flash_attn); + if (can_use_flash_attn && flash_attn) { + // LOG_DEBUG("using flash attention"); + k = ggml_cast(ctx, k, GGML_TYPE_F16); + v = ggml_cont(ctx, ggml_permute(ctx, v, 0, 2, 1, 3)); // [N, n_head, L_k, d_head] v = ggml_reshape_3d(ctx, v, d_head, L_k, n_head * N); // [N * n_head, L_k, d_head] - LOG_DEBUG("k->ne[1] == %d", k->ne[1]); + v = ggml_cast(ctx, v, GGML_TYPE_F16); + + if (mask != nullptr) { + mask = ggml_transpose(ctx, mask); + + if (mask->ne[1] < GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD)) { + LOG_DEBUG("mask dims %ld, %ld, %ld, %ld\n", mask->ne[0], mask->ne[1], mask->ne[2], mask->ne[3]); + LOG_DEBUG("needs padding, padding from %ld to %ld\n", mask->ne[1], GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD)); + mask = ggml_pad(ctx, mask, 0, GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) - mask->ne[1], 0, 0); + } + + mask = ggml_cast(ctx, mask, GGML_TYPE_F16); + } + kqv = ggml_flash_attn_ext(ctx, q, k, v, mask, scale, 0, 0); + ggml_flash_attn_ext_set_prec(kqv, GGML_PREC_F32); + + // kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_k, kqv->nb[1], kqv->nb[2], 0); + kqv = ggml_view_3d(ctx, kqv, d_head, n_head, L_q, kqv->nb[1], kqv->nb[2], 0); } else { v = ggml_cont(ctx, ggml_permute(ctx, v, 1, 2, 0, 3)); // [N, n_head, d_head, L_k] v = ggml_reshape_3d(ctx, v, L_k, d_head, n_head * N); // [N * n_head, d_head, L_k] @@ -748,7 +895,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* auto kq = ggml_mul_mat(ctx, k, q); // [N * n_head, L_q, L_k] kq = ggml_scale_inplace(ctx, kq, scale); if (mask) { - kq = ggml_add(ctx, kq, mask); + kq = ggml_add_inplace(ctx, kq, mask); } if (diag_mask_inf) { kq = ggml_diag_mask_inf_inplace(ctx, kq, 0); @@ -756,10 +903,12 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_attention_ext(struct ggml_context* kq = ggml_soft_max_inplace(ctx, kq); kqv = ggml_mul_mat(ctx, v, kq); // [N * n_head, L_q, d_head] + + kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head] + kqv = ggml_permute(ctx, kqv, 0, 2, 1, 3); // [N, L_q, n_head, d_head] } - kqv = ggml_reshape_4d(ctx, kqv, d_head, L_q, n_head, N); // [N, n_head, L_q, d_head] - kqv = ggml_cont(ctx, ggml_permute(ctx, kqv, 0, 2, 1, 3)); // [N, L_q, n_head, d_head] + kqv = ggml_cont(ctx, kqv); kqv = ggml_reshape_3d(ctx, kqv, d_head * n_head, L_q, N); // [N, L_q, C] return kqv; @@ -801,7 +950,7 @@ __STATIC_INLINE__ struct ggml_tensor* ggml_nn_group_norm(struct ggml_context* ct } __STATIC_INLINE__ void ggml_backend_tensor_get_and_sync(ggml_backend_t backend, const struct ggml_tensor* tensor, void* data, size_t offset, size_t size) { -#if defined(SD_USE_CUBLAS) || defined(SD_USE_SYCL) +#if defined(SD_USE_CUDA) || defined(SD_USE_SYCL) if (!ggml_backend_is_cpu(backend)) { ggml_backend_tensor_get_async(backend, tensor, data, offset, size); ggml_backend_synchronize(backend); @@ -924,8 +1073,8 @@ __STATIC_INLINE__ size_t ggml_tensor_num(ggml_context* ctx) { } /* SDXL with LoRA requires more space */ -#define MAX_PARAMS_TENSOR_NUM 15360 -#define MAX_GRAPH_SIZE 15360 +#define MAX_PARAMS_TENSOR_NUM 32768 +#define MAX_GRAPH_SIZE 32768 struct GGMLRunner { protected: @@ -939,7 +1088,6 @@ struct GGMLRunner { std::map backend_tensor_data_map; - ggml_type wtype = GGML_TYPE_F32; ggml_backend_t backend = NULL; void alloc_params_ctx() { @@ -1015,8 +1163,8 @@ struct GGMLRunner { public: virtual std::string get_desc() = 0; - GGMLRunner(ggml_backend_t backend, ggml_type wtype = GGML_TYPE_F32) - : backend(backend), wtype(wtype) { + GGMLRunner(ggml_backend_t backend) + : backend(backend) { alloc_params_ctx(); } @@ -1047,6 +1195,11 @@ struct GGMLRunner { params_buffer_size / (1024.0 * 1024.0), ggml_backend_is_cpu(backend) ? "RAM" : "VRAM", num_tensors); + // printf("%s params backend buffer size = % 6.2f MB(%s) (%i tensors)\n", + // get_desc().c_str(), + // params_buffer_size / (1024.0 * 1024.0), + // ggml_backend_is_cpu(backend) ? "RAM" : "VRAM", + // num_tensors); return true; } @@ -1107,18 +1260,12 @@ struct GGMLRunner { ggml_backend_cpu_set_n_threads(backend, n_threads); } -#ifdef SD_USE_METAL - if (ggml_backend_is_metal(backend)) { - ggml_backend_metal_set_n_cb(backend, n_threads); - } -#endif ggml_backend_graph_compute(backend, gf); - #ifdef GGML_PERF ggml_graph_print(gf); #endif if (output != NULL) { - auto result = gf->nodes[gf->n_nodes - 1]; + auto result = ggml_graph_node(gf, -1); if (*output == NULL && output_ctx != NULL) { *output = ggml_dup_tensor(output_ctx, result); } @@ -1140,20 +1287,22 @@ class GGMLBlock { GGMLBlockMap blocks; ParameterMap params; - void init_blocks(struct ggml_context* ctx, ggml_type wtype) { + void init_blocks(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { for (auto& pair : blocks) { auto& block = pair.second; - - block->init(ctx, wtype); + block->init(ctx, tensor_types, prefix + pair.first); } } - virtual void init_params(struct ggml_context* ctx, ggml_type wtype) {} + virtual void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") {} public: - void init(struct ggml_context* ctx, ggml_type wtype) { - init_blocks(ctx, wtype); - init_params(ctx, wtype); + void init(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { + if (prefix.size() > 0) { + prefix = prefix + "."; + } + init_blocks(ctx, tensor_types, prefix); + init_params(ctx, tensor_types, prefix); } size_t get_params_num() { @@ -1209,13 +1358,15 @@ class Linear : public UnaryBlock { bool bias; bool force_f32; - void init_params(struct ggml_context* ctx, ggml_type wtype) { + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + enum ggml_type wtype = (tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; if (in_features % ggml_blck_size(wtype) != 0 || force_f32) { wtype = GGML_TYPE_F32; } params["weight"] = ggml_new_tensor_2d(ctx, wtype, in_features, out_features); if (bias) { - params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_features); + enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32; + params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_features); } } @@ -1243,9 +1394,9 @@ class Embedding : public UnaryBlock { protected: int64_t embedding_dim; int64_t num_embeddings; - - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["weight"] = ggml_new_tensor_2d(ctx, wtype, embedding_dim, num_embeddings); + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + enum ggml_type wtype = (tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; + params["weight"] = ggml_new_tensor_2d(ctx, wtype, embedding_dim, num_embeddings); } public: @@ -1283,10 +1434,12 @@ class Conv2d : public UnaryBlock { std::pair dilation; bool bias; - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["weight"] = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kernel_size.second, kernel_size.first, in_channels, out_channels); + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + enum ggml_type wtype = GGML_TYPE_F16; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F16; + params["weight"] = ggml_new_tensor_4d(ctx, wtype, kernel_size.second, kernel_size.first, in_channels, out_channels); if (bias) { - params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + enum ggml_type wtype = GGML_TYPE_F32; // (tensor_types.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32; + params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_channels); } } @@ -1326,10 +1479,12 @@ class Conv3dnx1x1 : public UnaryBlock { int64_t dilation; bool bias; - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["weight"] = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, 1, kernel_size, in_channels, out_channels); // 5d => 4d + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + enum ggml_type wtype = GGML_TYPE_F16; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F16; + params["weight"] = ggml_new_tensor_4d(ctx, wtype, 1, kernel_size, in_channels, out_channels); // 5d => 4d if (bias) { - params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, out_channels); + enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32; + params["bias"] = ggml_new_tensor_1d(ctx, wtype, out_channels); } } @@ -1368,11 +1523,13 @@ class LayerNorm : public UnaryBlock { bool elementwise_affine; bool bias; - void init_params(struct ggml_context* ctx, ggml_type wtype) { + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { if (elementwise_affine) { - params["weight"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, normalized_shape); + enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; + params["weight"] = ggml_new_tensor_1d(ctx, wtype, normalized_shape); if (bias) { - params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, normalized_shape); + enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.ypes.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32; + params["bias"] = ggml_new_tensor_1d(ctx, wtype, normalized_shape); } } } @@ -1408,10 +1565,12 @@ class GroupNorm : public GGMLBlock { float eps; bool affine; - void init_params(struct ggml_context* ctx, ggml_type wtype) { + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { if (affine) { - params["weight"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_channels); - params["bias"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, num_channels); + enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; + enum ggml_type bias_wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "bias") != tensor_types.end()) ? tensor_types[prefix + "bias"] : GGML_TYPE_F32; + params["weight"] = ggml_new_tensor_1d(ctx, wtype, num_channels); + params["bias"] = ggml_new_tensor_1d(ctx, bias_wtype, num_channels); } } diff --git a/gits_noise.inl b/gits_noise.inl index fd4750267..7a10ff76f 100644 --- a/gits_noise.inl +++ b/gits_noise.inl @@ -329,21 +329,21 @@ const std::vector> GITS_NOISE_1_50 = { }; const std::vector>*> GITS_NOISE = { - { &GITS_NOISE_0_80 }, - { &GITS_NOISE_0_85 }, - { &GITS_NOISE_0_90 }, - { &GITS_NOISE_0_95 }, - { &GITS_NOISE_1_00 }, - { &GITS_NOISE_1_05 }, - { &GITS_NOISE_1_10 }, - { &GITS_NOISE_1_15 }, - { &GITS_NOISE_1_20 }, - { &GITS_NOISE_1_25 }, - { &GITS_NOISE_1_30 }, - { &GITS_NOISE_1_35 }, - { &GITS_NOISE_1_40 }, - { &GITS_NOISE_1_45 }, - { &GITS_NOISE_1_50 } + &GITS_NOISE_0_80, + &GITS_NOISE_0_85, + &GITS_NOISE_0_90, + &GITS_NOISE_0_95, + &GITS_NOISE_1_00, + &GITS_NOISE_1_05, + &GITS_NOISE_1_10, + &GITS_NOISE_1_15, + &GITS_NOISE_1_20, + &GITS_NOISE_1_25, + &GITS_NOISE_1_30, + &GITS_NOISE_1_35, + &GITS_NOISE_1_40, + &GITS_NOISE_1_45, + &GITS_NOISE_1_50 }; #endif // GITS_NOISE_INL diff --git a/lora.hpp b/lora.hpp index c44db7698..35f5aacd1 100644 --- a/lora.hpp +++ b/lora.hpp @@ -3,9 +3,93 @@ #include "ggml_extend.hpp" -#define LORA_GRAPH_SIZE 10240 +#define LORA_GRAPH_BASE_SIZE 10240 struct LoraModel : public GGMLRunner { + enum lora_t { + REGULAR = 0, + DIFFUSERS = 1, + DIFFUSERS_2 = 2, + DIFFUSERS_3 = 3, + TRANSFORMERS = 4, + LORA_TYPE_COUNT + }; + + const std::string lora_ups[LORA_TYPE_COUNT] = { + ".lora_up", + "_lora.up", + ".lora_B", + ".lora.up", + ".lora_linear_layer.up", + }; + + const std::string lora_downs[LORA_TYPE_COUNT] = { + ".lora_down", + "_lora.down", + ".lora_A", + ".lora.down", + ".lora_linear_layer.down", + }; + + const std::string lora_pre[LORA_TYPE_COUNT] = { + "lora.", + "", + "", + "", + "", + }; + + const std::map alt_names = { + // mmdit + {"final_layer.adaLN_modulation.1", "norm_out.linear"}, + {"pos_embed", "pos_embed.proj"}, + {"final_layer.linear", "proj_out"}, + {"y_embedder.mlp.0", "time_text_embed.text_embedder.linear_1"}, + {"y_embedder.mlp.2", "time_text_embed.text_embedder.linear_2"}, + {"t_embedder.mlp.0", "time_text_embed.timestep_embedder.linear_1"}, + {"t_embedder.mlp.2", "time_text_embed.timestep_embedder.linear_2"}, + {"x_block.mlp.fc1", "ff.net.0.proj"}, + {"x_block.mlp.fc2", "ff.net.2"}, + {"context_block.mlp.fc1", "ff_context.net.0.proj"}, + {"context_block.mlp.fc2", "ff_context.net.2"}, + {"x_block.adaLN_modulation.1", "norm1.linear"}, + {"context_block.adaLN_modulation.1", "norm1_context.linear"}, + {"context_block.attn.proj", "attn.to_add_out"}, + {"x_block.attn.proj", "attn.to_out.0"}, + {"x_block.attn2.proj", "attn2.to_out.0"}, + // flux + // singlestream + {"linear2", "proj_out"}, + {"modulation.lin", "norm.linear"}, + // doublestream + {"txt_attn.proj", "attn.to_add_out"}, + {"img_attn.proj", "attn.to_out.0"}, + {"txt_mlp.0", "ff_context.net.0.proj"}, + {"txt_mlp.2", "ff_context.net.2"}, + {"img_mlp.0", "ff.net.0.proj"}, + {"img_mlp.2", "ff.net.2"}, + {"txt_mod.lin", "norm1_context.linear"}, + {"img_mod.lin", "norm1.linear"}, + }; + + const std::map qkv_prefixes = { + // mmdit + {"context_block.attn.qkv", "attn.add_"}, // suffix "_proj" + {"x_block.attn.qkv", "attn.to_"}, + {"x_block.attn2.qkv", "attn2.to_"}, + // flux + // doublestream + {"txt_attn.qkv", "attn.add_"}, // suffix "_proj" + {"img_attn.qkv", "attn.to_"}, + }; + const std::map qkvm_prefixes = { + // flux + // singlestream + {"linear1", ""}, + }; + + const std::string* type_fingerprints = lora_ups; + float multiplier = 1.0f; std::map lora_tensors; std::string file_path; @@ -14,12 +98,12 @@ struct LoraModel : public GGMLRunner { bool applied = false; std::vector zero_index_vec = {0}; ggml_tensor* zero_index = NULL; + enum lora_t type = REGULAR; LoraModel(ggml_backend_t backend, - ggml_type wtype, const std::string& file_path = "", - const std::string& prefix = "") - : file_path(file_path), GGMLRunner(backend, wtype) { + const std::string prefix = "") + : file_path(file_path), GGMLRunner(backend) { if (!model_loader.init_from_file(file_path, prefix)) { load_failed = true; } @@ -45,6 +129,13 @@ struct LoraModel : public GGMLRunner { // LOG_INFO("skipping LoRA tesnor '%s'", name.c_str()); return true; } + // LOG_INFO("%s", name.c_str()); + for (int i = 0; i < LORA_TYPE_COUNT; i++) { + if (name.find(type_fingerprints[i]) != std::string::npos) { + type = (lora_t)i; + break; + } + } if (dry_run) { struct ggml_tensor* real = ggml_new_tensor(params_ctx, @@ -62,10 +153,12 @@ struct LoraModel : public GGMLRunner { model_loader.load_tensors(on_new_tensor_cb, backend); alloc_params_buffer(); - + // exit(0); dry_run = false; model_loader.load_tensors(on_new_tensor_cb, backend); + LOG_DEBUG("lora type: \"%s\"/\"%s\"", lora_downs[type].c_str(), lora_ups[type].c_str()); + LOG_DEBUG("finished loaded lora"); return true; } @@ -77,103 +170,653 @@ struct LoraModel : public GGMLRunner { return out; } - struct ggml_cgraph* build_lora_graph(std::map model_tensors) { - struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, LORA_GRAPH_SIZE, false); + std::vector to_lora_keys(std::string blk_name, SDVersion version) { + std::vector keys; + // if (!sd_version_is_sd3(version) || blk_name != "model.diffusion_model.pos_embed") { + size_t k_pos = blk_name.find(".weight"); + if (k_pos == std::string::npos) { + return keys; + } + blk_name = blk_name.substr(0, k_pos); + // } + keys.push_back(blk_name); + keys.push_back("lora." + blk_name); + if (sd_version_is_dit(version)) { + if (blk_name.find("model.diffusion_model") != std::string::npos) { + blk_name.replace(blk_name.find("model.diffusion_model"), sizeof("model.diffusion_model") - 1, "transformer"); + } - zero_index = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_I32, 1); - set_backend_tensor_data(zero_index, zero_index_vec.data()); - ggml_build_forward_expand(gf, zero_index); + if (blk_name.find(".single_blocks") != std::string::npos) { + blk_name.replace(blk_name.find(".single_blocks"), sizeof(".single_blocks") - 1, ".single_transformer_blocks"); + } + if (blk_name.find(".double_blocks") != std::string::npos) { + blk_name.replace(blk_name.find(".double_blocks"), sizeof(".double_blocks") - 1, ".transformer_blocks"); + } - std::set applied_lora_tensors; - for (auto it : model_tensors) { - std::string k_tensor = it.first; - struct ggml_tensor* weight = model_tensors[it.first]; + if (blk_name.find(".joint_blocks") != std::string::npos) { + blk_name.replace(blk_name.find(".joint_blocks"), sizeof(".joint_blocks") - 1, ".transformer_blocks"); + } - size_t k_pos = k_tensor.find(".weight"); - if (k_pos == std::string::npos) { - continue; + if (blk_name.find("text_encoders.clip_l") != std::string::npos) { + blk_name.replace(blk_name.find("text_encoders.clip_l"), sizeof("text_encoders.clip_l") - 1, "cond_stage_model"); + } + + for (const auto& item : alt_names) { + size_t match = blk_name.find(item.first); + if (match != std::string::npos) { + blk_name = blk_name.substr(0, match) + item.second; + } + } + for (const auto& prefix : qkv_prefixes) { + size_t match = blk_name.find(prefix.first); + if (match != std::string::npos) { + std::string split_blk = "SPLIT|" + blk_name.substr(0, match) + prefix.second; + keys.push_back(split_blk); + } } - k_tensor = k_tensor.substr(0, k_pos); - replace_all_chars(k_tensor, '.', '_'); - // LOG_DEBUG("k_tensor %s", k_tensor.c_str()); - std::string lora_up_name = "lora." + k_tensor + ".lora_up.weight"; - if (lora_tensors.find(lora_up_name) == lora_tensors.end()) { - if (k_tensor == "model_diffusion_model_output_blocks_2_2_conv") { - // fix for some sdxl lora, like lcm-lora-xl - k_tensor = "model_diffusion_model_output_blocks_2_1_conv"; - lora_up_name = "lora." + k_tensor + ".lora_up.weight"; + for (const auto& prefix : qkvm_prefixes) { + size_t match = blk_name.find(prefix.first); + if (match != std::string::npos) { + std::string split_blk = "SPLIT_L|" + blk_name.substr(0, match) + prefix.second; + keys.push_back(split_blk); } } + keys.push_back(blk_name); + } - std::string lora_down_name = "lora." + k_tensor + ".lora_down.weight"; - std::string alpha_name = "lora." + k_tensor + ".alpha"; - std::string scale_name = "lora." + k_tensor + ".scale"; + std::vector ret; + for (std::string& key : keys) { + ret.push_back(key); + replace_all_chars(key, '.', '_'); + // fix for some sdxl lora, like lcm-lora-xl + if (key == "model_diffusion_model_output_blocks_2_2_conv") { + ret.push_back("model_diffusion_model_output_blocks_2_1_conv"); + } + ret.push_back(key); + } + return ret; + } - ggml_tensor* lora_up = NULL; - ggml_tensor* lora_down = NULL; + struct ggml_cgraph* build_lora_graph(std::map model_tensors, SDVersion version) { + size_t lora_graph_size = LORA_GRAPH_BASE_SIZE + lora_tensors.size() * 10; + struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, lora_graph_size, false); - if (lora_tensors.find(lora_up_name) != lora_tensors.end()) { - lora_up = lora_tensors[lora_up_name]; - } + zero_index = ggml_new_tensor_1d(compute_ctx, GGML_TYPE_I32, 1); + set_backend_tensor_data(zero_index, zero_index_vec.data()); + ggml_build_forward_expand(gf, zero_index); - if (lora_tensors.find(lora_down_name) != lora_tensors.end()) { - lora_down = lora_tensors[lora_down_name]; - } + std::set applied_lora_tensors; + for (auto it : model_tensors) { + std::string k_tensor = it.first; + struct ggml_tensor* weight = model_tensors[it.first]; - if (lora_up == NULL || lora_down == NULL) { + std::vector keys = to_lora_keys(k_tensor, version); + if (keys.size() == 0) continue; - } - applied_lora_tensors.insert(lora_up_name); - applied_lora_tensors.insert(lora_down_name); - applied_lora_tensors.insert(alpha_name); - applied_lora_tensors.insert(scale_name); - - // calc_cale - int64_t dim = lora_down->ne[ggml_n_dims(lora_down) - 1]; - float scale_value = 1.0f; - if (lora_tensors.find(scale_name) != lora_tensors.end()) { - scale_value = ggml_backend_tensor_get_f32(lora_tensors[scale_name]); - } else if (lora_tensors.find(alpha_name) != lora_tensors.end()) { - float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]); - scale_value = alpha / dim; - } - scale_value *= multiplier; - - // flat lora tensors to multiply it - int64_t lora_up_rows = lora_up->ne[ggml_n_dims(lora_up) - 1]; - lora_up = ggml_reshape_2d(compute_ctx, lora_up, ggml_nelements(lora_up) / lora_up_rows, lora_up_rows); - int64_t lora_down_rows = lora_down->ne[ggml_n_dims(lora_down) - 1]; - lora_down = ggml_reshape_2d(compute_ctx, lora_down, ggml_nelements(lora_down) / lora_down_rows, lora_down_rows); - - // ggml_mul_mat requires tensor b transposed - lora_down = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, lora_down)); - struct ggml_tensor* updown = ggml_mul_mat(compute_ctx, lora_up, lora_down); - updown = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, updown)); - updown = ggml_reshape(compute_ctx, updown, weight); - GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(weight)); - updown = ggml_scale_inplace(compute_ctx, updown, scale_value); - ggml_tensor* final_weight; - if (weight->type != GGML_TYPE_F32 && weight->type != GGML_TYPE_F16) { - // final_weight = ggml_new_tensor(compute_ctx, GGML_TYPE_F32, ggml_n_dims(weight), weight->ne); - // final_weight = ggml_cpy(compute_ctx, weight, final_weight); - final_weight = to_f32(compute_ctx, weight); - final_weight = ggml_add_inplace(compute_ctx, final_weight, updown); - final_weight = ggml_cpy(compute_ctx, final_weight, weight); - } else { - final_weight = ggml_add_inplace(compute_ctx, weight, updown); + for (auto& key : keys) { + bool is_qkv_split = starts_with(key, "SPLIT|"); + if (is_qkv_split) { + key = key.substr(sizeof("SPLIT|") - 1); + } + bool is_qkvm_split = starts_with(key, "SPLIT_L|"); + if (is_qkvm_split) { + key = key.substr(sizeof("SPLIT_L|") - 1); + } + struct ggml_tensor* updown = NULL; + float scale_value = 1.0f; + std::string fk = lora_pre[type] + key; + if (lora_tensors.find(fk + ".hada_w1_a") != lora_tensors.end()) { + // LoHa mode + + // TODO: split qkv convention for LoHas (is it ever used?) + if (is_qkv_split || is_qkvm_split) { + LOG_ERROR("Split qkv isn't supported for LoHa models."); + break; + } + std::string alpha_name = ""; + + ggml_tensor* hada_1_mid = NULL; // tau for tucker decomposition + ggml_tensor* hada_1_up = NULL; + ggml_tensor* hada_1_down = NULL; + + ggml_tensor* hada_2_mid = NULL; // tau for tucker decomposition + ggml_tensor* hada_2_up = NULL; + ggml_tensor* hada_2_down = NULL; + + std::string hada_1_mid_name = ""; + std::string hada_1_down_name = ""; + std::string hada_1_up_name = ""; + + std::string hada_2_mid_name = ""; + std::string hada_2_down_name = ""; + std::string hada_2_up_name = ""; + + hada_1_down_name = fk + ".hada_w1_b"; + hada_1_up_name = fk + ".hada_w1_a"; + hada_1_mid_name = fk + ".hada_t1"; + if (lora_tensors.find(hada_1_down_name) != lora_tensors.end()) { + hada_1_down = to_f32(compute_ctx, lora_tensors[hada_1_down_name]); + } + if (lora_tensors.find(hada_1_up_name) != lora_tensors.end()) { + hada_1_up = to_f32(compute_ctx, lora_tensors[hada_1_up_name]); + } + if (lora_tensors.find(hada_1_mid_name) != lora_tensors.end()) { + hada_1_mid = to_f32(compute_ctx, lora_tensors[hada_1_mid_name]); + applied_lora_tensors.insert(hada_1_mid_name); + hada_1_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_1_up)); + } + + hada_2_down_name = fk + ".hada_w2_b"; + hada_2_up_name = fk + ".hada_w2_a"; + hada_2_mid_name = fk + ".hada_t2"; + if (lora_tensors.find(hada_2_down_name) != lora_tensors.end()) { + hada_2_down = to_f32(compute_ctx, lora_tensors[hada_2_down_name]); + } + if (lora_tensors.find(hada_2_up_name) != lora_tensors.end()) { + hada_2_up = to_f32(compute_ctx, lora_tensors[hada_2_up_name]); + } + if (lora_tensors.find(hada_2_mid_name) != lora_tensors.end()) { + hada_2_mid = to_f32(compute_ctx, lora_tensors[hada_2_mid_name]); + applied_lora_tensors.insert(hada_2_mid_name); + hada_2_up = ggml_cont(compute_ctx, ggml_transpose(compute_ctx, hada_2_up)); + } + + alpha_name = fk + ".alpha"; + + applied_lora_tensors.insert(hada_1_down_name); + applied_lora_tensors.insert(hada_1_up_name); + applied_lora_tensors.insert(hada_2_down_name); + applied_lora_tensors.insert(hada_2_up_name); + + applied_lora_tensors.insert(alpha_name); + if (hada_1_up == NULL || hada_1_down == NULL || hada_2_up == NULL || hada_2_down == NULL) { + continue; + } + + struct ggml_tensor* updown_1 = ggml_merge_lora(compute_ctx, hada_1_down, hada_1_up, hada_1_mid); + struct ggml_tensor* updown_2 = ggml_merge_lora(compute_ctx, hada_2_down, hada_2_up, hada_2_mid); + updown = ggml_mul_inplace(compute_ctx, updown_1, updown_2); + + // calc_scale + // TODO: .dora_scale? + int64_t rank = hada_1_down->ne[ggml_n_dims(hada_1_down) - 1]; + if (lora_tensors.find(alpha_name) != lora_tensors.end()) { + float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]); + scale_value = alpha / rank; + } + } else if (lora_tensors.find(fk + ".lokr_w1") != lora_tensors.end() || lora_tensors.find(fk + ".lokr_w1_a") != lora_tensors.end()) { + // LoKr mode + + // TODO: split qkv convention for LoKrs (is it ever used?) + if (is_qkv_split || is_qkvm_split) { + LOG_ERROR("Split qkv isn't supported for LoKr models."); + break; + } + + std::string alpha_name = fk + ".alpha"; + + ggml_tensor* lokr_w1 = NULL; + ggml_tensor* lokr_w2 = NULL; + + std::string lokr_w1_name = ""; + std::string lokr_w2_name = ""; + + lokr_w1_name = fk + ".lokr_w1"; + lokr_w2_name = fk + ".lokr_w2"; + + if (lora_tensors.find(lokr_w1_name) != lora_tensors.end()) { + lokr_w1 = to_f32(compute_ctx, lora_tensors[lokr_w1_name]); + applied_lora_tensors.insert(lokr_w1_name); + } else { + ggml_tensor* down = NULL; + ggml_tensor* up = NULL; + std::string down_name = lokr_w1_name + "_b"; + std::string up_name = lokr_w1_name + "_a"; + if (lora_tensors.find(down_name) != lora_tensors.end()) { + // w1 should not be low rank normally, sometimes w1 and w2 are swapped + down = to_f32(compute_ctx, lora_tensors[down_name]); + applied_lora_tensors.insert(down_name); + + int64_t rank = down->ne[ggml_n_dims(down) - 1]; + if (lora_tensors.find(alpha_name) != lora_tensors.end()) { + float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]); + scale_value = alpha / rank; + } + } + if (lora_tensors.find(up_name) != lora_tensors.end()) { + up = to_f32(compute_ctx, lora_tensors[up_name]); + applied_lora_tensors.insert(up_name); + } + lokr_w1 = ggml_merge_lora(compute_ctx, down, up); + } + if (lora_tensors.find(lokr_w2_name) != lora_tensors.end()) { + lokr_w2 = to_f32(compute_ctx, lora_tensors[lokr_w2_name]); + applied_lora_tensors.insert(lokr_w2_name); + } else { + ggml_tensor* down = NULL; + ggml_tensor* up = NULL; + std::string down_name = lokr_w2_name + "_b"; + std::string up_name = lokr_w2_name + "_a"; + if (lora_tensors.find(down_name) != lora_tensors.end()) { + down = to_f32(compute_ctx, lora_tensors[down_name]); + applied_lora_tensors.insert(down_name); + + int64_t rank = down->ne[ggml_n_dims(down) - 1]; + if (lora_tensors.find(alpha_name) != lora_tensors.end()) { + float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]); + scale_value = alpha / rank; + } + } + if (lora_tensors.find(up_name) != lora_tensors.end()) { + up = to_f32(compute_ctx, lora_tensors[up_name]); + applied_lora_tensors.insert(up_name); + } + lokr_w2 = ggml_merge_lora(compute_ctx, down, up); + } + + // Technically it might be unused, but I believe it's the expected behavior + applied_lora_tensors.insert(alpha_name); + + updown = ggml_kronecker(compute_ctx, lokr_w1, lokr_w2); + + } else { + // LoRA mode + ggml_tensor* lora_mid = NULL; // tau for tucker decomposition + ggml_tensor* lora_up = NULL; + ggml_tensor* lora_down = NULL; + + std::string alpha_name = ""; + std::string scale_name = ""; + std::string split_q_scale_name = ""; + std::string lora_mid_name = ""; + std::string lora_down_name = ""; + std::string lora_up_name = ""; + + if (is_qkv_split) { + std::string suffix = ""; + auto split_q_d_name = fk + "q" + suffix + lora_downs[type] + ".weight"; + + if (lora_tensors.find(split_q_d_name) == lora_tensors.end()) { + suffix = "_proj"; + split_q_d_name = fk + "q" + suffix + lora_downs[type] + ".weight"; + } + if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) { + // print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1] + // find qkv and mlp up parts in LoRA model + auto split_k_d_name = fk + "k" + suffix + lora_downs[type] + ".weight"; + auto split_v_d_name = fk + "v" + suffix + lora_downs[type] + ".weight"; + + auto split_q_u_name = fk + "q" + suffix + lora_ups[type] + ".weight"; + auto split_k_u_name = fk + "k" + suffix + lora_ups[type] + ".weight"; + auto split_v_u_name = fk + "v" + suffix + lora_ups[type] + ".weight"; + + auto split_q_scale_name = fk + "q" + suffix + ".scale"; + auto split_k_scale_name = fk + "k" + suffix + ".scale"; + auto split_v_scale_name = fk + "v" + suffix + ".scale"; + + auto split_q_alpha_name = fk + "q" + suffix + ".alpha"; + auto split_k_alpha_name = fk + "k" + suffix + ".alpha"; + auto split_v_alpha_name = fk + "v" + suffix + ".alpha"; + + ggml_tensor* lora_q_down = NULL; + ggml_tensor* lora_q_up = NULL; + ggml_tensor* lora_k_down = NULL; + ggml_tensor* lora_k_up = NULL; + ggml_tensor* lora_v_down = NULL; + ggml_tensor* lora_v_up = NULL; + + lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]); + + if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) { + lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]); + } + + if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) { + lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]); + } + + if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) { + lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]); + } + + if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) { + lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]); + } + + if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) { + lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]); + } + + float q_rank = lora_q_up->ne[0]; + float k_rank = lora_k_up->ne[0]; + float v_rank = lora_v_up->ne[0]; + + float lora_q_scale = 1; + float lora_k_scale = 1; + float lora_v_scale = 1; + + if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) { + lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]); + applied_lora_tensors.insert(split_q_scale_name); + } + if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) { + lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]); + applied_lora_tensors.insert(split_k_scale_name); + } + if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) { + lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]); + applied_lora_tensors.insert(split_v_scale_name); + } + + if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) { + float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]); + applied_lora_tensors.insert(split_q_alpha_name); + lora_q_scale = lora_q_alpha / q_rank; + } + if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) { + float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]); + applied_lora_tensors.insert(split_k_alpha_name); + lora_k_scale = lora_k_alpha / k_rank; + } + if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) { + float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]); + applied_lora_tensors.insert(split_v_alpha_name); + lora_v_scale = lora_v_alpha / v_rank; + } + + ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale); + ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale); + ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale); + + // print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1] + // print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1] + // print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1] + // print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1] + // print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1] + // print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1] + + // these need to be stitched together this way: + // |q_up,0 ,0 | + // |0 ,k_up,0 | + // |0 ,0 ,v_up| + // (q_down,k_down,v_down) . (q ,k ,v) + + // up_concat will be [9216, R*3, 1, 1] + // down_concat will be [R*3, 3072, 1, 1] + ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), lora_v_down, 1); + + ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up); + ggml_scale(compute_ctx, z, 0); + ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1); + + ggml_tensor* q_up = ggml_concat(compute_ctx, lora_q_up, zz, 1); + ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), z, 1); + ggml_tensor* v_up = ggml_concat(compute_ctx, zz, lora_v_up, 1); + // print_ggml_tensor(q_up, true); //[R, 9216, 1, 1] + // print_ggml_tensor(k_up, true); //[R, 9216, 1, 1] + // print_ggml_tensor(v_up, true); //[R, 9216, 1, 1] + ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), v_up, 0); + // print_ggml_tensor(lora_up_concat, true); //[R*3, 9216, 1, 1] + + lora_down = ggml_cont(compute_ctx, lora_down_concat); + lora_up = ggml_cont(compute_ctx, lora_up_concat); + + applied_lora_tensors.insert(split_q_u_name); + applied_lora_tensors.insert(split_k_u_name); + applied_lora_tensors.insert(split_v_u_name); + + applied_lora_tensors.insert(split_q_d_name); + applied_lora_tensors.insert(split_k_d_name); + applied_lora_tensors.insert(split_v_d_name); + } + } else if (is_qkvm_split) { + auto split_q_d_name = fk + "attn.to_q" + lora_downs[type] + ".weight"; + if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) { + // print_ggml_tensor(it.second, true); //[3072, 21504, 1, 1] + // find qkv and mlp up parts in LoRA model + auto split_k_d_name = fk + "attn.to_k" + lora_downs[type] + ".weight"; + auto split_v_d_name = fk + "attn.to_v" + lora_downs[type] + ".weight"; + + auto split_q_u_name = fk + "attn.to_q" + lora_ups[type] + ".weight"; + auto split_k_u_name = fk + "attn.to_k" + lora_ups[type] + ".weight"; + auto split_v_u_name = fk + "attn.to_v" + lora_ups[type] + ".weight"; + + auto split_m_d_name = fk + "proj_mlp" + lora_downs[type] + ".weight"; + auto split_m_u_name = fk + "proj_mlp" + lora_ups[type] + ".weight"; + + auto split_q_scale_name = fk + "attn.to_q" + ".scale"; + auto split_k_scale_name = fk + "attn.to_k" + ".scale"; + auto split_v_scale_name = fk + "attn.to_v" + ".scale"; + auto split_m_scale_name = fk + "proj_mlp" + ".scale"; + + auto split_q_alpha_name = fk + "attn.to_q" + ".alpha"; + auto split_k_alpha_name = fk + "attn.to_k" + ".alpha"; + auto split_v_alpha_name = fk + "attn.to_v" + ".alpha"; + auto split_m_alpha_name = fk + "proj_mlp" + ".alpha"; + + ggml_tensor* lora_q_down = NULL; + ggml_tensor* lora_q_up = NULL; + ggml_tensor* lora_k_down = NULL; + ggml_tensor* lora_k_up = NULL; + ggml_tensor* lora_v_down = NULL; + ggml_tensor* lora_v_up = NULL; + + ggml_tensor* lora_m_down = NULL; + ggml_tensor* lora_m_up = NULL; + + lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]); + + if (lora_tensors.find(split_q_d_name) != lora_tensors.end()) { + lora_q_down = to_f32(compute_ctx, lora_tensors[split_q_d_name]); + } + + if (lora_tensors.find(split_q_u_name) != lora_tensors.end()) { + lora_q_up = to_f32(compute_ctx, lora_tensors[split_q_u_name]); + } + + if (lora_tensors.find(split_k_d_name) != lora_tensors.end()) { + lora_k_down = to_f32(compute_ctx, lora_tensors[split_k_d_name]); + } + + if (lora_tensors.find(split_k_u_name) != lora_tensors.end()) { + lora_k_up = to_f32(compute_ctx, lora_tensors[split_k_u_name]); + } + + if (lora_tensors.find(split_v_d_name) != lora_tensors.end()) { + lora_v_down = to_f32(compute_ctx, lora_tensors[split_v_d_name]); + } + + if (lora_tensors.find(split_v_u_name) != lora_tensors.end()) { + lora_v_up = to_f32(compute_ctx, lora_tensors[split_v_u_name]); + } + + if (lora_tensors.find(split_m_d_name) != lora_tensors.end()) { + lora_m_down = to_f32(compute_ctx, lora_tensors[split_m_d_name]); + } + + if (lora_tensors.find(split_m_u_name) != lora_tensors.end()) { + lora_m_up = to_f32(compute_ctx, lora_tensors[split_m_u_name]); + } + + float q_rank = lora_q_up->ne[0]; + float k_rank = lora_k_up->ne[0]; + float v_rank = lora_v_up->ne[0]; + float m_rank = lora_v_up->ne[0]; + + float lora_q_scale = 1; + float lora_k_scale = 1; + float lora_v_scale = 1; + float lora_m_scale = 1; + + if (lora_tensors.find(split_q_scale_name) != lora_tensors.end()) { + lora_q_scale = ggml_backend_tensor_get_f32(lora_tensors[split_q_scale_name]); + applied_lora_tensors.insert(split_q_scale_name); + } + if (lora_tensors.find(split_k_scale_name) != lora_tensors.end()) { + lora_k_scale = ggml_backend_tensor_get_f32(lora_tensors[split_k_scale_name]); + applied_lora_tensors.insert(split_k_scale_name); + } + if (lora_tensors.find(split_v_scale_name) != lora_tensors.end()) { + lora_v_scale = ggml_backend_tensor_get_f32(lora_tensors[split_v_scale_name]); + applied_lora_tensors.insert(split_v_scale_name); + } + if (lora_tensors.find(split_m_scale_name) != lora_tensors.end()) { + lora_m_scale = ggml_backend_tensor_get_f32(lora_tensors[split_m_scale_name]); + applied_lora_tensors.insert(split_m_scale_name); + } + + if (lora_tensors.find(split_q_alpha_name) != lora_tensors.end()) { + float lora_q_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_q_alpha_name]); + applied_lora_tensors.insert(split_q_alpha_name); + lora_q_scale = lora_q_alpha / q_rank; + } + if (lora_tensors.find(split_k_alpha_name) != lora_tensors.end()) { + float lora_k_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_k_alpha_name]); + applied_lora_tensors.insert(split_k_alpha_name); + lora_k_scale = lora_k_alpha / k_rank; + } + if (lora_tensors.find(split_v_alpha_name) != lora_tensors.end()) { + float lora_v_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_v_alpha_name]); + applied_lora_tensors.insert(split_v_alpha_name); + lora_v_scale = lora_v_alpha / v_rank; + } + if (lora_tensors.find(split_m_alpha_name) != lora_tensors.end()) { + float lora_m_alpha = ggml_backend_tensor_get_f32(lora_tensors[split_m_alpha_name]); + applied_lora_tensors.insert(split_m_alpha_name); + lora_m_scale = lora_m_alpha / m_rank; + } + + ggml_scale_inplace(compute_ctx, lora_q_down, lora_q_scale); + ggml_scale_inplace(compute_ctx, lora_k_down, lora_k_scale); + ggml_scale_inplace(compute_ctx, lora_v_down, lora_v_scale); + ggml_scale_inplace(compute_ctx, lora_m_down, lora_m_scale); + + // print_ggml_tensor(lora_q_down, true); //[3072, R, 1, 1] + // print_ggml_tensor(lora_k_down, true); //[3072, R, 1, 1] + // print_ggml_tensor(lora_v_down, true); //[3072, R, 1, 1] + // print_ggml_tensor(lora_m_down, true); //[3072, R, 1, 1] + // print_ggml_tensor(lora_q_up, true); //[R, 3072, 1, 1] + // print_ggml_tensor(lora_k_up, true); //[R, 3072, 1, 1] + // print_ggml_tensor(lora_v_up, true); //[R, 3072, 1, 1] + // print_ggml_tensor(lora_m_up, true); //[R, 12288, 1, 1] + + // these need to be stitched together this way: + // |q_up,0 ,0 ,0 | + // |0 ,k_up,0 ,0 | + // |0 ,0 ,v_up,0 | + // |0 ,0 ,0 ,m_up| + // (q_down,k_down,v_down,m_down) . (q ,k ,v ,m) + + // up_concat will be [21504, R*4, 1, 1] + // down_concat will be [R*4, 3072, 1, 1] + + ggml_tensor* lora_down_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_down, lora_k_down, 1), ggml_concat(compute_ctx, lora_v_down, lora_m_down, 1), 1); + // print_ggml_tensor(lora_down_concat, true); //[3072, R*4, 1, 1] + + // this also means that if rank is bigger than 672, it is less memory efficient to do it this way (should be fine) + // print_ggml_tensor(lora_q_up, true); //[3072, R, 1, 1] + ggml_tensor* z = ggml_dup_tensor(compute_ctx, lora_q_up); + ggml_tensor* mlp_z = ggml_dup_tensor(compute_ctx, lora_m_up); + ggml_scale(compute_ctx, z, 0); + ggml_scale(compute_ctx, mlp_z, 0); + ggml_tensor* zz = ggml_concat(compute_ctx, z, z, 1); + + ggml_tensor* q_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, lora_q_up, zz, 1), mlp_z, 1); + ggml_tensor* k_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, z, lora_k_up, 1), ggml_concat(compute_ctx, z, mlp_z, 1), 1); + ggml_tensor* v_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, lora_v_up, 1), mlp_z, 1); + ggml_tensor* m_up = ggml_concat(compute_ctx, ggml_concat(compute_ctx, zz, z, 1), lora_m_up, 1); + // print_ggml_tensor(q_up, true); //[R, 21504, 1, 1] + // print_ggml_tensor(k_up, true); //[R, 21504, 1, 1] + // print_ggml_tensor(v_up, true); //[R, 21504, 1, 1] + // print_ggml_tensor(m_up, true); //[R, 21504, 1, 1] + + ggml_tensor* lora_up_concat = ggml_concat(compute_ctx, ggml_concat(compute_ctx, q_up, k_up, 0), ggml_concat(compute_ctx, v_up, m_up, 0), 0); + // print_ggml_tensor(lora_up_concat, true); //[R*4, 21504, 1, 1] + + lora_down = ggml_cont(compute_ctx, lora_down_concat); + lora_up = ggml_cont(compute_ctx, lora_up_concat); + + applied_lora_tensors.insert(split_q_u_name); + applied_lora_tensors.insert(split_k_u_name); + applied_lora_tensors.insert(split_v_u_name); + applied_lora_tensors.insert(split_m_u_name); + + applied_lora_tensors.insert(split_q_d_name); + applied_lora_tensors.insert(split_k_d_name); + applied_lora_tensors.insert(split_v_d_name); + applied_lora_tensors.insert(split_m_d_name); + } + } else { + lora_up_name = fk + lora_ups[type] + ".weight"; + lora_down_name = fk + lora_downs[type] + ".weight"; + lora_mid_name = fk + ".lora_mid.weight"; + + alpha_name = fk + ".alpha"; + scale_name = fk + ".scale"; + + if (lora_tensors.find(lora_up_name) != lora_tensors.end()) { + lora_up = to_f32(compute_ctx, lora_tensors[lora_up_name]); + } + + if (lora_tensors.find(lora_down_name) != lora_tensors.end()) { + lora_down = to_f32(compute_ctx, lora_tensors[lora_down_name]); + } + + if (lora_tensors.find(lora_mid_name) != lora_tensors.end()) { + lora_mid = to_f32(compute_ctx, lora_tensors[lora_mid_name]); + applied_lora_tensors.insert(lora_mid_name); + } + + applied_lora_tensors.insert(lora_up_name); + applied_lora_tensors.insert(lora_down_name); + applied_lora_tensors.insert(alpha_name); + applied_lora_tensors.insert(scale_name); + } + + if (lora_up == NULL || lora_down == NULL) { + continue; + } + // calc_scale + // TODO: .dora_scale? + int64_t rank = lora_down->ne[ggml_n_dims(lora_down) - 1]; + if (lora_tensors.find(scale_name) != lora_tensors.end()) { + scale_value = ggml_backend_tensor_get_f32(lora_tensors[scale_name]); + } else if (lora_tensors.find(alpha_name) != lora_tensors.end()) { + float alpha = ggml_backend_tensor_get_f32(lora_tensors[alpha_name]); + scale_value = alpha / rank; + } + + updown = ggml_merge_lora(compute_ctx, lora_down, lora_up, lora_mid); + } + scale_value *= multiplier; + updown = ggml_reshape(compute_ctx, updown, weight); + GGML_ASSERT(ggml_nelements(updown) == ggml_nelements(weight)); + updown = ggml_scale_inplace(compute_ctx, updown, scale_value); + ggml_tensor* final_weight; + if (weight->type != GGML_TYPE_F32 && weight->type != GGML_TYPE_F16) { + // final_weight = ggml_new_tensor(compute_ctx, GGML_TYPE_F32, ggml_n_dims(weight), weight->ne); + // final_weight = ggml_cpy(compute_ctx, weight, final_weight); + final_weight = to_f32(compute_ctx, weight); + final_weight = ggml_add_inplace(compute_ctx, final_weight, updown); + final_weight = ggml_cpy(compute_ctx, final_weight, weight); + } else { + final_weight = ggml_add_inplace(compute_ctx, weight, updown); + } + // final_weight = ggml_add_inplace(compute_ctx, weight, updown); // apply directly + ggml_build_forward_expand(gf, final_weight); + break; } - // final_weight = ggml_add_inplace(compute_ctx, weight, updown); // apply directly - ggml_build_forward_expand(gf, final_weight); } - size_t total_lora_tensors_count = 0; size_t applied_lora_tensors_count = 0; for (auto& kv : lora_tensors) { total_lora_tensors_count++; if (applied_lora_tensors.find(kv.first) == applied_lora_tensors.end()) { - LOG_WARN("unused lora tensor %s", kv.first.c_str()); + LOG_WARN("unused lora tensor |%s|", kv.first.c_str()); + print_ggml_tensor(kv.second, true); + // exit(0); } else { applied_lora_tensors_count++; } @@ -192,9 +835,9 @@ struct LoraModel : public GGMLRunner { return gf; } - void apply(std::map model_tensors, int n_threads) { + void apply(std::map model_tensors, SDVersion version, int n_threads) { auto get_graph = [&]() -> struct ggml_cgraph* { - return build_lora_graph(model_tensors); + return build_lora_graph(model_tensors, version); }; GGMLRunner::compute(get_graph, n_threads, true); } diff --git a/mmdit.hpp b/mmdit.hpp index 3a278dac7..dee7b1c49 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -147,8 +147,9 @@ class RMSNorm : public UnaryBlock { int64_t hidden_size; float eps; - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["weight"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); + void init_params(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { + enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; + params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size); } public: @@ -252,6 +253,7 @@ struct DismantledBlock : public GGMLBlock { public: int64_t num_heads; bool pre_only; + bool self_attn; public: DismantledBlock(int64_t hidden_size, @@ -259,14 +261,19 @@ struct DismantledBlock : public GGMLBlock { float mlp_ratio = 4.0, std::string qk_norm = "", bool qkv_bias = false, - bool pre_only = false) - : num_heads(num_heads), pre_only(pre_only) { + bool pre_only = false, + bool self_attn = false) + : num_heads(num_heads), pre_only(pre_only), self_attn(self_attn) { // rmsnorm is always Flase // scale_mod_only is always Flase // swiglu is always Flase blocks["norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); blocks["attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only)); + if (self_attn) { + blocks["attn2"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, false)); + } + if (!pre_only) { blocks["norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); int64_t mlp_hidden_dim = (int64_t)(hidden_size * mlp_ratio); @@ -277,9 +284,52 @@ struct DismantledBlock : public GGMLBlock { if (pre_only) { n_mods = 2; } + if (self_attn) { + n_mods = 9; + } blocks["adaLN_modulation.1"] = std::shared_ptr(new Linear(hidden_size, n_mods * hidden_size)); } + std::tuple, std::vector, std::vector> pre_attention_x(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* c) { + GGML_ASSERT(self_attn); + // x: [N, n_token, hidden_size] + // c: [N, hidden_size] + auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); + auto attn = std::dynamic_pointer_cast(blocks["attn"]); + auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); + auto adaLN_modulation_1 = std::dynamic_pointer_cast(blocks["adaLN_modulation.1"]); + + int64_t n_mods = 9; + auto m = adaLN_modulation_1->forward(ctx, ggml_silu(ctx, c)); // [N, n_mods * hidden_size] + m = ggml_reshape_3d(ctx, m, c->ne[0], n_mods, c->ne[1]); // [N, n_mods, hidden_size] + m = ggml_cont(ctx, ggml_permute(ctx, m, 0, 2, 1, 3)); // [n_mods, N, hidden_size] + + int64_t offset = m->nb[1] * m->ne[1]; + auto shift_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 0); // [N, hidden_size] + auto scale_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 1); // [N, hidden_size] + auto gate_msa = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 2); // [N, hidden_size] + + auto shift_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 3); // [N, hidden_size] + auto scale_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 4); // [N, hidden_size] + auto gate_mlp = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 5); // [N, hidden_size] + + auto shift_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 6); // [N, hidden_size] + auto scale_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 7); // [N, hidden_size] + auto gate_msa2 = ggml_view_2d(ctx, m, m->ne[0], m->ne[1], m->nb[1], offset * 8); // [N, hidden_size] + + auto x_norm = norm1->forward(ctx, x); + + auto attn_in = modulate(ctx, x_norm, shift_msa, scale_msa); + auto qkv = attn->pre_attention(ctx, attn_in); + + auto attn2_in = modulate(ctx, x_norm, shift_msa2, scale_msa2); + auto qkv2 = attn2->pre_attention(ctx, attn2_in); + + return {qkv, qkv2, {x, gate_msa, shift_mlp, scale_mlp, gate_mlp, gate_msa2}}; + } + std::pair, std::vector> pre_attention(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* c) { @@ -319,6 +369,44 @@ struct DismantledBlock : public GGMLBlock { } } + struct ggml_tensor* post_attention_x(struct ggml_context* ctx, + struct ggml_tensor* attn_out, + struct ggml_tensor* attn2_out, + struct ggml_tensor* x, + struct ggml_tensor* gate_msa, + struct ggml_tensor* shift_mlp, + struct ggml_tensor* scale_mlp, + struct ggml_tensor* gate_mlp, + struct ggml_tensor* gate_msa2) { + // attn_out: [N, n_token, hidden_size] + // x: [N, n_token, hidden_size] + // gate_msa: [N, hidden_size] + // shift_mlp: [N, hidden_size] + // scale_mlp: [N, hidden_size] + // gate_mlp: [N, hidden_size] + // return: [N, n_token, hidden_size] + GGML_ASSERT(!pre_only); + + auto attn = std::dynamic_pointer_cast(blocks["attn"]); + auto attn2 = std::dynamic_pointer_cast(blocks["attn2"]); + auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); + auto mlp = std::dynamic_pointer_cast(blocks["mlp"]); + + gate_msa = ggml_reshape_3d(ctx, gate_msa, gate_msa->ne[0], 1, gate_msa->ne[1]); // [N, 1, hidden_size] + gate_mlp = ggml_reshape_3d(ctx, gate_mlp, gate_mlp->ne[0], 1, gate_mlp->ne[1]); // [N, 1, hidden_size] + gate_msa2 = ggml_reshape_3d(ctx, gate_msa2, gate_msa2->ne[0], 1, gate_msa2->ne[1]); // [N, 1, hidden_size] + + attn_out = attn->post_attention(ctx, attn_out); + attn2_out = attn2->post_attention(ctx, attn2_out); + + x = ggml_add(ctx, x, ggml_mul(ctx, attn_out, gate_msa)); + x = ggml_add(ctx, x, ggml_mul(ctx, attn2_out, gate_msa2)); + auto mlp_out = mlp->forward(ctx, modulate(ctx, norm2->forward(ctx, x), shift_mlp, scale_mlp)); + x = ggml_add(ctx, x, ggml_mul(ctx, mlp_out, gate_mlp)); + + return x; + } + struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* attn_out, struct ggml_tensor* x, @@ -357,29 +445,52 @@ struct DismantledBlock : public GGMLBlock { // return: [N, n_token, hidden_size] auto attn = std::dynamic_pointer_cast(blocks["attn"]); - - auto qkv_intermediates = pre_attention(ctx, x, c); - auto qkv = qkv_intermediates.first; - auto intermediates = qkv_intermediates.second; - - auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] - x = post_attention(ctx, - attn_out, - intermediates[0], - intermediates[1], - intermediates[2], - intermediates[3], - intermediates[4]); - return x; // [N, n_token, dim] + if (self_attn) { + auto qkv_intermediates = pre_attention_x(ctx, x, c); + // auto qkv = qkv_intermediates.first; + // auto intermediates = qkv_intermediates.second; + // no longer a pair, but a tuple + auto qkv = std::get<0>(qkv_intermediates); + auto qkv2 = std::get<1>(qkv_intermediates); + auto intermediates = std::get<2>(qkv_intermediates); + + auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] + auto attn2_out = ggml_nn_attention_ext(ctx, qkv2[0], qkv2[1], qkv2[2], num_heads); // [N, n_token, dim] + x = post_attention_x(ctx, + attn_out, + attn2_out, + intermediates[0], + intermediates[1], + intermediates[2], + intermediates[3], + intermediates[4], + intermediates[5]); + return x; // [N, n_token, dim] + } else { + auto qkv_intermediates = pre_attention(ctx, x, c); + auto qkv = qkv_intermediates.first; + auto intermediates = qkv_intermediates.second; + + auto attn_out = ggml_nn_attention_ext(ctx, qkv[0], qkv[1], qkv[2], num_heads); // [N, n_token, dim] + x = post_attention(ctx, + attn_out, + intermediates[0], + intermediates[1], + intermediates[2], + intermediates[3], + intermediates[4]); + return x; // [N, n_token, dim] + } } }; -__STATIC_INLINE__ std::pair block_mixing(struct ggml_context* ctx, - struct ggml_tensor* context, - struct ggml_tensor* x, - struct ggml_tensor* c, - std::shared_ptr context_block, - std::shared_ptr x_block) { +__STATIC_INLINE__ std::pair +block_mixing(struct ggml_context* ctx, + struct ggml_tensor* context, + struct ggml_tensor* x, + struct ggml_tensor* c, + std::shared_ptr context_block, + std::shared_ptr x_block) { // context: [N, n_context, hidden_size] // x: [N, n_token, hidden_size] // c: [N, hidden_size] @@ -387,10 +498,18 @@ __STATIC_INLINE__ std::pair block_mixi auto context_qkv = context_qkv_intermediates.first; auto context_intermediates = context_qkv_intermediates.second; - auto x_qkv_intermediates = x_block->pre_attention(ctx, x, c); - auto x_qkv = x_qkv_intermediates.first; - auto x_intermediates = x_qkv_intermediates.second; + std::vector x_qkv, x_qkv2, x_intermediates; + if (x_block->self_attn) { + auto x_qkv_intermediates = x_block->pre_attention_x(ctx, x, c); + x_qkv = std::get<0>(x_qkv_intermediates); + x_qkv2 = std::get<1>(x_qkv_intermediates); + x_intermediates = std::get<2>(x_qkv_intermediates); + } else { + auto x_qkv_intermediates = x_block->pre_attention(ctx, x, c); + x_qkv = x_qkv_intermediates.first; + x_intermediates = x_qkv_intermediates.second; + } std::vector qkv; for (int i = 0; i < 3; i++) { qkv.push_back(ggml_concat(ctx, context_qkv[i], x_qkv[i], 1)); @@ -429,13 +548,27 @@ __STATIC_INLINE__ std::pair block_mixi context = NULL; } - x = x_block->post_attention(ctx, - x_attn, - x_intermediates[0], - x_intermediates[1], - x_intermediates[2], - x_intermediates[3], - x_intermediates[4]); + if (x_block->self_attn) { + auto attn2 = ggml_nn_attention_ext(ctx, x_qkv2[0], x_qkv2[1], x_qkv2[2], x_block->num_heads); // [N, n_token, hidden_size] + + x = x_block->post_attention_x(ctx, + x_attn, + attn2, + x_intermediates[0], + x_intermediates[1], + x_intermediates[2], + x_intermediates[3], + x_intermediates[4], + x_intermediates[5]); + } else { + x = x_block->post_attention(ctx, + x_attn, + x_intermediates[0], + x_intermediates[1], + x_intermediates[2], + x_intermediates[3], + x_intermediates[4]); + } return {context, x}; } @@ -447,9 +580,10 @@ struct JointBlock : public GGMLBlock { float mlp_ratio = 4.0, std::string qk_norm = "", bool qkv_bias = false, - bool pre_only = false) { + bool pre_only = false, + bool self_attn_x = false) { blocks["context_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only)); - blocks["x_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false)); + blocks["x_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false, self_attn_x)); } std::pair forward(struct ggml_context* ctx, @@ -503,10 +637,10 @@ struct FinalLayer : public GGMLBlock { struct MMDiT : public GGMLBlock { // Diffusion model with a Transformer backbone. protected: - SDVersion version = VERSION_SD3_2B; int64_t input_size = -1; int64_t patch_size = 2; int64_t in_channels = 16; + int64_t d_self = -1; // >=0 for MMdiT-X int64_t depth = 24; float mlp_ratio = 4.0f; int64_t adm_in_channels = 2048; @@ -518,13 +652,13 @@ struct MMDiT : public GGMLBlock { int64_t hidden_size; std::string qk_norm; - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["pos_embed"] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hidden_size, num_patchs, 1); + void init_params(struct ggml_context* ctx, std::map& tensor_types, std::string prefix = "") { + enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "pos_embed") != tensor_types.end()) ? tensor_types[prefix + "pos_embed"] : GGML_TYPE_F32; + params["pos_embed"] = ggml_new_tensor_3d(ctx, wtype, hidden_size, num_patchs, 1); } public: - MMDiT(SDVersion version = VERSION_SD3_2B) - : version(version) { + MMDiT(std::map& tensor_types) { // input_size is always None // learn_sigma is always False // register_length is alwalys 0 @@ -536,34 +670,44 @@ struct MMDiT : public GGMLBlock { // pos_embed_scaling_factor is not used // pos_embed_offset is not used // context_embedder_config is always {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}} - if (version == VERSION_SD3_2B) { - input_size = -1; - patch_size = 2; - in_channels = 16; - depth = 24; - mlp_ratio = 4.0f; - adm_in_channels = 2048; - out_channels = 16; - pos_embed_max_size = 192; - num_patchs = 36864; // 192 * 192 - context_size = 4096; - context_embedder_out_dim = 1536; - } else if (version == VERSION_SD3_5_8B) { - input_size = -1; - patch_size = 2; - in_channels = 16; - depth = 38; - mlp_ratio = 4.0f; - adm_in_channels = 2048; - out_channels = 16; - pos_embed_max_size = 192; - num_patchs = 36864; // 192 * 192 - context_size = 4096; - context_embedder_out_dim = 2432; - qk_norm = "rms"; + + // read tensors from tensor_types + for (auto pair : tensor_types) { + std::string tensor_name = pair.first; + if (tensor_name.find("model.diffusion_model.") == std::string::npos) + continue; + size_t jb = tensor_name.find("joint_blocks."); + if (jb != std::string::npos) { + tensor_name = tensor_name.substr(jb); // remove prefix + int block_depth = atoi(tensor_name.substr(13, tensor_name.find(".", 13)).c_str()); + if (block_depth + 1 > depth) { + depth = block_depth + 1; + } + if (tensor_name.find("attn.ln") != std::string::npos) { + if (tensor_name.find(".bias") != std::string::npos) { + qk_norm = "ln"; + } else { + qk_norm = "rms"; + } + } + if (tensor_name.find("attn2") != std::string::npos) { + if (block_depth > d_self) { + d_self = block_depth; + } + } + } } + + if (d_self >= 0) { + pos_embed_max_size *= 2; + num_patchs *= 4; + } + + LOG_INFO("MMDiT layers: %d (including %d MMDiT-x layers)", depth, d_self + 1); + int64_t default_out_channels = in_channels; hidden_size = 64 * depth; + context_embedder_out_dim = 64 * depth; int64_t num_heads = depth; blocks["x_embedder"] = std::shared_ptr(new PatchEmbed(input_size, patch_size, in_channels, hidden_size, true)); @@ -581,15 +725,17 @@ struct MMDiT : public GGMLBlock { mlp_ratio, qk_norm, true, - i == depth - 1)); + i == depth - 1, + i <= d_self)); } blocks["final_layer"] = std::shared_ptr(new FinalLayer(hidden_size, patch_size, out_channels)); } - struct ggml_tensor* cropped_pos_embed(struct ggml_context* ctx, - int64_t h, - int64_t w) { + struct ggml_tensor* + cropped_pos_embed(struct ggml_context* ctx, + int64_t h, + int64_t w) { auto pos_embed = params["pos_embed"]; h = (h + 1) / patch_size; @@ -651,7 +797,8 @@ struct MMDiT : public GGMLBlock { struct ggml_tensor* forward_core_with_concat(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* c_mod, - struct ggml_tensor* context) { + struct ggml_tensor* context, + std::vector skip_layers = std::vector()) { // x: [N, H*W, hidden_size] // context: [N, n_context, d_context] // c: [N, hidden_size] @@ -659,6 +806,11 @@ struct MMDiT : public GGMLBlock { auto final_layer = std::dynamic_pointer_cast(blocks["final_layer"]); for (int i = 0; i < depth; i++) { + // skip iteration if i is in skip_layers + if (skip_layers.size() > 0 && std::find(skip_layers.begin(), skip_layers.end(), i) != skip_layers.end()) { + continue; + } + auto block = std::dynamic_pointer_cast(blocks["joint_blocks." + std::to_string(i)]); auto context_x = block->forward(ctx, context, x, c_mod); @@ -674,8 +826,9 @@ struct MMDiT : public GGMLBlock { struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x, struct ggml_tensor* t, - struct ggml_tensor* y = NULL, - struct ggml_tensor* context = NULL) { + struct ggml_tensor* y = NULL, + struct ggml_tensor* context = NULL, + std::vector skip_layers = std::vector()) { // Forward pass of DiT. // x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) // t: (N,) tensor of diffusion timesteps @@ -706,22 +859,23 @@ struct MMDiT : public GGMLBlock { context = context_embedder->forward(ctx, context); // [N, L, D] aka [N, L, 1536] } - x = forward_core_with_concat(ctx, x, c, context); // (N, H*W, patch_size ** 2 * out_channels) + x = forward_core_with_concat(ctx, x, c, context, skip_layers); // (N, H*W, patch_size ** 2 * out_channels) x = unpatchify(ctx, x, h, w); // [N, C, H, W] return x; } }; - struct MMDiTRunner : public GGMLRunner { MMDiT mmdit; + static std::map empty_tensor_types; + MMDiTRunner(ggml_backend_t backend, - ggml_type wtype, - SDVersion version = VERSION_SD3_2B) - : GGMLRunner(backend, wtype), mmdit(version) { - mmdit.init(params_ctx, wtype); + std::map& tensor_types = empty_tensor_types, + const std::string prefix = "") + : GGMLRunner(backend), mmdit(tensor_types) { + mmdit.init(params_ctx, tensor_types, prefix); } std::string get_desc() { @@ -735,7 +889,8 @@ struct MMDiTRunner : public GGMLRunner { struct ggml_cgraph* build_graph(struct ggml_tensor* x, struct ggml_tensor* timesteps, struct ggml_tensor* context, - struct ggml_tensor* y) { + struct ggml_tensor* y, + std::vector skip_layers = std::vector()) { struct ggml_cgraph* gf = ggml_new_graph_custom(compute_ctx, MMDIT_GRAPH_SIZE, false); x = to_backend(x); @@ -747,7 +902,8 @@ struct MMDiTRunner : public GGMLRunner { x, timesteps, y, - context); + context, + skip_layers); ggml_build_forward_expand(gf, out); @@ -760,13 +916,14 @@ struct MMDiTRunner : public GGMLRunner { struct ggml_tensor* context, struct ggml_tensor* y, struct ggml_tensor** output = NULL, - struct ggml_context* output_ctx = NULL) { + struct ggml_context* output_ctx = NULL, + std::vector skip_layers = std::vector()) { // x: [N, in_channels, h, w] // timesteps: [N, ] // context: [N, max_position, hidden_size]([N, 154, 4096]) or [1, max_position, hidden_size] // y: [N, adm_in_channels] or [1, adm_in_channels] auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(x, timesteps, context, y); + return build_graph(x, timesteps, context, y, skip_layers); }; GGMLRunner::compute(get_graph, n_threads, false, output, output_ctx); @@ -815,7 +972,7 @@ struct MMDiTRunner : public GGMLRunner { // ggml_backend_t backend = ggml_backend_cuda_init(0); ggml_backend_t backend = ggml_backend_cpu_init(); ggml_type model_data_type = GGML_TYPE_F16; - std::shared_ptr mmdit = std::shared_ptr(new MMDiTRunner(backend, model_data_type)); + std::shared_ptr mmdit = std::shared_ptr(new MMDiTRunner(backend)); { LOG_INFO("loading from '%s'", file_path.c_str()); diff --git a/model.cpp b/model.cpp index 26451cdc5..2e40e004a 100644 --- a/model.cpp +++ b/model.cpp @@ -13,6 +13,7 @@ #include "ggml-alloc.h" #include "ggml-backend.h" +#include "ggml-cpu.h" #include "ggml.h" #include "stable-diffusion.h" @@ -25,6 +26,10 @@ #include "ggml-vulkan.h" #endif +#ifdef SD_USE_OPENCL +#include "ggml-opencl.h" +#endif + #define ST_HEADER_SIZE_LEN 8 uint64_t read_u64(uint8_t* buffer) { @@ -95,6 +100,7 @@ const char* unused_tensors[] = { "model_ema.diffusion_model", "embedding_manager", "denoiser.sigmas", + "text_encoders.t5xxl.transformer.encoder.embed_tokens.weight", // only used during training }; bool is_unused_tensor(std::string name) { @@ -146,9 +152,94 @@ std::unordered_map vae_decoder_name_map = { {"first_stage_model.decoder.mid.attn_1.to_v.weight", "first_stage_model.decoder.mid.attn_1.v.weight"}, }; +std::unordered_map pmid_v2_name_map = { + {"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc1.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.0.1.3.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.0.1.1.fc2.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc1.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.1.1.3.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.1.1.1.fc2.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc1.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.2.1.3.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.2.1.1.fc2.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc1.weight"}, + {"pmid.qformer_perceiver.perceiver_resampler.layers.3.1.3.weight", + "pmid.qformer_perceiver.perceiver_resampler.layers.3.1.1.fc2.weight"}, + {"pmid.qformer_perceiver.token_proj.0.bias", + "pmid.qformer_perceiver.token_proj.fc1.bias"}, + {"pmid.qformer_perceiver.token_proj.2.bias", + "pmid.qformer_perceiver.token_proj.fc2.bias"}, + {"pmid.qformer_perceiver.token_proj.0.weight", + "pmid.qformer_perceiver.token_proj.fc1.weight"}, + {"pmid.qformer_perceiver.token_proj.2.weight", + "pmid.qformer_perceiver.token_proj.fc2.weight"}, +}; + std::string convert_open_clip_to_hf_clip(const std::string& name) { std::string new_name = name; std::string prefix; + if (contains(new_name, ".enc.")) { + // llama.cpp naming convention for T5 + size_t pos = new_name.find(".enc."); + if (pos != std::string::npos) { + new_name.replace(pos, 5, ".encoder."); + } + pos = new_name.find("blk."); + if (pos != std::string::npos) { + new_name.replace(pos, 4, "block."); + } + pos = new_name.find("output_norm."); + if (pos != std::string::npos) { + new_name.replace(pos, 12, "final_layer_norm."); + } + pos = new_name.find("attn_k."); + if (pos != std::string::npos) { + new_name.replace(pos, 7, "layer.0.SelfAttention.k."); + } + pos = new_name.find("attn_v."); + if (pos != std::string::npos) { + new_name.replace(pos, 7, "layer.0.SelfAttention.v."); + } + pos = new_name.find("attn_o."); + if (pos != std::string::npos) { + new_name.replace(pos, 7, "layer.0.SelfAttention.o."); + } + pos = new_name.find("attn_q."); + if (pos != std::string::npos) { + new_name.replace(pos, 7, "layer.0.SelfAttention.q."); + } + pos = new_name.find("attn_norm."); + if (pos != std::string::npos) { + new_name.replace(pos, 10, "layer.0.layer_norm."); + } + pos = new_name.find("ffn_norm."); + if (pos != std::string::npos) { + new_name.replace(pos, 9, "layer.1.layer_norm."); + } + pos = new_name.find("ffn_up."); + if (pos != std::string::npos) { + new_name.replace(pos, 7, "layer.1.DenseReluDense.wi_1."); + } + pos = new_name.find("ffn_down."); + if (pos != std::string::npos) { + new_name.replace(pos, 9, "layer.1.DenseReluDense.wo."); + } + pos = new_name.find("ffn_gate."); + if (pos != std::string::npos) { + new_name.replace(pos, 9, "layer.1.DenseReluDense.wi_0."); + } + pos = new_name.find("attn_rel_b."); + if (pos != std::string::npos) { + new_name.replace(pos, 11, "layer.0.SelfAttention.relative_attention_bias."); + } + } else if (name == "text_encoders.t5xxl.transformer.token_embd.weight") { + new_name = "text_encoders.t5xxl.transformer.shared.weight"; + } + if (starts_with(new_name, "conditioner.embedders.0.open_clip.")) { prefix = "cond_stage_model."; new_name = new_name.substr(strlen("conditioner.embedders.0.open_clip.")); @@ -212,6 +303,13 @@ std::string convert_vae_decoder_name(const std::string& name) { return name; } +std::string convert_pmid_v2_name(const std::string& name) { + if (pmid_v2_name_map.find(name) != pmid_v2_name_map.end()) { + return pmid_v2_name_map[name]; + } + return name; +} + /* If not a SDXL LoRA the unet" prefix will have already been replaced by this * point and "te2" and "te1" don't seem to appear in non-SDXL only "te_" */ std::string convert_sdxl_lora_name(std::string tensor_name) { @@ -240,6 +338,10 @@ std::unordered_map> su {"to_v", "v"}, {"to_out_0", "proj_out"}, {"group_norm", "norm"}, + {"key", "k"}, + {"query", "q"}, + {"value", "v"}, + {"proj_attn", "proj_out"}, }, }, { @@ -264,6 +366,10 @@ std::unordered_map> su {"to_v", "v"}, {"to_out.0", "proj_out"}, {"group_norm", "norm"}, + {"key", "k"}, + {"query", "q"}, + {"value", "v"}, + {"proj_attn", "proj_out"}, }, }, { @@ -335,6 +441,10 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) { return format("model%cdiffusion_model%ctime_embed%c", seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1]; } + if (match(m, std::regex(format("unet%cadd_embedding%clinear_(\\d+)(.*)", seq, seq)), key)) { + return format("model%cdiffusion_model%clabel_emb%c0%c", seq, seq, seq, seq) + std::to_string(std::stoi(m[0]) * 2 - 2) + m[1]; + } + if (match(m, std::regex(format("unet%cdown_blocks%c(\\d+)%c(attentions|resnets)%c(\\d+)%c(.+)", seq, seq, seq, seq, seq)), key)) { std::string suffix = get_converted_suffix(m[1], m[3]); // LOG_DEBUG("%s %s %s %s", m[0].c_str(), m[1].c_str(), m[2].c_str(), m[3].c_str()); @@ -372,6 +482,19 @@ std::string convert_diffusers_name_to_compvis(std::string key, char seq) { return format("cond_stage_model%ctransformer%ctext_model", seq, seq) + m[0]; } + // clip-g + if (match(m, std::regex(format("te%c1%ctext_model%cencoder%clayers%c(\\d+)%c(.+)", seq, seq, seq, seq, seq, seq)), key)) { + return format("cond_stage_model%c1%ctransformer%ctext_model%cencoder%clayers%c", seq, seq, seq, seq, seq, seq) + m[0] + seq + m[1]; + } + + if (match(m, std::regex(format("te%c1%ctext_model(.*)", seq, seq)), key)) { + return format("cond_stage_model%c1%ctransformer%ctext_model", seq, seq, seq) + m[0]; + } + + if (match(m, std::regex(format("te%c1%ctext_projection", seq, seq)), key)) { + return format("cond_stage_model%c1%ctransformer%ctext_model%ctext_projection", seq, seq, seq, seq); + } + // vae if (match(m, std::regex(format("vae%c(.*)%cconv_norm_out(.*)", seq, seq)), key)) { return format("first_stage_model%c%s%cnorm_out%s", seq, m[0].c_str(), seq, m[1].c_str()); @@ -443,6 +566,8 @@ std::string convert_tensor_name(std::string name) { new_name = convert_open_clip_to_hf_clip(name); } else if (starts_with(name, "first_stage_model.decoder")) { new_name = convert_vae_decoder_name(name); + } else if (starts_with(name, "pmid.qformer_perceiver")) { + new_name = convert_pmid_v2_name(name); } else if (starts_with(name, "control_model.")) { // for controlnet pth models size_t pos = name.find('.'); if (pos != std::string::npos) { @@ -506,6 +631,8 @@ std::string convert_tensor_name(std::string name) { std::string new_key = convert_diffusers_name_to_compvis(name_without_network_parts, '.'); if (new_key.empty()) { new_name = name; + } else if (new_key == "cond_stage_model.1.transformer.text_model.text_projection") { + new_name = new_key; } else { new_name = new_key + "." + network_part; } @@ -521,6 +648,26 @@ std::string convert_tensor_name(std::string name) { return new_name; } +void add_preprocess_tensor_storage_types(std::map& tensor_storages_types, std::string name, enum ggml_type type) { + std::string new_name = convert_tensor_name(name); + + if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_weight")) { + size_t prefix_size = new_name.find("attn.in_proj_weight"); + std::string prefix = new_name.substr(0, prefix_size); + tensor_storages_types[prefix + "self_attn.q_proj.weight"] = type; + tensor_storages_types[prefix + "self_attn.k_proj.weight"] = type; + tensor_storages_types[prefix + "self_attn.v_proj.weight"] = type; + } else if (new_name.find("cond_stage_model") != std::string::npos && ends_with(new_name, "attn.in_proj_bias")) { + size_t prefix_size = new_name.find("attn.in_proj_bias"); + std::string prefix = new_name.substr(0, prefix_size); + tensor_storages_types[prefix + "self_attn.q_proj.bias"] = type; + tensor_storages_types[prefix + "self_attn.k_proj.bias"] = type; + tensor_storages_types[prefix + "self_attn.v_proj.bias"] = type; + } else { + tensor_storages_types[new_name] = type; + } +} + void preprocess_tensor(TensorStorage tensor_storage, std::vector& processed_tensor_storages) { std::vector result; @@ -614,6 +761,47 @@ uint16_t f8_e4m3_to_f16(uint8_t f8) { return ggml_fp32_to_fp16(*reinterpret_cast(&result)); } +uint16_t f8_e5m2_to_f16(uint8_t fp8) { + uint8_t sign = (fp8 >> 7) & 0x1; + uint8_t exponent = (fp8 >> 2) & 0x1F; + uint8_t mantissa = fp8 & 0x3; + + uint16_t fp16_sign = sign << 15; + uint16_t fp16_exponent; + uint16_t fp16_mantissa; + + if (exponent == 0 && mantissa == 0) { // zero + return fp16_sign; + } + + if (exponent == 0x1F) { // NAN and INF + fp16_exponent = 0x1F; + fp16_mantissa = mantissa ? (mantissa << 8) : 0; + return fp16_sign | (fp16_exponent << 10) | fp16_mantissa; + } + + if (exponent == 0) { // subnormal numbers + fp16_exponent = 0; + fp16_mantissa = (mantissa << 8); + return fp16_sign | fp16_mantissa; + } + + // normal numbers + int16_t true_exponent = (int16_t)exponent - 15 + 15; + if (true_exponent <= 0) { + fp16_exponent = 0; + fp16_mantissa = (mantissa << 8); + } else if (true_exponent >= 0x1F) { + fp16_exponent = 0x1F; + fp16_mantissa = 0; + } else { + fp16_exponent = (uint16_t)true_exponent; + fp16_mantissa = mantissa << 8; + } + + return fp16_sign | (fp16_exponent << 10) | fp16_mantissa; +} + void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) { // support inplace op for (int64_t i = n - 1; i >= 0; i--) { @@ -627,6 +815,12 @@ void f8_e4m3_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) { dst[i] = f8_e4m3_to_f16(src[i]); } } +void f8_e5m2_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) { + // support inplace op + for (int64_t i = n - 1; i >= 0; i--) { + dst[i] = f8_e5m2_to_f16(src[i]); + } +} void convert_tensor(void* src, ggml_type src_type, @@ -650,25 +844,25 @@ void convert_tensor(void* src, if (src_type == GGML_TYPE_F16) { ggml_fp16_to_fp32_row((ggml_fp16_t*)src, (float*)dst, n); } else { - auto qtype = ggml_internal_get_type_traits(src_type); - if (qtype.to_float == NULL) { + auto qtype = ggml_get_type_traits(src_type); + if (qtype->to_float == NULL) { throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(src_type))); } - qtype.to_float(src, (float*)dst, n); + qtype->to_float(src, (float*)dst, n); } } else { // src_type == GGML_TYPE_F16 => dst_type is quantized // src_type is quantized => dst_type == GGML_TYPE_F16 or dst_type is quantized - auto qtype = ggml_internal_get_type_traits(src_type); - if (qtype.to_float == NULL) { + auto qtype = ggml_get_type_traits(src_type); + if (qtype->to_float == NULL) { throw std::runtime_error(format("type %s unsupported for integer quantization: no dequantization available", ggml_type_name(src_type))); } std::vector buf; buf.resize(sizeof(float) * n); char* src_data_f32 = buf.data(); - qtype.to_float(src, (float*)src_data_f32, n); + qtype->to_float(src, (float*)src_data_f32, n); if (dst_type == GGML_TYPE_F16) { ggml_fp32_to_fp16_row((float*)src_data_f32, (ggml_fp16_t*)dst, n); } else { @@ -843,6 +1037,7 @@ bool ModelLoader::init_from_gguf_file(const std::string& file_path, const std::s GGML_ASSERT(ggml_nbytes(dummy) == tensor_storage.nbytes()); tensor_storages.push_back(tensor_storage); + add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type); } gguf_free(ctx_gguf_); @@ -861,8 +1056,14 @@ ggml_type str_to_ggml_type(const std::string& dtype) { ttype = GGML_TYPE_F32; } else if (dtype == "F32") { ttype = GGML_TYPE_F32; + } else if (dtype == "F64") { + ttype = GGML_TYPE_F64; } else if (dtype == "F8_E4M3") { ttype = GGML_TYPE_F16; + } else if (dtype == "F8_E5M2") { + ttype = GGML_TYPE_F16; + } else if (dtype == "I64") { + ttype = GGML_TYPE_I64; } return ttype; } @@ -875,6 +1076,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const std::ifstream file(file_path, std::ios::binary); if (!file.is_open()) { LOG_ERROR("failed to open '%s'", file_path.c_str()); + file_paths_.pop_back(); return false; } @@ -886,6 +1088,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const // read header size if (file_size_ <= ST_HEADER_SIZE_LEN) { LOG_ERROR("invalid safetensor file '%s'", file_path.c_str()); + file_paths_.pop_back(); return false; } @@ -899,6 +1102,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const size_t header_size_ = read_u64(header_size_buf); if (header_size_ >= file_size_) { LOG_ERROR("invalid safetensor file '%s'", file_path.c_str()); + file_paths_.pop_back(); return false; } @@ -909,6 +1113,7 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const file.read(header_buf.data(), header_size_); if (!file) { LOG_ERROR("read safetensors header failed: '%s'", file_path.c_str()); + file_paths_.pop_back(); return false; } @@ -976,11 +1181,16 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const tensor_storage.is_f8_e4m3 = true; // f8 -> f16 GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2); + } else if (dtype == "F8_E5M2") { + tensor_storage.is_f8_e5m2 = true; + // f8 -> f16 + GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2); } else { GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size); } tensor_storages.push_back(tensor_storage); + add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type); // LOG_DEBUG("%s %s", tensor_storage.to_string().c_str(), dtype.c_str()); } @@ -991,18 +1201,45 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const /*================================================= DiffusersModelLoader ==================================================*/ bool ModelLoader::init_from_diffusers_file(const std::string& file_path, const std::string& prefix) { - std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors"); - std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors"); - std::string clip_path = path_join(file_path, "text_encoder/model.safetensors"); + std::string unet_path = path_join(file_path, "unet/diffusion_pytorch_model.safetensors"); + std::string vae_path = path_join(file_path, "vae/diffusion_pytorch_model.safetensors"); + std::string clip_path = path_join(file_path, "text_encoder/model.safetensors"); + std::string clip_g_path = path_join(file_path, "text_encoder_2/model.safetensors"); if (!init_from_safetensors_file(unet_path, "unet.")) { return false; } + for (auto ts : tensor_storages) { + if (ts.name.find("add_embedding") != std::string::npos || ts.name.find("label_emb") != std::string::npos) { + // probably SDXL + LOG_DEBUG("Fixing name for SDXL output blocks.2.2"); + for (auto& tensor_storage : tensor_storages) { + int len = 34; + auto pos = tensor_storage.name.find("unet.up_blocks.0.upsamplers.0.conv"); + if (pos == std::string::npos) { + len = 44; + pos = tensor_storage.name.find("model.diffusion_model.output_blocks.2.1.conv"); + } + if (pos != std::string::npos) { + tensor_storage.name = "model.diffusion_model.output_blocks.2.2.conv" + tensor_storage.name.substr(len); + LOG_DEBUG("NEW NAME: %s", tensor_storage.name.c_str()); + add_preprocess_tensor_storage_types(tensor_storages_types, tensor_storage.name, tensor_storage.type); + } + } + break; + } + } + if (!init_from_safetensors_file(vae_path, "vae.")) { - return false; + LOG_WARN("Couldn't find working VAE in %s", file_path.c_str()); + // return false; } if (!init_from_safetensors_file(clip_path, "te.")) { - return false; + LOG_WARN("Couldn't find working text encoder in %s", file_path.c_str()); + // return false; + } + if (!init_from_safetensors_file(clip_g_path, "te.1.")) { + LOG_DEBUG("Couldn't find working second text encoder in %s", file_path.c_str()); } return true; } @@ -1206,7 +1443,7 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer, zip_t* zip, std::string dir, size_t file_index, - const std::string& prefix) { + const std::string prefix) { uint8_t* buffer_end = buffer + buffer_size; if (buffer[0] == 0x80) { // proto if (buffer[1] != 2) { @@ -1308,9 +1545,11 @@ bool ModelLoader::parse_data_pkl(uint8_t* buffer, reader.tensor_storage.reverse_ne(); reader.tensor_storage.file_index = file_index; // if(strcmp(prefix.c_str(), "scarlett") == 0) - // printf(" got tensor %s \n ", reader.tensor_storage.name.c_str()); + // printf(" ZIP got tensor %s \n ", reader.tensor_storage.name.c_str()); reader.tensor_storage.name = prefix + reader.tensor_storage.name; tensor_storages.push_back(reader.tensor_storage); + add_preprocess_tensor_storage_types(tensor_storages_types, reader.tensor_storage.name, reader.tensor_storage.type); + // LOG_DEBUG("%s", reader.tensor_storage.name.c_str()); // reset reader = PickleTensorReader(); @@ -1345,7 +1584,8 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s size_t pos = name.find("data.pkl"); if (pos != std::string::npos) { std::string dir = name.substr(0, pos); - void* pkl_data = NULL; + printf("ZIP %d, name = %s, dir = %s \n", i, name.c_str(), dir.c_str()); + void* pkl_data = NULL; size_t pkl_size; zip_entry_read(zip, &pkl_data, &pkl_size); @@ -1362,33 +1602,59 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s return true; } +bool ModelLoader::model_is_unet() { + for (auto& tensor_storage : tensor_storages) { + if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos) { + return true; + } + } + return false; +} + SDVersion ModelLoader::get_sd_version() { - TensorStorage token_embedding_weight; + TensorStorage token_embedding_weight, input_block_weight; + bool input_block_checked = false; + + bool has_multiple_encoders = false; + bool is_unet = false; + + bool is_xl = false; bool is_flux = false; - bool is_sd3 = false; + +#define found_family (is_xl || is_flux) for (auto& tensor_storage : tensor_storages) { - if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) { - return VERSION_FLUX_DEV; - } - if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { - is_flux = true; - } - if (tensor_storage.name.find("joint_blocks.37.x_block.attn.ln_q.weight") != std::string::npos) { - return VERSION_SD3_5_8B; - } - if (tensor_storage.name.find("model.diffusion_model.joint_blocks.23.") != std::string::npos) { - is_sd3 = true; - } - if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos) { - return VERSION_SDXL; - } - if (tensor_storage.name.find("cond_stage_model.1") != std::string::npos) { - return VERSION_SDXL; - } - if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) { - return VERSION_SVD; + if (!found_family) { + if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { + is_flux = true; + if (input_block_checked) { + break; + } + } + if (tensor_storage.name.find("model.diffusion_model.joint_blocks.") != std::string::npos) { + return VERSION_SD3; + } + if (tensor_storage.name.find("model.diffusion_model.input_blocks.") != std::string::npos || tensor_storage.name.find("unet.down_blocks.") != std::string::npos) { + is_unet = true; + if (has_multiple_encoders) { + is_xl = true; + if (input_block_checked) { + break; + } + } + } + if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos || tensor_storage.name.find("cond_stage_model.1") != std::string::npos || tensor_storage.name.find("te.1") != std::string::npos) { + has_multiple_encoders = true; + if (is_unet) { + is_xl = true; + if (input_block_checked) { + break; + } + } + } + if (tensor_storage.name.find("model.diffusion_model.input_blocks.8.0.time_mixer.mix_factor") != std::string::npos) { + return VERSION_SVD; + } } - if (tensor_storage.name == "cond_stage_model.transformer.text_model.embeddings.token_embedding.weight" || tensor_storage.name == "cond_stage_model.model.token_embedding.weight" || tensor_storage.name == "text_model.embeddings.token_embedding.weight" || @@ -1398,16 +1664,46 @@ SDVersion ModelLoader::get_sd_version() { token_embedding_weight = tensor_storage; // break; } + if (tensor_storage.name == "model.diffusion_model.input_blocks.0.0.weight" || tensor_storage.name == "model.diffusion_model.img_in.weight" || tensor_storage.name == "unet.conv_in.weight") { + input_block_weight = tensor_storage; + input_block_checked = true; + if (found_family) { + break; + } + } } - if (is_flux) { - return VERSION_FLUX_SCHNELL; + bool is_inpaint = input_block_weight.ne[2] == 9; + bool is_ip2p = input_block_weight.ne[2] == 8; + if (is_xl) { + if (is_inpaint) { + return VERSION_SDXL_INPAINT; + } + if (is_ip2p) { + return VERSION_SDXL_PIX2PIX; + } + return VERSION_SDXL; } - if (is_sd3) { - return VERSION_SD3_2B; + + if (is_flux) { + is_inpaint = input_block_weight.ne[0] == 384; + if (is_inpaint) { + return VERSION_FLUX_FILL; + } + return VERSION_FLUX; } + if (token_embedding_weight.ne[0] == 768) { + if (is_inpaint) { + return VERSION_SD1_INPAINT; + } + if (is_ip2p) { + return VERSION_SD1_PIX2PIX; + } return VERSION_SD1; } else if (token_embedding_weight.ne[0] == 1024) { + if (is_inpaint) { + return VERSION_SD2_INPAINT; + } return VERSION_SD2; } return VERSION_COUNT; @@ -1460,7 +1756,7 @@ ggml_type ModelLoader::get_diffusion_model_wtype() { continue; } - if (tensor_storage.name.find("model.diffusion_model.") == std::string::npos) { + if (tensor_storage.name.find("model.diffusion_model.") == std::string::npos && tensor_storage.name.find("unet.") == std::string::npos) { continue; } @@ -1497,6 +1793,30 @@ ggml_type ModelLoader::get_vae_wtype() { return GGML_TYPE_COUNT; } +void ModelLoader::set_wtype_override(ggml_type wtype, std::string prefix) { + for (auto& pair : tensor_storages_types) { + if (prefix.size() < 1 || pair.first.substr(0, prefix.size()) == prefix) { + bool found = false; + for (auto& tensor_storage : tensor_storages) { + std::map temp; + add_preprocess_tensor_storage_types(temp, tensor_storage.name, tensor_storage.type); + for (auto& preprocessed_name : temp) { + if (preprocessed_name.first == pair.first) { + if (tensor_should_be_converted(tensor_storage, wtype)) { + pair.second = wtype; + } + found = true; + break; + } + } + if (found) { + break; + } + } + } + } +} + std::string ModelLoader::load_merges() { std::string merges_utf8_str(reinterpret_cast(merges_utf8_c_str), sizeof(merges_utf8_c_str)); return merges_utf8_str; @@ -1598,9 +1918,12 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend } return true; }; - + int tensor_count = 0; + int64_t t1 = ggml_time_ms(); + bool partial = false; for (auto& tensor_storage : processed_tensor_storages) { if (tensor_storage.file_index != file_index) { + ++tensor_count; continue; } ggml_tensor* dst_tensor = NULL; @@ -1612,6 +1935,7 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend } if (dst_tensor == NULL) { + ++tensor_count; continue; } @@ -1629,6 +1953,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend } else if (tensor_storage.is_f8_e4m3) { // inplace op f8_e4m3_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements()); + } else if (tensor_storage.is_f8_e5m2) { + // inplace op + f8_e5m2_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements()); } } else { read_buffer.resize(tensor_storage.nbytes()); @@ -1640,6 +1967,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend } else if (tensor_storage.is_f8_e4m3) { // inplace op f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); + } else if (tensor_storage.is_f8_e5m2) { + // inplace op + f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); } convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data, @@ -1655,6 +1985,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend } else if (tensor_storage.is_f8_e4m3) { // inplace op f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); + } else if (tensor_storage.is_f8_e5m2) { + // inplace op + f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements()); } if (tensor_storage.type == dst_tensor->type) { @@ -1669,12 +2002,21 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend ggml_backend_tensor_set(dst_tensor, convert_buffer.data(), 0, ggml_nbytes(dst_tensor)); } } + size_t tensor_max = processed_tensor_storages.size(); + int64_t t2 = ggml_time_ms(); + pretty_progress(++tensor_count, tensor_max, (t2 - t1) / 1000.0f); + t1 = t2; + partial = tensor_count != tensor_max; } if (zip != NULL) { zip_close(zip); } + if (partial) { + printf("\n"); + } + if (!success) { break; } @@ -1735,9 +2077,6 @@ bool ModelLoader::load_tensors(std::map& tenso if (pair.first.find("cond_stage_model.transformer.text_model.encoder.layers.23") != std::string::npos) { continue; } - if (pair.first.find("alphas_cumprod") != std::string::npos) { - continue; - } if (pair.first.find("alphas_cumprod") != std::string::npos) { continue; @@ -1755,6 +2094,41 @@ bool ModelLoader::load_tensors(std::map& tenso return true; } +std::vector> parse_tensor_type_rules(const std::string& tensor_type_rules) { + std::vector> result; + for (const auto& item : splitString(tensor_type_rules, ',')) { + if (item.size() == 0) + continue; + std::string::size_type pos = item.find('='); + if (pos == std::string::npos) { + LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str()); + continue; + } + std::string tensor_pattern = item.substr(0, pos); + std::string type_name = item.substr(pos + 1); + + ggml_type tensor_type = GGML_TYPE_COUNT; + + if (type_name == "f32") { + tensor_type = GGML_TYPE_F32; + } else { + for (size_t i = 0; i < SD_TYPE_COUNT; i++) { + auto trait = ggml_get_type_traits((ggml_type)i); + if (trait->to_float && trait->type_size && type_name == trait->type_name) { + tensor_type = (ggml_type)i; + } + } + } + + if (tensor_type != GGML_TYPE_COUNT) { + result.emplace_back(tensor_pattern, tensor_type); + } else { + LOG_WARN("ignoring invalid quant override \"%s\"", item.c_str()); + } + } + return result; +} + bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type) { const std::string& name = tensor_storage.name; if (type != GGML_TYPE_COUNT) { @@ -1786,7 +2160,7 @@ bool ModelLoader::tensor_should_be_converted(const TensorStorage& tensor_storage return false; } -bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type) { +bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules_str) { auto backend = ggml_backend_cpu_init(); size_t mem_size = 1 * 1024 * 1024; // for padding mem_size += tensor_storages.size() * ggml_tensor_overhead(); @@ -1796,12 +2170,23 @@ bool ModelLoader::save_to_gguf_file(const std::string& file_path, ggml_type type gguf_context* gguf_ctx = gguf_init_empty(); + auto tensor_type_rules = parse_tensor_type_rules(tensor_type_rules_str); + auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { const std::string& name = tensor_storage.name; + ggml_type tensor_type = tensor_storage.type; + ggml_type dst_type = type; - ggml_type tensor_type = tensor_storage.type; - if (tensor_should_be_converted(tensor_storage, type)) { - tensor_type = type; + for (const auto& tensor_type_rule : tensor_type_rules) { + std::regex pattern(tensor_type_rule.first); + if (std::regex_search(name, pattern)) { + dst_type = tensor_type_rule.second; + break; + } + } + + if (tensor_should_be_converted(tensor_storage, dst_type)) { + tensor_type = dst_type; } ggml_tensor* tensor = ggml_new_tensor(ggml_ctx, tensor_type, tensor_storage.n_dims, tensor_storage.ne); @@ -1860,7 +2245,7 @@ int64_t ModelLoader::get_params_mem_size(ggml_backend_t backend, ggml_type type) return mem_size; } -bool convert(const char* input_path, const char* vae_path, const char* output_path, sd_type_t output_type) { +bool convert(const char* input_path, const char* vae_path, const char* output_path, sd_type_t output_type, const char* tensor_type_rules) { ModelLoader model_loader; if (!model_loader.init_from_file(input_path)) { @@ -1874,6 +2259,6 @@ bool convert(const char* input_path, const char* vae_path, const char* output_pa return false; } } - bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type); + bool success = model_loader.save_to_gguf_file(output_path, (ggml_type)output_type, tensor_type_rules); return success; } diff --git a/model.h b/model.h index 4efbdf813..a6266039a 100644 --- a/model.h +++ b/model.h @@ -12,6 +12,7 @@ #include "ggml-backend.h" #include "ggml.h" +#include "gguf.h" #include "json.hpp" #include "zip.h" @@ -19,21 +20,88 @@ enum SDVersion { VERSION_SD1, + VERSION_SD1_INPAINT, + VERSION_SD1_PIX2PIX, VERSION_SD2, + VERSION_SD2_INPAINT, VERSION_SDXL, + VERSION_SDXL_INPAINT, + VERSION_SDXL_PIX2PIX, VERSION_SVD, - VERSION_SD3_2B, - VERSION_FLUX_DEV, - VERSION_FLUX_SCHNELL, - VERSION_SD3_5_8B, + VERSION_SD3, + VERSION_FLUX, + VERSION_FLUX_FILL, VERSION_COUNT, }; +static inline bool sd_version_is_flux(SDVersion version) { + if (version == VERSION_FLUX || version == VERSION_FLUX_FILL) { + return true; + } + return false; +} + +static inline bool sd_version_is_sd3(SDVersion version) { + if (version == VERSION_SD3) { + return true; + } + return false; +} + +static inline bool sd_version_is_sd1(SDVersion version) { + if (version == VERSION_SD1 || version == VERSION_SD1_INPAINT || version == VERSION_SD1_PIX2PIX) { + return true; + } + return false; +} + +static inline bool sd_version_is_sd2(SDVersion version) { + if (version == VERSION_SD2 || version == VERSION_SD2_INPAINT) { + return true; + } + return false; +} + +static inline bool sd_version_is_sdxl(SDVersion version) { + if (version == VERSION_SDXL || version == VERSION_SDXL_INPAINT || version == VERSION_SDXL_PIX2PIX) { + return true; + } + return false; +} + +static inline bool sd_version_is_inpaint(SDVersion version) { + if (version == VERSION_SD1_INPAINT || version == VERSION_SD2_INPAINT || version == VERSION_SDXL_INPAINT || version == VERSION_FLUX_FILL) { + return true; + } + return false; +} + +static inline bool sd_version_is_dit(SDVersion version) { + if (sd_version_is_flux(version) || sd_version_is_sd3(version)) { + return true; + } + return false; +} + +static inline bool sd_version_is_unet_edit(SDVersion version) { + return version == VERSION_SD1_PIX2PIX || version == VERSION_SDXL_PIX2PIX; +} + +static bool sd_version_is_inpaint_or_unet_edit(SDVersion version) { + return sd_version_is_unet_edit(version) || sd_version_is_inpaint(version); +} + +enum PMVersion { + PM_VERSION_1, + PM_VERSION_2, +}; + struct TensorStorage { std::string name; ggml_type type = GGML_TYPE_F32; bool is_bf16 = false; bool is_f8_e4m3 = false; + bool is_f8_e5m2 = false; int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1}; int n_dims = 0; @@ -63,7 +131,7 @@ struct TensorStorage { } int64_t nbytes_to_read() const { - if (is_bf16 || is_f8_e4m3) { + if (is_bf16 || is_f8_e4m3 || is_f8_e5m2) { return nbytes() / 2; } else { return nbytes(); @@ -113,6 +181,8 @@ struct TensorStorage { type_name = "bf16"; } else if (is_f8_e4m3) { type_name = "f8_e4m3"; + } else if (is_f8_e5m2) { + type_name = "f8_e5m2"; } ss << name << " | " << type_name << " | "; ss << n_dims << " ["; @@ -139,7 +209,7 @@ class ModelLoader { zip_t* zip, std::string dir, size_t file_index, - const std::string& prefix); + const std::string prefix); bool init_from_gguf_file(const std::string& file_path, const std::string& prefix = ""); bool init_from_safetensors_file(const std::string& file_path, const std::string& prefix = ""); @@ -147,17 +217,22 @@ class ModelLoader { bool init_from_diffusers_file(const std::string& file_path, const std::string& prefix = ""); public: + std::map tensor_storages_types; + bool init_from_file(const std::string& file_path, const std::string& prefix = ""); + bool model_is_unet(); SDVersion get_sd_version(); ggml_type get_sd_wtype(); ggml_type get_conditioner_wtype(); ggml_type get_diffusion_model_wtype(); ggml_type get_vae_wtype(); + void set_wtype_override(ggml_type wtype, std::string prefix = ""); bool load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend_t backend); bool load_tensors(std::map& tensors, ggml_backend_t backend, std::set ignore_tensors = {}); - bool save_to_gguf_file(const std::string& file_path, ggml_type type); + + bool save_to_gguf_file(const std::string& file_path, ggml_type type, const std::string& tensor_type_rules); bool tensor_should_be_converted(const TensorStorage& tensor_storage, ggml_type type); int64_t get_params_mem_size(ggml_backend_t backend, ggml_type type = GGML_TYPE_COUNT); ~ModelLoader() = default; diff --git a/pmid.hpp b/pmid.hpp index 381050fef..ea9f02eb6 100644 --- a/pmid.hpp +++ b/pmid.hpp @@ -42,6 +42,370 @@ struct FuseBlock : public GGMLBlock { } }; +/* +class QFormerPerceiver(nn.Module): + def __init__(self, id_embeddings_dim, cross_attention_dim, num_tokens, embedding_dim=1024, use_residual=True, ratio=4): + super().__init__() + + self.num_tokens = num_tokens + self.cross_attention_dim = cross_attention_dim + self.use_residual = use_residual + print(cross_attention_dim*num_tokens) + self.token_proj = nn.Sequential( + nn.Linear(id_embeddings_dim, id_embeddings_dim*ratio), + nn.GELU(), + nn.Linear(id_embeddings_dim*ratio, cross_attention_dim*num_tokens), + ) + self.token_norm = nn.LayerNorm(cross_attention_dim) + self.perceiver_resampler = FacePerceiverResampler( + dim=cross_attention_dim, + depth=4, + dim_head=128, + heads=cross_attention_dim // 128, + embedding_dim=embedding_dim, + output_dim=cross_attention_dim, + ff_mult=4, + ) + + def forward(self, x, last_hidden_state): + x = self.token_proj(x) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.token_norm(x) # cls token + out = self.perceiver_resampler(x, last_hidden_state) # retrieve from patch tokens + if self.use_residual: # TODO: if use_residual is not true + out = x + 1.0 * out + return out +*/ + +struct PMFeedForward : public GGMLBlock { + // network hparams + int dim; + +public: + PMFeedForward(int d, int multi = 4) + : dim(d) { + int inner_dim = dim * multi; + blocks["0"] = std::shared_ptr(new LayerNorm(dim)); + blocks["1"] = std::shared_ptr(new Mlp(dim, inner_dim, dim, false)); + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x) { + auto norm = std::dynamic_pointer_cast(blocks["0"]); + auto ff = std::dynamic_pointer_cast(blocks["1"]); + + x = norm->forward(ctx, x); + x = ff->forward(ctx, x); + return x; + } +}; + +struct PerceiverAttention : public GGMLBlock { + // network hparams + float scale; // = dim_head**-0.5 + int dim_head; // = dim_head + int heads; // = heads +public: + PerceiverAttention(int dim, int dim_h = 64, int h = 8) + : scale(powf(dim_h, -0.5)), dim_head(dim_h), heads(h) { + int inner_dim = dim_head * heads; + blocks["norm1"] = std::shared_ptr(new LayerNorm(dim)); + blocks["norm2"] = std::shared_ptr(new LayerNorm(dim)); + blocks["to_q"] = std::shared_ptr(new Linear(dim, inner_dim, false)); + blocks["to_kv"] = std::shared_ptr(new Linear(dim, inner_dim * 2, false)); + blocks["to_out"] = std::shared_ptr(new Linear(inner_dim, dim, false)); + } + + struct ggml_tensor* reshape_tensor(struct ggml_context* ctx, + struct ggml_tensor* x, + int heads) { + int64_t ne[4]; + for (int i = 0; i < 4; ++i) + ne[i] = x->ne[i]; + // print_ggml_tensor(x, true, "PerceiverAttention reshape x 0: "); + // printf("heads = %d \n", heads); + // x = ggml_view_4d(ctx, x, x->ne[0], x->ne[1], heads, x->ne[2]/heads, + // x->nb[1], x->nb[2], x->nb[3], 0); + x = ggml_reshape_4d(ctx, x, x->ne[0] / heads, heads, x->ne[1], x->ne[2]); + // x = ggml_view_4d(ctx, x, x->ne[0]/heads, heads, x->ne[1], x->ne[2], + // x->nb[1], x->nb[2], x->nb[3], 0); + // x = ggml_cont(ctx, x); + x = ggml_cont(ctx, ggml_permute(ctx, x, 0, 2, 1, 3)); + // print_ggml_tensor(x, true, "PerceiverAttention reshape x 1: "); + // x = ggml_reshape_4d(ctx, x, ne[0], heads, ne[1], ne[2]/heads); + return x; + } + + std::vector chunk_half(struct ggml_context* ctx, + struct ggml_tensor* x) { + auto tlo = ggml_view_4d(ctx, x, x->ne[0] / 2, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], 0); + auto tli = ggml_view_4d(ctx, x, x->ne[0] / 2, x->ne[1], x->ne[2], x->ne[3], x->nb[1], x->nb[2], x->nb[3], x->nb[0] * x->ne[0] / 2); + return {ggml_cont(ctx, tlo), + ggml_cont(ctx, tli)}; + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* latents) { + // x (torch.Tensor): image features + // shape (b, n1, D) + // latent (torch.Tensor): latent features + // shape (b, n2, D) + int64_t ne[4]; + for (int i = 0; i < 4; ++i) + ne[i] = latents->ne[i]; + + auto norm1 = std::dynamic_pointer_cast(blocks["norm1"]); + auto norm2 = std::dynamic_pointer_cast(blocks["norm2"]); + x = norm1->forward(ctx, x); + latents = norm2->forward(ctx, latents); + auto to_q = std::dynamic_pointer_cast(blocks["to_q"]); + auto q = to_q->forward(ctx, latents); + + auto kv_input = ggml_concat(ctx, x, latents, 1); + auto to_kv = std::dynamic_pointer_cast(blocks["to_kv"]); + auto kv = to_kv->forward(ctx, kv_input); + auto k = ggml_view_4d(ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, 0); + auto v = ggml_view_4d(ctx, kv, kv->ne[0] / 2, kv->ne[1], kv->ne[2], kv->ne[3], kv->nb[1] / 2, kv->nb[2] / 2, kv->nb[3] / 2, kv->nb[0] * (kv->ne[0] / 2)); + k = ggml_cont(ctx, k); + v = ggml_cont(ctx, v); + q = reshape_tensor(ctx, q, heads); + k = reshape_tensor(ctx, k, heads); + v = reshape_tensor(ctx, v, heads); + scale = 1.f / sqrt(sqrt((float)dim_head)); + k = ggml_scale_inplace(ctx, k, scale); + q = ggml_scale_inplace(ctx, q, scale); + // auto weight = ggml_mul_mat(ctx, q, k); + auto weight = ggml_mul_mat(ctx, k, q); // NOTE order of mul is opposite to pytorch + + // GGML's softmax() is equivalent to pytorch's softmax(x, dim=-1) + // in this case, dimension along which Softmax will be computed is the last dim + // in torch and the first dim in GGML, consistent with the convention that pytorch's + // last dimension (varying most rapidly) corresponds to GGML's first (varying most rapidly). + // weight = ggml_soft_max(ctx, weight); + weight = ggml_soft_max_inplace(ctx, weight); + v = ggml_cont(ctx, ggml_transpose(ctx, v)); + // auto out = ggml_mul_mat(ctx, weight, v); + auto out = ggml_mul_mat(ctx, v, weight); // NOTE order of mul is opposite to pytorch + out = ggml_cont(ctx, ggml_permute(ctx, out, 0, 2, 1, 3)); + out = ggml_reshape_3d(ctx, out, ne[0], ne[1], ggml_nelements(out) / (ne[0] * ne[1])); + auto to_out = std::dynamic_pointer_cast(blocks["to_out"]); + out = to_out->forward(ctx, out); + return out; + } +}; + +struct FacePerceiverResampler : public GGMLBlock { + // network hparams + int depth; + +public: + FacePerceiverResampler(int dim = 768, + int d = 4, + int dim_head = 64, + int heads = 16, + int embedding_dim = 1280, + int output_dim = 768, + int ff_mult = 4) + : depth(d) { + blocks["proj_in"] = std::shared_ptr(new Linear(embedding_dim, dim, true)); + blocks["proj_out"] = std::shared_ptr(new Linear(dim, output_dim, true)); + blocks["norm_out"] = std::shared_ptr(new LayerNorm(output_dim)); + + for (int i = 0; i < depth; i++) { + std::string name = "layers." + std::to_string(i) + ".0"; + blocks[name] = std::shared_ptr(new PerceiverAttention(dim, dim_head, heads)); + name = "layers." + std::to_string(i) + ".1"; + blocks[name] = std::shared_ptr(new PMFeedForward(dim, ff_mult)); + } + } + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* latents, + struct ggml_tensor* x) { + // x: [N, channels, h, w] + auto proj_in = std::dynamic_pointer_cast(blocks["proj_in"]); + auto proj_out = std::dynamic_pointer_cast(blocks["proj_out"]); + auto norm_out = std::dynamic_pointer_cast(blocks["norm_out"]); + + x = proj_in->forward(ctx, x); + for (int i = 0; i < depth; i++) { + std::string name = "layers." + std::to_string(i) + ".0"; + auto attn = std::dynamic_pointer_cast(blocks[name]); + name = "layers." + std::to_string(i) + ".1"; + auto ff = std::dynamic_pointer_cast(blocks[name]); + auto t = attn->forward(ctx, x, latents); + latents = ggml_add(ctx, t, latents); + t = ff->forward(ctx, latents); + latents = ggml_add(ctx, t, latents); + } + latents = proj_out->forward(ctx, latents); + latents = norm_out->forward(ctx, latents); + return latents; + } +}; + +struct QFormerPerceiver : public GGMLBlock { + // network hparams + int num_tokens; + int cross_attention_dim; + bool use_residul; + +public: + QFormerPerceiver(int id_embeddings_dim, int cross_attention_d, int num_t, int embedding_dim = 1024, bool use_r = true, int ratio = 4) + : cross_attention_dim(cross_attention_d), num_tokens(num_t), use_residul(use_r) { + blocks["token_proj"] = std::shared_ptr(new Mlp(id_embeddings_dim, + id_embeddings_dim * ratio, + cross_attention_dim * num_tokens, + true)); + blocks["token_norm"] = std::shared_ptr(new LayerNorm(cross_attention_d)); + blocks["perceiver_resampler"] = std::shared_ptr(new FacePerceiverResampler( + cross_attention_dim, + 4, + 128, + cross_attention_dim / 128, + embedding_dim, + cross_attention_dim, + 4)); + } + + /* + def forward(self, x, last_hidden_state): + x = self.token_proj(x) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + x = self.token_norm(x) # cls token + out = self.perceiver_resampler(x, last_hidden_state) # retrieve from patch tokens + if self.use_residual: # TODO: if use_residual is not true + out = x + 1.0 * out + return out + */ + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* x, + struct ggml_tensor* last_hidden_state) { + // x: [N, channels, h, w] + auto token_proj = std::dynamic_pointer_cast(blocks["token_proj"]); + auto token_norm = std::dynamic_pointer_cast(blocks["token_norm"]); + auto perceiver_resampler = std::dynamic_pointer_cast(blocks["perceiver_resampler"]); + + x = token_proj->forward(ctx, x); + int64_t nel = ggml_nelements(x); + x = ggml_reshape_3d(ctx, x, cross_attention_dim, num_tokens, nel / (cross_attention_dim * num_tokens)); + x = token_norm->forward(ctx, x); + struct ggml_tensor* out = perceiver_resampler->forward(ctx, x, last_hidden_state); + if (use_residul) + out = ggml_add(ctx, x, out); + return out; + } +}; + +/* +class FacePerceiverResampler(torch.nn.Module): + def __init__( + self, + *, + dim=768, + depth=4, + dim_head=64, + heads=16, + embedding_dim=1280, + output_dim=768, + ff_mult=4, + ): + super().__init__() + + self.proj_in = torch.nn.Linear(embedding_dim, dim) + self.proj_out = torch.nn.Linear(dim, output_dim) + self.norm_out = torch.nn.LayerNorm(output_dim) + self.layers = torch.nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + torch.nn.ModuleList( + [ + PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads), + FeedForward(dim=dim, mult=ff_mult), + ] + ) + ) + + def forward(self, latents, x): + x = self.proj_in(x) + for attn, ff in self.layers: + latents = attn(x, latents) + latents + latents = ff(latents) + latents + latents = self.proj_out(latents) + return self.norm_out(latents) +*/ + +/* + +def FeedForward(dim, mult=4): + inner_dim = int(dim * mult) + return nn.Sequential( + nn.LayerNorm(dim), + nn.Linear(dim, inner_dim, bias=False), + nn.GELU(), + nn.Linear(inner_dim, dim, bias=False), + ) + +def reshape_tensor(x, heads): + bs, length, width = x.shape + # (bs, length, width) --> (bs, length, n_heads, dim_per_head) + x = x.view(bs, length, heads, -1) + # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) + x = x.transpose(1, 2) + # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) + x = x.reshape(bs, heads, length, -1) + return x + +class PerceiverAttention(nn.Module): + def __init__(self, *, dim, dim_head=64, heads=8): + super().__init__() + self.scale = dim_head**-0.5 + self.dim_head = dim_head + self.heads = heads + inner_dim = dim_head * heads + + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + + self.to_q = nn.Linear(dim, inner_dim, bias=False) + self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) + self.to_out = nn.Linear(inner_dim, dim, bias=False) + + def forward(self, x, latents): + """ + Args: + x (torch.Tensor): image features + shape (b, n1, D) + latent (torch.Tensor): latent features + shape (b, n2, D) + """ + x = self.norm1(x) + latents = self.norm2(latents) + + b, l, _ = latents.shape + + q = self.to_q(latents) + kv_input = torch.cat((x, latents), dim=-2) + k, v = self.to_kv(kv_input).chunk(2, dim=-1) + + q = reshape_tensor(q, self.heads) + k = reshape_tensor(k, self.heads) + v = reshape_tensor(v, self.heads) + + # attention + scale = 1 / math.sqrt(math.sqrt(self.dim_head)) + weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards + weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) + out = weight @ v + + out = out.permute(0, 2, 1, 3).reshape(b, l, -1) + + return self.to_out(out) + +*/ + struct FuseModule : public GGMLBlock { // network hparams int embed_dim; @@ -61,12 +425,19 @@ struct FuseModule : public GGMLBlock { auto mlp2 = std::dynamic_pointer_cast(blocks["mlp2"]); auto layer_norm = std::dynamic_pointer_cast(blocks["layer_norm"]); - auto prompt_embeds0 = ggml_cont(ctx, ggml_permute(ctx, prompt_embeds, 2, 0, 1, 3)); - auto id_embeds0 = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 2, 0, 1, 3)); - // concat is along dim 2 - auto stacked_id_embeds = ggml_concat(ctx, prompt_embeds0, id_embeds0, 2); - stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 1, 2, 0, 3)); + // print_ggml_tensor(id_embeds, true, "Fuseblock id_embeds: "); + // print_ggml_tensor(prompt_embeds, true, "Fuseblock prompt_embeds: "); + // auto prompt_embeds0 = ggml_cont(ctx, ggml_permute(ctx, prompt_embeds, 2, 0, 1, 3)); + // auto id_embeds0 = ggml_cont(ctx, ggml_permute(ctx, id_embeds, 2, 0, 1, 3)); + // print_ggml_tensor(id_embeds0, true, "Fuseblock id_embeds0: "); + // print_ggml_tensor(prompt_embeds0, true, "Fuseblock prompt_embeds0: "); + // concat is along dim 2 + // auto stacked_id_embeds = ggml_concat(ctx, prompt_embeds0, id_embeds0, 2); + auto stacked_id_embeds = ggml_concat(ctx, prompt_embeds, id_embeds, 0); + // print_ggml_tensor(stacked_id_embeds, true, "Fuseblock stacked_id_embeds 0: "); + // stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 1, 2, 0, 3)); + // print_ggml_tensor(stacked_id_embeds, true, "Fuseblock stacked_id_embeds 1: "); // stacked_id_embeds = mlp1.forward(ctx, stacked_id_embeds); // stacked_id_embeds = ggml_add(ctx, stacked_id_embeds, prompt_embeds); // stacked_id_embeds = mlp2.forward(ctx, stacked_id_embeds); @@ -77,6 +448,8 @@ struct FuseModule : public GGMLBlock { stacked_id_embeds = mlp2->forward(ctx, stacked_id_embeds); stacked_id_embeds = layer_norm->forward(ctx, stacked_id_embeds); + // print_ggml_tensor(stacked_id_embeds, true, "Fuseblock stacked_id_embeds 1: "); + return stacked_id_embeds; } @@ -98,23 +471,31 @@ struct FuseModule : public GGMLBlock { // print_ggml_tensor(class_tokens_mask_pos, true, "class_tokens_mask_pos"); struct ggml_tensor* image_token_embeds = ggml_get_rows(ctx, prompt_embeds, class_tokens_mask_pos); ggml_set_name(image_token_embeds, "image_token_embeds"); + valid_id_embeds = ggml_reshape_2d(ctx, valid_id_embeds, valid_id_embeds->ne[0], + ggml_nelements(valid_id_embeds) / valid_id_embeds->ne[0]); struct ggml_tensor* stacked_id_embeds = fuse_fn(ctx, image_token_embeds, valid_id_embeds); - stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3)); + // stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3)); + // print_ggml_tensor(stacked_id_embeds, true, "AA stacked_id_embeds"); + // print_ggml_tensor(left, true, "AA left"); + // print_ggml_tensor(right, true, "AA right"); if (left && right) { - stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 2); - stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 2); + stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 1); + stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 1); } else if (left) { - stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 2); + stacked_id_embeds = ggml_concat(ctx, left, stacked_id_embeds, 1); } else if (right) { - stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 2); + stacked_id_embeds = ggml_concat(ctx, stacked_id_embeds, right, 1); } - stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3)); + // print_ggml_tensor(stacked_id_embeds, true, "BB stacked_id_embeds"); + // stacked_id_embeds = ggml_cont(ctx, ggml_permute(ctx, stacked_id_embeds, 0, 2, 1, 3)); + // print_ggml_tensor(stacked_id_embeds, true, "CC stacked_id_embeds"); class_tokens_mask = ggml_cont(ctx, ggml_transpose(ctx, class_tokens_mask)); class_tokens_mask = ggml_repeat(ctx, class_tokens_mask, prompt_embeds); prompt_embeds = ggml_mul(ctx, prompt_embeds, class_tokens_mask); struct ggml_tensor* updated_prompt_embeds = ggml_add(ctx, prompt_embeds, stacked_id_embeds); ggml_set_name(updated_prompt_embeds, "updated_prompt_embeds"); + // print_ggml_tensor(updated_prompt_embeds, true, "updated_prompt_embeds: "); return updated_prompt_embeds; } }; @@ -159,10 +540,77 @@ struct PhotoMakerIDEncoderBlock : public CLIPVisionModelProjection { } }; +struct PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock : public CLIPVisionModelProjection { + int cross_attention_dim; + int num_tokens; + + PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock(int id_embeddings_dim = 512) + : CLIPVisionModelProjection(OPENAI_CLIP_VIT_L_14), + cross_attention_dim(2048), + num_tokens(2) { + blocks["visual_projection_2"] = std::shared_ptr(new Linear(1024, 1280, false)); + blocks["fuse_module"] = std::shared_ptr(new FuseModule(2048)); + /* + cross_attention_dim = 2048 + # projection + self.num_tokens = 2 + self.cross_attention_dim = cross_attention_dim + self.qformer_perceiver = QFormerPerceiver( + id_embeddings_dim, + cross_attention_dim, + self.num_tokens, + )*/ + blocks["qformer_perceiver"] = std::shared_ptr(new QFormerPerceiver(id_embeddings_dim, + cross_attention_dim, + num_tokens)); + } + + /* + def forward(self, id_pixel_values, prompt_embeds, class_tokens_mask, id_embeds): + b, num_inputs, c, h, w = id_pixel_values.shape + id_pixel_values = id_pixel_values.view(b * num_inputs, c, h, w) + + last_hidden_state = self.vision_model(id_pixel_values)[0] + id_embeds = id_embeds.view(b * num_inputs, -1) + + id_embeds = self.qformer_perceiver(id_embeds, last_hidden_state) + id_embeds = id_embeds.view(b, num_inputs, self.num_tokens, -1) + updated_prompt_embeds = self.fuse_module(prompt_embeds, id_embeds, class_tokens_mask) + */ + + struct ggml_tensor* forward(struct ggml_context* ctx, + struct ggml_tensor* id_pixel_values, + struct ggml_tensor* prompt_embeds, + struct ggml_tensor* class_tokens_mask, + struct ggml_tensor* class_tokens_mask_pos, + struct ggml_tensor* id_embeds, + struct ggml_tensor* left, + struct ggml_tensor* right) { + // x: [N, channels, h, w] + auto vision_model = std::dynamic_pointer_cast(blocks["vision_model"]); + auto fuse_module = std::dynamic_pointer_cast(blocks["fuse_module"]); + auto qformer_perceiver = std::dynamic_pointer_cast(blocks["qformer_perceiver"]); + + // struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values); // [N, hidden_size] + struct ggml_tensor* last_hidden_state = vision_model->forward(ctx, id_pixel_values, false); // [N, hidden_size] + id_embeds = qformer_perceiver->forward(ctx, id_embeds, last_hidden_state); + + struct ggml_tensor* updated_prompt_embeds = fuse_module->forward(ctx, + prompt_embeds, + id_embeds, + class_tokens_mask, + class_tokens_mask_pos, + left, right); + return updated_prompt_embeds; + } +}; + struct PhotoMakerIDEncoder : public GGMLRunner { public: - SDVersion version = VERSION_SDXL; + SDVersion version = VERSION_SDXL; + PMVersion pm_version = PM_VERSION_1; PhotoMakerIDEncoderBlock id_encoder; + PhotoMakerIDEncoder_CLIPInsightfaceExtendtokenBlock id_encoder2; float style_strength; std::vector ctm; @@ -175,25 +623,38 @@ struct PhotoMakerIDEncoder : public GGMLRunner { std::vector zeros_right; public: - PhotoMakerIDEncoder(ggml_backend_t backend, ggml_type wtype, SDVersion version = VERSION_SDXL, float sty = 20.f) - : GGMLRunner(backend, wtype), + PhotoMakerIDEncoder(ggml_backend_t backend, std::map& tensor_types, const std::string prefix, SDVersion version = VERSION_SDXL, PMVersion pm_v = PM_VERSION_1, float sty = 20.f) + : GGMLRunner(backend), version(version), + pm_version(pm_v), style_strength(sty) { - id_encoder.init(params_ctx, wtype); + if (pm_version == PM_VERSION_1) { + id_encoder.init(params_ctx, tensor_types, prefix); + } else if (pm_version == PM_VERSION_2) { + id_encoder2.init(params_ctx, tensor_types, prefix); + } } std::string get_desc() { return "pmid"; } + PMVersion get_version() const { + return pm_version; + } + void get_param_tensors(std::map& tensors, const std::string prefix) { - id_encoder.get_param_tensors(tensors, prefix); + if (pm_version == PM_VERSION_1) + id_encoder.get_param_tensors(tensors, prefix); + else if (pm_version == PM_VERSION_2) + id_encoder2.get_param_tensors(tensors, prefix); } struct ggml_cgraph* build_graph( // struct ggml_allocr* allocr, struct ggml_tensor* id_pixel_values, struct ggml_tensor* prompt_embeds, - std::vector& class_tokens_mask) { + std::vector& class_tokens_mask, + struct ggml_tensor* id_embeds) { ctm.clear(); ctmf16.clear(); ctmpos.clear(); @@ -214,25 +675,32 @@ struct PhotoMakerIDEncoder : public GGMLRunner { struct ggml_tensor* id_pixel_values_d = to_backend(id_pixel_values); struct ggml_tensor* prompt_embeds_d = to_backend(prompt_embeds); + struct ggml_tensor* id_embeds_d = to_backend(id_embeds); struct ggml_tensor* left = NULL; struct ggml_tensor* right = NULL; for (int i = 0; i < class_tokens_mask.size(); i++) { if (class_tokens_mask[i]) { + // printf(" 1,"); ctm.push_back(0.f); // here use 0.f instead of 1.f to make a scale mask ctmf16.push_back(ggml_fp32_to_fp16(0.f)); // here use 0.f instead of 1.f to make a scale mask ctmpos.push_back(i); } else { + // printf(" 0,"); ctm.push_back(1.f); // here use 1.f instead of 0.f to make a scale mask ctmf16.push_back(ggml_fp32_to_fp16(1.f)); // here use 0.f instead of 1.f to make a scale mask } } + // printf("\n"); if (ctmpos[0] > 0) { - left = ggml_new_tensor_3d(ctx0, type, hidden_size, 1, ctmpos[0]); + // left = ggml_new_tensor_3d(ctx0, type, hidden_size, 1, ctmpos[0]); + left = ggml_new_tensor_3d(ctx0, type, hidden_size, ctmpos[0], 1); } if (ctmpos[ctmpos.size() - 1] < seq_length - 1) { + // right = ggml_new_tensor_3d(ctx0, type, + // hidden_size, 1, seq_length - ctmpos[ctmpos.size() - 1] - 1); right = ggml_new_tensor_3d(ctx0, type, - hidden_size, 1, seq_length - ctmpos[ctmpos.size() - 1] - 1); + hidden_size, seq_length - ctmpos[ctmpos.size() - 1] - 1, 1); } struct ggml_tensor* class_tokens_mask_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ctmpos.size()); @@ -265,12 +733,23 @@ struct PhotoMakerIDEncoder : public GGMLRunner { } } } - struct ggml_tensor* updated_prompt_embeds = id_encoder.forward(ctx0, - id_pixel_values_d, - prompt_embeds_d, - class_tokens_mask_d, - class_tokens_mask_pos, - left, right); + struct ggml_tensor* updated_prompt_embeds = NULL; + if (pm_version == PM_VERSION_1) + updated_prompt_embeds = id_encoder.forward(ctx0, + id_pixel_values_d, + prompt_embeds_d, + class_tokens_mask_d, + class_tokens_mask_pos, + left, right); + else if (pm_version == PM_VERSION_2) + updated_prompt_embeds = id_encoder2.forward(ctx0, + id_pixel_values_d, + prompt_embeds_d, + class_tokens_mask_d, + class_tokens_mask_pos, + id_embeds_d, + left, right); + ggml_build_forward_expand(gf, updated_prompt_embeds); return gf; @@ -279,12 +758,13 @@ struct PhotoMakerIDEncoder : public GGMLRunner { void compute(const int n_threads, struct ggml_tensor* id_pixel_values, struct ggml_tensor* prompt_embeds, + struct ggml_tensor* id_embeds, std::vector& class_tokens_mask, struct ggml_tensor** updated_prompt_embeds, ggml_context* output_ctx) { auto get_graph = [&]() -> struct ggml_cgraph* { // return build_graph(compute_allocr, id_pixel_values, prompt_embeds, class_tokens_mask); - return build_graph(id_pixel_values, prompt_embeds, class_tokens_mask); + return build_graph(id_pixel_values, prompt_embeds, class_tokens_mask, id_embeds); }; // GGMLRunner::compute(get_graph, n_threads, updated_prompt_embeds); @@ -292,4 +772,74 @@ struct PhotoMakerIDEncoder : public GGMLRunner { } }; +struct PhotoMakerIDEmbed : public GGMLRunner { + std::map tensors; + std::string file_path; + ModelLoader* model_loader; + bool load_failed = false; + bool applied = false; + + PhotoMakerIDEmbed(ggml_backend_t backend, + ModelLoader* ml, + const std::string& file_path = "", + const std::string& prefix = "") + : file_path(file_path), GGMLRunner(backend), model_loader(ml) { + if (!model_loader->init_from_file(file_path, prefix)) { + load_failed = true; + } + } + + std::string get_desc() { + return "id_embeds"; + } + + bool load_from_file(bool filter_tensor = false) { + LOG_INFO("loading PhotoMaker ID Embeds from '%s'", file_path.c_str()); + + if (load_failed) { + LOG_ERROR("init photomaker id embed from file failed: '%s'", file_path.c_str()); + return false; + } + + bool dry_run = true; + auto on_new_tensor_cb = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) -> bool { + const std::string& name = tensor_storage.name; + + if (filter_tensor && !contains(name, "pmid.id_embeds")) { + // LOG_INFO("skipping LoRA tesnor '%s'", name.c_str()); + return true; + } + if (dry_run) { + struct ggml_tensor* real = ggml_new_tensor(params_ctx, + tensor_storage.type, + tensor_storage.n_dims, + tensor_storage.ne); + tensors[name] = real; + } else { + auto real = tensors[name]; + *dst_tensor = real; + } + + return true; + }; + + model_loader->load_tensors(on_new_tensor_cb, backend); + alloc_params_buffer(); + + dry_run = false; + model_loader->load_tensors(on_new_tensor_cb, backend); + + LOG_DEBUG("finished loading PhotoMaker ID Embeds "); + return true; + } + + struct ggml_tensor* get() { + std::map::iterator pos; + pos = tensors.find("pmid.id_embeds"); + if (pos != tensors.end()) + return pos->second; + return NULL; + } +}; + #endif // __PMI_HPP__ diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 4d28a147b..402585f1c 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -26,13 +26,17 @@ const char* model_version_to_str[] = { "SD 1.x", + "SD 1.x Inpaint", + "Instruct-Pix2Pix", "SD 2.x", + "SD 2.x Inpaint", "SDXL", + "SDXL Inpaint", + "SDXL Instruct-Pix2Pix", "SVD", - "SD3 2B", - "Flux Dev", - "Flux Schnell", - "SD3.5 8B"}; + "SD3.x", + "Flux", + "Flux Fill"}; const char* sampling_methods_str[] = { "Euler A", @@ -45,7 +49,8 @@ const char* sampling_methods_str[] = { "iPNDM", "iPNDM_v", "LCM", -}; + "DDIM \"trailing\"", + "TCD"}; /*================================================== Helper Functions ================================================*/ @@ -93,12 +98,16 @@ class StableDiffusionGGML { std::shared_ptr control_net; std::shared_ptr pmid_model; std::shared_ptr pmid_lora; + std::shared_ptr pmid_id_embeds; std::string taesd_path; bool use_tiny_autoencoder = false; bool vae_tiling = false; bool stacked_id = false; + bool is_using_v_parameterization = false; + bool is_using_edm_v_parameterization = false; + std::map tensors; std::string lora_model_dir; @@ -109,22 +118,6 @@ class StableDiffusionGGML { StableDiffusionGGML() = default; - StableDiffusionGGML(int n_threads, - bool vae_decode_only, - bool free_params_immediately, - std::string lora_model_dir, - rng_type_t rng_type) - : n_threads(n_threads), - vae_decode_only(vae_decode_only), - free_params_immediately(free_params_immediately), - lora_model_dir(lora_model_dir) { - if (rng_type == STD_DEFAULT_RNG) { - rng = std::make_shared(); - } else if (rng_type == CUDA_RNG) { - rng = std::make_shared(); - } - } - ~StableDiffusionGGML() { if (clip_backend != backend) { ggml_backend_free(clip_backend); @@ -138,30 +131,14 @@ class StableDiffusionGGML { ggml_backend_free(backend); } - bool load_from_file(const std::string& model_path, - const std::string& clip_l_path, - const std::string& clip_g_path, - const std::string& t5xxl_path, - const std::string& diffusion_model_path, - const std::string& vae_path, - const std::string control_net_path, - const std::string embeddings_path, - const std::string id_embeddings_path, - const std::string& taesd_path, - bool vae_tiling_, - ggml_type wtype, - schedule_t schedule, - bool clip_on_cpu, - bool control_net_cpu, - bool vae_on_cpu) { - use_tiny_autoencoder = taesd_path.size() > 0; -#ifdef SD_USE_CUBLAS + void init_backend() { +#ifdef SD_USE_CUDA LOG_DEBUG("Using CUDA backend"); backend = ggml_backend_cuda_init(0); #endif #ifdef SD_USE_METAL LOG_DEBUG("Using Metal backend"); - ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr); + ggml_log_set(ggml_log_callback_default, nullptr); backend = ggml_backend_metal_init(); #endif #ifdef SD_USE_VULKAN @@ -173,6 +150,14 @@ class StableDiffusionGGML { LOG_WARN("Failed to initialize Vulkan backend"); } #endif +#ifdef SD_USE_OPENCL + LOG_DEBUG("Using OpenCL backend"); + // ggml_log_set(ggml_log_callback_default, nullptr); // Optional ggml logs + backend = ggml_backend_opencl_init(); + if (!backend) { + LOG_WARN("Failed to initialize OpenCL backend"); + } +#endif #ifdef SD_USE_SYCL LOG_DEBUG("Using SYCL backend"); backend = ggml_backend_sycl_init(0); @@ -182,66 +167,81 @@ class StableDiffusionGGML { LOG_DEBUG("Using CPU backend"); backend = ggml_backend_cpu_init(); } -#ifdef SD_USE_FLASH_ATTENTION -#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined(SD_USE_SYCL) || defined(SD_USE_VULKAN) - LOG_WARN("Flash Attention not supported with GPU Backend"); -#else - LOG_INFO("Flash Attention enabled"); -#endif -#endif - ModelLoader model_loader; + } + + bool init(const sd_ctx_params_t* sd_ctx_params) { + n_threads = sd_ctx_params->n_threads; + vae_decode_only = sd_ctx_params->vae_decode_only; + free_params_immediately = sd_ctx_params->free_params_immediately; + lora_model_dir = SAFE_STR(sd_ctx_params->lora_model_dir); + taesd_path = SAFE_STR(sd_ctx_params->taesd_path); + use_tiny_autoencoder = taesd_path.size() > 0; + vae_tiling = sd_ctx_params->vae_tiling; - vae_tiling = vae_tiling_; + if (sd_ctx_params->rng_type == STD_DEFAULT_RNG) { + rng = std::make_shared(); + } else if (sd_ctx_params->rng_type == CUDA_RNG) { + rng = std::make_shared(); + } + + init_backend(); - if (model_path.size() > 0) { - LOG_INFO("loading model from '%s'", model_path.c_str()); - if (!model_loader.init_from_file(model_path)) { - LOG_ERROR("init model loader from file failed: '%s'", model_path.c_str()); + ModelLoader model_loader; + + if (strlen(SAFE_STR(sd_ctx_params->model_path)) > 0) { + LOG_INFO("loading model from '%s'", sd_ctx_params->model_path); + if (!model_loader.init_from_file(sd_ctx_params->model_path)) { + LOG_ERROR("init model loader from file failed: '%s'", sd_ctx_params->model_path); } } - if (clip_l_path.size() > 0) { - LOG_INFO("loading clip_l from '%s'", clip_l_path.c_str()); - if (!model_loader.init_from_file(clip_l_path, "text_encoders.clip_l.transformer.")) { - LOG_WARN("loading clip_l from '%s' failed", clip_l_path.c_str()); + if (strlen(SAFE_STR(sd_ctx_params->diffusion_model_path)) > 0) { + LOG_INFO("loading diffusion model from '%s'", sd_ctx_params->diffusion_model_path); + if (!model_loader.init_from_file(sd_ctx_params->diffusion_model_path, "model.diffusion_model.")) { + LOG_WARN("loading diffusion model from '%s' failed", sd_ctx_params->diffusion_model_path); } } - if (clip_g_path.size() > 0) { - LOG_INFO("loading clip_g from '%s'", clip_g_path.c_str()); - if (!model_loader.init_from_file(clip_g_path, "text_encoders.clip_g.transformer.")) { - LOG_WARN("loading clip_g from '%s' failed", clip_g_path.c_str()); + bool is_unet = model_loader.model_is_unet(); + + if (strlen(SAFE_STR(sd_ctx_params->clip_l_path)) > 0) { + LOG_INFO("loading clip_l from '%s'", sd_ctx_params->clip_l_path); + std::string prefix = is_unet ? "cond_stage_model.transformer." : "text_encoders.clip_l.transformer."; + if (!model_loader.init_from_file(sd_ctx_params->clip_l_path, prefix)) { + LOG_WARN("loading clip_l from '%s' failed", sd_ctx_params->clip_l_path); } } - if (t5xxl_path.size() > 0) { - LOG_INFO("loading t5xxl from '%s'", t5xxl_path.c_str()); - if (!model_loader.init_from_file(t5xxl_path, "text_encoders.t5xxl.transformer.")) { - LOG_WARN("loading t5xxl from '%s' failed", t5xxl_path.c_str()); + if (strlen(SAFE_STR(sd_ctx_params->clip_g_path)) > 0) { + LOG_INFO("loading clip_g from '%s'", sd_ctx_params->clip_g_path); + std::string prefix = is_unet ? "cond_stage_model.1.transformer." : "text_encoders.clip_g.transformer."; + if (!model_loader.init_from_file(sd_ctx_params->clip_g_path, prefix)) { + LOG_WARN("loading clip_g from '%s' failed", sd_ctx_params->clip_g_path); } } - if (diffusion_model_path.size() > 0) { - LOG_INFO("loading diffusion model from '%s'", diffusion_model_path.c_str()); - if (!model_loader.init_from_file(diffusion_model_path, "model.diffusion_model.")) { - LOG_WARN("loading diffusion model from '%s' failed", diffusion_model_path.c_str()); + if (strlen(SAFE_STR(sd_ctx_params->t5xxl_path)) > 0) { + LOG_INFO("loading t5xxl from '%s'", sd_ctx_params->t5xxl_path); + if (!model_loader.init_from_file(sd_ctx_params->t5xxl_path, "text_encoders.t5xxl.transformer.")) { + LOG_WARN("loading t5xxl from '%s' failed", sd_ctx_params->t5xxl_path); } } - if (vae_path.size() > 0) { - LOG_INFO("loading vae from '%s'", vae_path.c_str()); - if (!model_loader.init_from_file(vae_path, "vae.")) { - LOG_WARN("loading vae from '%s' failed", vae_path.c_str()); + if (strlen(SAFE_STR(sd_ctx_params->vae_path)) > 0) { + LOG_INFO("loading vae from '%s'", sd_ctx_params->vae_path); + if (!model_loader.init_from_file(sd_ctx_params->vae_path, "vae.")) { + LOG_WARN("loading vae from '%s' failed", sd_ctx_params->vae_path); } } version = model_loader.get_sd_version(); if (version == VERSION_COUNT) { - LOG_ERROR("get sd version from file failed: '%s'", model_path.c_str()); + LOG_ERROR("get sd version from file failed: '%s'", SAFE_STR(sd_ctx_params->model_path)); return false; } LOG_INFO("Version: %s ", model_version_to_str[version]); + ggml_type wtype = (ggml_type)sd_ctx_params->wtype; if (wtype == GGML_TYPE_COUNT) { model_wtype = model_loader.get_sd_wtype(); if (model_wtype == GGML_TYPE_COUNT) { @@ -266,52 +266,56 @@ class StableDiffusionGGML { conditioner_wtype = wtype; diffusion_model_wtype = wtype; vae_wtype = wtype; + model_loader.set_wtype_override(wtype); } - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { vae_wtype = GGML_TYPE_F32; + model_loader.set_wtype_override(GGML_TYPE_F32, "vae."); } - LOG_INFO("Weight type: %s", ggml_type_name(model_wtype)); - LOG_INFO("Conditioner weight type: %s", ggml_type_name(conditioner_wtype)); - LOG_INFO("Diffusion model weight type: %s", ggml_type_name(diffusion_model_wtype)); - LOG_INFO("VAE weight type: %s", ggml_type_name(vae_wtype)); + LOG_INFO("Weight type: %s", model_wtype != GGML_TYPE_COUNT ? ggml_type_name(model_wtype) : "??"); + LOG_INFO("Conditioner weight type: %s", conditioner_wtype != GGML_TYPE_COUNT ? ggml_type_name(conditioner_wtype) : "??"); + LOG_INFO("Diffusion model weight type: %s", diffusion_model_wtype != GGML_TYPE_COUNT ? ggml_type_name(diffusion_model_wtype) : "??"); + LOG_INFO("VAE weight type: %s", vae_wtype != GGML_TYPE_COUNT ? ggml_type_name(vae_wtype) : "??"); LOG_DEBUG("ggml tensor size = %d bytes", (int)sizeof(ggml_tensor)); - if (version == VERSION_SDXL) { + if (sd_version_is_sdxl(version)) { scale_factor = 0.13025f; - if (vae_path.size() == 0 && taesd_path.size() == 0) { + if (strlen(SAFE_STR(sd_ctx_params->vae_path)) == 0 && strlen(SAFE_STR(sd_ctx_params->taesd_path)) == 0) { LOG_WARN( "!!!It looks like you are using SDXL model. " "If you find that the generated images are completely black, " "try specifying SDXL VAE FP16 Fix with the --vae parameter. " "You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors"); } - } else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { + } else if (sd_version_is_sd3(version)) { scale_factor = 1.5305f; - } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + } else if (sd_version_is_flux(version)) { scale_factor = 0.3611; // TODO: shift_factor } + bool clip_on_cpu = sd_ctx_params->keep_clip_on_cpu; + if (version == VERSION_SVD) { - clip_vision = std::make_shared(backend, conditioner_wtype); + clip_vision = std::make_shared(backend, model_loader.tensor_storages_types); clip_vision->alloc_params_buffer(); clip_vision->get_param_tensors(tensors); - diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); + diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types, version); diffusion_model->alloc_params_buffer(); diffusion_model->get_param_tensors(tensors); - first_stage_model = std::make_shared(backend, vae_wtype, vae_decode_only, true, version); + first_stage_model = std::make_shared(backend, model_loader.tensor_storages_types, "first_stage_model", vae_decode_only, true, version); LOG_DEBUG("vae_decode_only %d", vae_decode_only); first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); } else { clip_backend = backend; bool use_t5xxl = false; - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + if (sd_version_is_dit(version)) { use_t5xxl = true; } if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) { @@ -322,16 +326,56 @@ class StableDiffusionGGML { LOG_INFO("CLIP: Using CPU backend"); clip_backend = ggml_backend_cpu_init(); } - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { - cond_stage_model = std::make_shared(clip_backend, conditioner_wtype); - diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); - } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { - cond_stage_model = std::make_shared(clip_backend, conditioner_wtype); - diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); + if (sd_ctx_params->diffusion_flash_attn) { + LOG_INFO("Using flash attention in the diffusion model"); + } + if (sd_version_is_sd3(version)) { + if (sd_ctx_params->diffusion_flash_attn) { + LOG_WARN("flash attention in this diffusion model is currently unsupported!"); + } + cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types); + diffusion_model = std::make_shared(backend, model_loader.tensor_storages_types); + } else if (sd_version_is_flux(version)) { + bool is_chroma = false; + for (auto pair : model_loader.tensor_storages_types) { + if (pair.first.find("distilled_guidance_layer.in_proj.weight") != std::string::npos) { + is_chroma = true; + break; + } + } + if (is_chroma) { + cond_stage_model = std::make_shared(clip_backend, + model_loader.tensor_storages_types, + -1, + sd_ctx_params->chroma_use_t5_mask, + sd_ctx_params->chroma_t5_mask_pad); + } else { + cond_stage_model = std::make_shared(clip_backend, model_loader.tensor_storages_types); + } + diffusion_model = std::make_shared(backend, + model_loader.tensor_storages_types, + version, + sd_ctx_params->diffusion_flash_attn, + sd_ctx_params->chroma_use_dit_mask); } else { - cond_stage_model = std::make_shared(clip_backend, conditioner_wtype, embeddings_path, version); - diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); + if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) { + cond_stage_model = std::make_shared(clip_backend, + model_loader.tensor_storages_types, + SAFE_STR(sd_ctx_params->embedding_dir), + version, + PM_VERSION_2); + } else { + cond_stage_model = std::make_shared(clip_backend, + model_loader.tensor_storages_types, + SAFE_STR(sd_ctx_params->embedding_dir), + version); + } + diffusion_model = std::make_shared(backend, + model_loader.tensor_storages_types, + version, + sd_ctx_params->diffusion_flash_attn); } + cond_stage_model->alloc_params_buffer(); cond_stage_model->get_param_tensors(tensors); @@ -339,41 +383,55 @@ class StableDiffusionGGML { diffusion_model->get_param_tensors(tensors); if (!use_tiny_autoencoder) { - if (vae_on_cpu && !ggml_backend_is_cpu(backend)) { + if (sd_ctx_params->keep_vae_on_cpu && !ggml_backend_is_cpu(backend)) { LOG_INFO("VAE Autoencoder: Using CPU backend"); vae_backend = ggml_backend_cpu_init(); } else { vae_backend = backend; } - first_stage_model = std::make_shared(vae_backend, vae_wtype, vae_decode_only, false, version); + first_stage_model = std::make_shared(vae_backend, + model_loader.tensor_storages_types, + "first_stage_model", + vae_decode_only, + false, + version); first_stage_model->alloc_params_buffer(); first_stage_model->get_param_tensors(tensors, "first_stage_model"); } else { - tae_first_stage = std::make_shared(backend, vae_wtype, vae_decode_only); + tae_first_stage = std::make_shared(backend, + model_loader.tensor_storages_types, + "decoder.layers", + vae_decode_only, + version); } // first_stage_model->get_param_tensors(tensors, "first_stage_model."); - if (control_net_path.size() > 0) { + if (strlen(SAFE_STR(sd_ctx_params->control_net_path)) > 0) { ggml_backend_t controlnet_backend = NULL; - if (control_net_cpu && !ggml_backend_is_cpu(backend)) { + if (sd_ctx_params->keep_control_net_on_cpu && !ggml_backend_is_cpu(backend)) { LOG_DEBUG("ControlNet: Using CPU backend"); controlnet_backend = ggml_backend_cpu_init(); } else { controlnet_backend = backend; } - control_net = std::make_shared(controlnet_backend, diffusion_model_wtype, version); + control_net = std::make_shared(controlnet_backend, model_loader.tensor_storages_types, version); } - pmid_model = std::make_shared(clip_backend, model_wtype, version); - if (id_embeddings_path.size() > 0) { - pmid_lora = std::make_shared(backend, model_wtype, id_embeddings_path, ""); + if (strstr(SAFE_STR(sd_ctx_params->stacked_id_embed_dir), "v2")) { + pmid_model = std::make_shared(backend, model_loader.tensor_storages_types, "pmid", version, PM_VERSION_2); + LOG_INFO("using PhotoMaker Version 2"); + } else { + pmid_model = std::make_shared(backend, model_loader.tensor_storages_types, "pmid", version); + } + if (strlen(SAFE_STR(sd_ctx_params->stacked_id_embed_dir)) > 0) { + pmid_lora = std::make_shared(backend, sd_ctx_params->stacked_id_embed_dir, ""); if (!pmid_lora->load_from_file(true)) { - LOG_WARN("load photomaker lora tensors from %s failed", id_embeddings_path.c_str()); + LOG_WARN("load photomaker lora tensors from %s failed", sd_ctx_params->stacked_id_embed_dir); return false; } - LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", id_embeddings_path.c_str()); - if (!model_loader.init_from_file(id_embeddings_path, "pmid.")) { - LOG_WARN("loading stacked ID embedding from '%s' failed", id_embeddings_path.c_str()); + LOG_INFO("loading stacked ID embedding (PHOTOMAKER) model file from '%s'", sd_ctx_params->stacked_id_embed_dir); + if (!model_loader.init_from_file(sd_ctx_params->stacked_id_embed_dir, "pmid.")) { + LOG_WARN("loading stacked ID embedding from '%s' failed", sd_ctx_params->stacked_id_embed_dir); } else { stacked_id = true; } @@ -383,14 +441,8 @@ class StableDiffusionGGML { LOG_ERROR(" pmid model params buffer allocation failed"); return false; } - // LOG_INFO("pmid param memory buffer size = %.2fMB ", - // pmid_model->params_buffer_size / 1024.0 / 1024.0); pmid_model->get_param_tensors(tensors, "pmid"); } - // if(stacked_id){ - // pmid_model.init_params(GGML_TYPE_F32); - // pmid_model.map_by_name(tensors, "pmid."); - // } } struct ggml_init_params params; @@ -451,7 +503,7 @@ class StableDiffusionGGML { } size_t control_net_params_mem_size = 0; if (control_net) { - if (!control_net->load_from_file(control_net_path)) { + if (!control_net->load_from_file(SAFE_STR(sd_ctx_params->control_net_path))) { return false; } control_net_params_mem_size = control_net->get_params_buffer_size(); @@ -507,12 +559,21 @@ class StableDiffusionGGML { } int64_t t1 = ggml_time_ms(); - LOG_INFO("loading model from '%s' completed, taking %.2fs", model_path.c_str(), (t1 - t0) * 1.0f / 1000); + LOG_INFO("loading model from '%s' completed, taking %.2fs", SAFE_STR(sd_ctx_params->model_path), (t1 - t0) * 1.0f / 1000); // check is_using_v_parameterization_for_sd2 - bool is_using_v_parameterization = false; - if (version == VERSION_SD2) { - if (is_using_v_parameterization_for_sd2(ctx)) { + + if (sd_version_is_sd2(version)) { + if (is_using_v_parameterization_for_sd2(ctx, sd_version_is_inpaint(version))) { + is_using_v_parameterization = true; + } + } else if (sd_version_is_sdxl(version)) { + if (model_loader.tensor_storages_types.find("edm_vpred.sigma_max") != model_loader.tensor_storages_types.end()) { + // CosXL models + // TODO: get sigma_min and sigma_max values from file + is_using_edm_v_parameterization = true; + } + if (model_loader.tensor_storages_types.find("v_pred") != model_loader.tensor_storages_types.end()) { is_using_v_parameterization = true; } } else if (version == VERSION_SVD) { @@ -520,25 +581,31 @@ class StableDiffusionGGML { is_using_v_parameterization = true; } - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { + if (sd_version_is_sd3(version)) { LOG_INFO("running in FLOW mode"); denoiser = std::make_shared(); - } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + } else if (sd_version_is_flux(version)) { LOG_INFO("running in Flux FLOW mode"); - float shift = 1.15f; - if (version == VERSION_FLUX_SCHNELL) { - shift = 1.0f; // TODO: validate + float shift = 1.0f; // TODO: validate + for (auto pair : model_loader.tensor_storages_types) { + if (pair.first.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) { + shift = 1.15f; + break; + } } denoiser = std::make_shared(shift); } else if (is_using_v_parameterization) { LOG_INFO("running in v-prediction mode"); denoiser = std::make_shared(); + } else if (is_using_edm_v_parameterization) { + LOG_INFO("running in v-prediction EDM mode"); + denoiser = std::make_shared(); } else { LOG_INFO("running in eps-prediction mode"); } - if (schedule != DEFAULT) { - switch (schedule) { + if (sd_ctx_params->schedule != DEFAULT) { + switch (sd_ctx_params->schedule) { case DISCRETE: LOG_INFO("running with discrete schedule"); denoiser->schedule = std::make_shared(); @@ -565,7 +632,7 @@ class StableDiffusionGGML { // Don't touch anything. break; default: - LOG_ERROR("Unknown schedule %i", schedule); + LOG_ERROR("Unknown schedule %i", sd_ctx_params->schedule); abort(); } } @@ -583,7 +650,7 @@ class StableDiffusionGGML { return true; } - bool is_using_v_parameterization_for_sd2(ggml_context* work_ctx) { + bool is_using_v_parameterization_for_sd2(ggml_context* work_ctx, bool is_inpaint = false) { struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 4, 1); ggml_set_f32(x_t, 0.5); struct ggml_tensor* c = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 1024, 2, 1, 1); @@ -591,9 +658,15 @@ class StableDiffusionGGML { struct ggml_tensor* timesteps = ggml_new_tensor_1d(work_ctx, GGML_TYPE_F32, 1); ggml_set_f32(timesteps, 999); + + struct ggml_tensor* concat = is_inpaint ? ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, 8, 8, 5, 1) : NULL; + if (concat != NULL) { + ggml_set_f32(concat, 0); + } + int64_t t0 = ggml_time_ms(); struct ggml_tensor* out = ggml_dup_tensor(work_ctx, x_t); - diffusion_model->compute(n_threads, x_t, timesteps, c, NULL, NULL, NULL, -1, {}, 0.f, &out); + diffusion_model->compute(n_threads, x_t, timesteps, c, concat, NULL, NULL, {}, -1, {}, 0.f, &out); diffusion_model->free_compute_buffer(); double result = 0.f; @@ -626,14 +699,15 @@ class StableDiffusionGGML { LOG_WARN("can not find %s or %s for lora %s", st_file_path.c_str(), ckpt_file_path.c_str(), lora_name.c_str()); return; } - LoraModel lora(backend, model_wtype, file_path); + LoraModel lora(backend, file_path); if (!lora.load_from_file()) { LOG_WARN("load lora tensors from %s failed", file_path.c_str()); return; } lora.multiplier = multiplier; - lora.apply(tensors, n_threads); + // TODO: send version? + lora.apply(tensors, version, n_threads); lora.free_params_buffer(); int64_t t1 = ggml_time_ms(); @@ -649,19 +723,20 @@ class StableDiffusionGGML { for (auto& kv : lora_state) { const std::string& lora_name = kv.first; float multiplier = kv.second; - - if (curr_lora_state.find(lora_name) != curr_lora_state.end()) { - float curr_multiplier = curr_lora_state[lora_name]; - float multiplier_diff = multiplier - curr_multiplier; - if (multiplier_diff != 0.f) { - lora_state_diff[lora_name] = multiplier_diff; - } - } else { - lora_state_diff[lora_name] = multiplier; - } + lora_state_diff[lora_name] += multiplier; + } + for (auto& kv : curr_lora_state) { + const std::string& lora_name = kv.first; + float curr_multiplier = kv.second; + lora_state_diff[lora_name] -= curr_multiplier; } - LOG_INFO("Attempting to apply %lu LoRAs", lora_state.size()); + size_t rm = lora_state_diff.size() - lora_state.size(); + if (rm != 0) { + LOG_INFO("Attempting to apply %lu LoRAs (removing %lu applied LoRAs)", lora_state.size(), rm); + } else { + LOG_INFO("Attempting to apply %lu LoRAs", lora_state.size()); + } for (auto& kv : lora_state_diff) { apply_lora(kv.first, kv.second); @@ -673,10 +748,10 @@ class StableDiffusionGGML { ggml_tensor* id_encoder(ggml_context* work_ctx, ggml_tensor* init_img, ggml_tensor* prompts_embeds, + ggml_tensor* id_embeds, std::vector& class_tokens_mask) { ggml_tensor* res = NULL; - pmid_model->compute(n_threads, init_img, prompts_embeds, class_tokens_mask, &res, work_ctx); - + pmid_model->compute(n_threads, init_img, prompts_embeds, id_embeds, class_tokens_mask, &res, work_ctx); return res; } @@ -763,15 +838,42 @@ class StableDiffusionGGML { ggml_tensor* noise, SDCondition cond, SDCondition uncond, + SDCondition img_cond, ggml_tensor* control_hint, float control_strength, - float min_cfg, - float cfg_scale, - float guidance, + sd_guidance_params_t guidance, + float eta, sample_method_t method, const std::vector& sigmas, int start_merge_step, - SDCondition id_cond) { + SDCondition id_cond, + std::vector ref_latents = {}, + ggml_tensor* denoise_mask = nullptr) { + std::vector skip_layers(guidance.slg.layers, guidance.slg.layers + guidance.slg.layer_count); + + float cfg_scale = guidance.txt_cfg; + float img_cfg_scale = guidance.img_cfg; + float slg_scale = guidance.slg.scale; + + float min_cfg = guidance.min_cfg; + + if (img_cfg_scale != cfg_scale && !sd_version_is_inpaint_or_unet_edit(version)) { + LOG_WARN("2-conditioning CFG is not supported with this model, disabling it for better performance..."); + img_cfg_scale = cfg_scale; + } + + LOG_DEBUG("Sample"); + struct ggml_init_params params; + size_t data_size = ggml_row_size(init_latent->type, init_latent->ne[0]); + for (int i = 1; i < 4; i++) { + data_size *= init_latent->ne[i]; + } + data_size += 1024; + params.mem_size = data_size * 3; + params.mem_buffer = NULL; + params.no_alloc = false; + ggml_context* tmp_ctx = ggml_init(params); + size_t steps = sigmas.size() - 1; // noise = load_tensor_from_file(work_ctx, "./rand0.bin"); // print_ggml_tensor(noise); @@ -781,14 +883,30 @@ class StableDiffusionGGML { struct ggml_tensor* noised_input = ggml_dup_tensor(work_ctx, noise); - bool has_unconditioned = cfg_scale != 1.0 && uncond.c_crossattn != NULL; + bool has_unconditioned = img_cfg_scale != 1.0 && uncond.c_crossattn != NULL; + bool has_img_cond = cfg_scale != img_cfg_scale && img_cond.c_crossattn != NULL; + bool has_skiplayer = slg_scale != 0.0 && skip_layers.size() > 0; // denoise wrapper - struct ggml_tensor* out_cond = ggml_dup_tensor(work_ctx, x); - struct ggml_tensor* out_uncond = NULL; + struct ggml_tensor* out_cond = ggml_dup_tensor(work_ctx, x); + struct ggml_tensor* out_uncond = NULL; + struct ggml_tensor* out_skip = NULL; + struct ggml_tensor* out_img_cond = NULL; + if (has_unconditioned) { out_uncond = ggml_dup_tensor(work_ctx, x); } + if (has_skiplayer) { + if (sd_version_is_dit(version)) { + out_skip = ggml_dup_tensor(work_ctx, x); + } else { + has_skiplayer = false; + LOG_WARN("SLG is incompatible with %s models", model_version_to_str[version]); + } + } + if (has_img_cond) { + out_img_cond = ggml_dup_tensor(work_ctx, x); + } struct ggml_tensor* denoised = ggml_dup_tensor(work_ctx, x); auto denoise = [&](ggml_tensor* input, float sigma, int step) -> ggml_tensor* { @@ -806,7 +924,7 @@ class StableDiffusionGGML { float t = denoiser->sigma_to_t(sigma); std::vector timesteps_vec(x->ne[3], t); // [N, ] auto timesteps = vector_to_ggml_tensor(work_ctx, timesteps_vec); - std::vector guidance_vec(x->ne[3], guidance); + std::vector guidance_vec(x->ne[3], guidance.distilled_guidance); auto guidance_tensor = vector_to_ggml_tensor(work_ctx, guidance_vec); copy_ggml_tensor(noised_input, input); @@ -831,6 +949,7 @@ class StableDiffusionGGML { cond.c_concat, cond.c_vector, guidance_tensor, + ref_latents, -1, controls, control_strength, @@ -843,6 +962,7 @@ class StableDiffusionGGML { cond.c_concat, id_cond.c_vector, guidance_tensor, + ref_latents, -1, controls, control_strength, @@ -863,12 +983,53 @@ class StableDiffusionGGML { uncond.c_concat, uncond.c_vector, guidance_tensor, + ref_latents, -1, controls, control_strength, &out_uncond); negative_data = (float*)out_uncond->data; } + + float* img_cond_data = NULL; + if (has_img_cond) { + diffusion_model->compute(n_threads, + noised_input, + timesteps, + img_cond.c_crossattn, + img_cond.c_concat, + img_cond.c_vector, + guidance_tensor, + ref_latents, + -1, + controls, + control_strength, + &out_img_cond); + img_cond_data = (float*)out_img_cond->data; + } + + int step_count = sigmas.size(); + bool is_skiplayer_step = has_skiplayer && step > (int)(guidance.slg.layer_start * step_count) && step < (int)(guidance.slg.layer_end * step_count); + float* skip_layer_data = NULL; + if (is_skiplayer_step) { + LOG_DEBUG("Skipping layers at step %d\n", step); + // skip layer (same as conditionned) + diffusion_model->compute(n_threads, + noised_input, + timesteps, + cond.c_crossattn, + cond.c_concat, + cond.c_vector, + guidance_tensor, + ref_latents, + -1, + controls, + control_strength, + &out_skip, + NULL, + skip_layers); + skip_layer_data = (float*)out_skip->data; + } float* vec_denoised = (float*)denoised->data; float* vec_input = (float*)input->data; float* positive_data = (float*)out_cond->data; @@ -882,8 +1043,20 @@ class StableDiffusionGGML { int64_t i3 = i / out_cond->ne[0] * out_cond->ne[1] * out_cond->ne[2]; float scale = min_cfg + (cfg_scale - min_cfg) * (i3 * 1.0f / ne3); } else { - latent_result = negative_data[i] + cfg_scale * (positive_data[i] - negative_data[i]); + if (has_img_cond) { + // out_uncond + text_cfg_scale * (out_cond - out_img_cond) + image_cfg_scale * (out_img_cond - out_uncond) + latent_result = negative_data[i] + img_cfg_scale * (img_cond_data[i] - negative_data[i]) + cfg_scale * (positive_data[i] - img_cond_data[i]); + } else { + // img_cfg_scale == cfg_scale + latent_result = negative_data[i] + cfg_scale * (positive_data[i] - negative_data[i]); + } } + } else if (has_img_cond) { + // img_cfg_scale == 1 + latent_result = img_cond_data[i] + cfg_scale * (positive_data[i] - img_cond_data[i]); + } + if (is_skiplayer_step) { + latent_result = latent_result + (positive_data[i] - skip_layer_data[i]) * slg_scale; } // v = latent_result, eps = latent_result // denoised = (v * c_out + input * c_skip) or (input + eps * c_out) @@ -894,10 +1067,23 @@ class StableDiffusionGGML { pretty_progress(step, (int)steps, (t1 - t0) / 1000000.f); // LOG_INFO("step %d sampling completed taking %.2fs", step, (t1 - t0) * 1.0f / 1000000); } + if (denoise_mask != nullptr) { + for (int64_t x = 0; x < denoised->ne[0]; x++) { + for (int64_t y = 0; y < denoised->ne[1]; y++) { + float mask = ggml_tensor_get_f32(denoise_mask, x, y); + for (int64_t k = 0; k < denoised->ne[2]; k++) { + float init = ggml_tensor_get_f32(init_latent, x, y, k); + float den = ggml_tensor_get_f32(denoised, x, y, k); + ggml_tensor_set_f32(denoised, init + mask * (den - init), x, y, k); + } + } + } + } + return denoised; }; - sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng); + sample_k_diffusion(method, denoise, work_ctx, x, sigmas, rng, eta); x = denoiser->inverse_noise_scaling(sigmas[sigmas.size() - 1], x); @@ -948,9 +1134,9 @@ class StableDiffusionGGML { if (use_tiny_autoencoder) { C = 4; } else { - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { + if (sd_version_is_sd3(version)) { C = 32; - } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + } else if (sd_version_is_flux(version)) { C = 32; } } @@ -1011,72 +1197,301 @@ class StableDiffusionGGML { /*================================================= SD API ==================================================*/ +#define NONE_STR "NONE" + +const char* sd_type_name(enum sd_type_t type) { + return ggml_type_name((ggml_type)type); +} + +enum sd_type_t str_to_sd_type(const char* str) { + for (int i = 0; i < SD_TYPE_COUNT; i++) { + auto trait = ggml_get_type_traits((ggml_type)i); + if (!strcmp(str, trait->type_name)) { + return (enum sd_type_t)i; + } + } + return SD_TYPE_COUNT; +} + +const char* rng_type_to_str[] = { + "std_default", + "cuda", +}; + +const char* sd_rng_type_name(enum rng_type_t rng_type) { + if (rng_type < RNG_TYPE_COUNT) { + return rng_type_to_str[rng_type]; + } + return NONE_STR; +} + +enum rng_type_t str_to_rng_type(const char* str) { + for (int i = 0; i < RNG_TYPE_COUNT; i++) { + if (!strcmp(str, rng_type_to_str[i])) { + return (enum rng_type_t)i; + } + } + return RNG_TYPE_COUNT; +} + +const char* sample_method_to_str[] = { + "euler_a", + "euler", + "heun", + "dpm2", + "dpm++2s_a", + "dpm++2m", + "dpm++2mv2", + "ipndm", + "ipndm_v", + "lcm", + "ddim_trailing", + "tcd", +}; + +const char* sd_sample_method_name(enum sample_method_t sample_method) { + if (sample_method < SAMPLE_METHOD_COUNT) { + return sample_method_to_str[sample_method]; + } + return NONE_STR; +} + +enum sample_method_t str_to_sample_method(const char* str) { + for (int i = 0; i < SAMPLE_METHOD_COUNT; i++) { + if (!strcmp(str, sample_method_to_str[i])) { + return (enum sample_method_t)i; + } + } + return SAMPLE_METHOD_COUNT; +} + +const char* schedule_to_str[] = { + "default", + "discrete", + "karras", + "exponential", + "ays", + "gits", +}; + +const char* sd_schedule_name(enum schedule_t schedule) { + if (schedule < SCHEDULE_COUNT) { + return schedule_to_str[schedule]; + } + return NONE_STR; +} + +enum schedule_t str_to_schedule(const char* str) { + for (int i = 0; i < SCHEDULE_COUNT; i++) { + if (!strcmp(str, schedule_to_str[i])) { + return (enum schedule_t)i; + } + } + return SCHEDULE_COUNT; +} + +void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params) { + memset((void*)sd_ctx_params, 0, sizeof(sd_ctx_params_t)); + sd_ctx_params->vae_decode_only = true; + sd_ctx_params->vae_tiling = false; + sd_ctx_params->free_params_immediately = true; + sd_ctx_params->n_threads = get_num_physical_cores(); + sd_ctx_params->wtype = SD_TYPE_COUNT; + sd_ctx_params->rng_type = CUDA_RNG; + sd_ctx_params->schedule = DEFAULT; + sd_ctx_params->keep_clip_on_cpu = false; + sd_ctx_params->keep_control_net_on_cpu = false; + sd_ctx_params->keep_vae_on_cpu = false; + sd_ctx_params->diffusion_flash_attn = false; + sd_ctx_params->chroma_use_dit_mask = true; + sd_ctx_params->chroma_use_t5_mask = false; + sd_ctx_params->chroma_t5_mask_pad = 1; +} + +char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params) { + char* buf = (char*)malloc(4096); + if (!buf) + return NULL; + buf[0] = '\0'; + + snprintf(buf + strlen(buf), 4096 - strlen(buf), + "model_path: %s\n" + "clip_l_path: %s\n" + "clip_g_path: %s\n" + "t5xxl_path: %s\n" + "diffusion_model_path: %s\n" + "vae_path: %s\n" + "taesd_path: %s\n" + "control_net_path: %s\n" + "lora_model_dir: %s\n" + "embedding_dir: %s\n" + "stacked_id_embed_dir: %s\n" + "vae_decode_only: %s\n" + "vae_tiling: %s\n" + "free_params_immediately: %s\n" + "n_threads: %d\n" + "wtype: %s\n" + "rng_type: %s\n" + "schedule: %s\n" + "keep_clip_on_cpu: %s\n" + "keep_control_net_on_cpu: %s\n" + "keep_vae_on_cpu: %s\n" + "diffusion_flash_attn: %s\n" + "chroma_use_dit_mask: %s\n" + "chroma_use_t5_mask: %s\n" + "chroma_t5_mask_pad: %d\n", + SAFE_STR(sd_ctx_params->model_path), + SAFE_STR(sd_ctx_params->clip_l_path), + SAFE_STR(sd_ctx_params->clip_g_path), + SAFE_STR(sd_ctx_params->t5xxl_path), + SAFE_STR(sd_ctx_params->diffusion_model_path), + SAFE_STR(sd_ctx_params->vae_path), + SAFE_STR(sd_ctx_params->taesd_path), + SAFE_STR(sd_ctx_params->control_net_path), + SAFE_STR(sd_ctx_params->lora_model_dir), + SAFE_STR(sd_ctx_params->embedding_dir), + SAFE_STR(sd_ctx_params->stacked_id_embed_dir), + BOOL_STR(sd_ctx_params->vae_decode_only), + BOOL_STR(sd_ctx_params->vae_tiling), + BOOL_STR(sd_ctx_params->free_params_immediately), + sd_ctx_params->n_threads, + sd_type_name(sd_ctx_params->wtype), + sd_rng_type_name(sd_ctx_params->rng_type), + sd_schedule_name(sd_ctx_params->schedule), + BOOL_STR(sd_ctx_params->keep_clip_on_cpu), + BOOL_STR(sd_ctx_params->keep_control_net_on_cpu), + BOOL_STR(sd_ctx_params->keep_vae_on_cpu), + BOOL_STR(sd_ctx_params->diffusion_flash_attn), + BOOL_STR(sd_ctx_params->chroma_use_dit_mask), + BOOL_STR(sd_ctx_params->chroma_use_t5_mask), + sd_ctx_params->chroma_t5_mask_pad); + + return buf; +} + +void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params) { + memset((void*)sd_img_gen_params, 0, sizeof(sd_img_gen_params_t)); + sd_img_gen_params->clip_skip = -1; + sd_img_gen_params->guidance.txt_cfg = 7.0f; + sd_img_gen_params->guidance.min_cfg = 1.0f; + sd_img_gen_params->guidance.img_cfg = INFINITY; + sd_img_gen_params->guidance.distilled_guidance = 3.5f; + sd_img_gen_params->guidance.slg.layer_count = 0; + sd_img_gen_params->guidance.slg.layer_start = 0.01f; + sd_img_gen_params->guidance.slg.layer_end = 0.2f; + sd_img_gen_params->guidance.slg.scale = 0.f; + sd_img_gen_params->ref_images_count = 0; + sd_img_gen_params->width = 512; + sd_img_gen_params->height = 512; + sd_img_gen_params->sample_method = EULER_A; + sd_img_gen_params->sample_steps = 20; + sd_img_gen_params->eta = 0.f; + sd_img_gen_params->strength = 0.75f; + sd_img_gen_params->seed = -1; + sd_img_gen_params->batch_count = 1; + sd_img_gen_params->control_strength = 0.9f; + sd_img_gen_params->style_strength = 20.f; + sd_img_gen_params->normalize_input = false; +} + +char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params) { + char* buf = (char*)malloc(4096); + if (!buf) + return NULL; + buf[0] = '\0'; + + snprintf(buf + strlen(buf), 4096 - strlen(buf), + "prompt: %s\n" + "negative_prompt: %s\n" + "clip_skip: %d\n" + "txt_cfg: %.2f\n" + "img_cfg: %.2f\n" + "min_cfg: %.2f\n" + "distilled_guidance: %.2f\n" + "slg.layer_count: %zu\n" + "slg.layer_start: %.2f\n" + "slg.layer_end: %.2f\n" + "slg.scale: %.2f\n" + "width: %d\n" + "height: %d\n" + "sample_method: %s\n" + "sample_steps: %d\n" + "eta: %.2f\n" + "strength: %.2f\n" + "seed: %" PRId64 + "\n" + "batch_count: %d\n" + "ref_images_count: %d\n" + "control_strength: %.2f\n" + "style_strength: %.2f\n" + "normalize_input: %s\n" + "input_id_images_path: %s\n", + SAFE_STR(sd_img_gen_params->prompt), + SAFE_STR(sd_img_gen_params->negative_prompt), + sd_img_gen_params->clip_skip, + sd_img_gen_params->guidance.txt_cfg, + sd_img_gen_params->guidance.img_cfg, + sd_img_gen_params->guidance.min_cfg, + sd_img_gen_params->guidance.distilled_guidance, + sd_img_gen_params->guidance.slg.layer_count, + sd_img_gen_params->guidance.slg.layer_start, + sd_img_gen_params->guidance.slg.layer_end, + sd_img_gen_params->guidance.slg.scale, + sd_img_gen_params->width, + sd_img_gen_params->height, + sd_sample_method_name(sd_img_gen_params->sample_method), + sd_img_gen_params->sample_steps, + sd_img_gen_params->eta, + sd_img_gen_params->strength, + sd_img_gen_params->seed, + sd_img_gen_params->batch_count, + sd_img_gen_params->ref_images_count, + sd_img_gen_params->control_strength, + sd_img_gen_params->style_strength, + BOOL_STR(sd_img_gen_params->normalize_input), + SAFE_STR(sd_img_gen_params->input_id_images_path)); + + return buf; +} + +void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params) { + memset((void*)sd_vid_gen_params, 0, sizeof(sd_vid_gen_params_t)); + sd_vid_gen_params->guidance.txt_cfg = 7.0f; + sd_vid_gen_params->guidance.min_cfg = 1.0f; + sd_vid_gen_params->guidance.img_cfg = INFINITY; + sd_vid_gen_params->guidance.distilled_guidance = 3.5f; + sd_vid_gen_params->guidance.slg.layer_count = 0; + sd_vid_gen_params->guidance.slg.layer_start = 0.01f; + sd_vid_gen_params->guidance.slg.layer_end = 0.2f; + sd_vid_gen_params->guidance.slg.scale = 0.f; + sd_vid_gen_params->width = 512; + sd_vid_gen_params->height = 512; + sd_vid_gen_params->sample_method = EULER_A; + sd_vid_gen_params->sample_steps = 20; + sd_vid_gen_params->strength = 0.75f; + sd_vid_gen_params->seed = -1; + sd_vid_gen_params->video_frames = 6; + sd_vid_gen_params->motion_bucket_id = 127; + sd_vid_gen_params->fps = 6; + sd_vid_gen_params->augmentation_level = 0.f; +} + struct sd_ctx_t { StableDiffusionGGML* sd = NULL; }; -sd_ctx_t* new_sd_ctx(const char* model_path_c_str, - const char* clip_l_path_c_str, - const char* clip_g_path_c_str, - const char* t5xxl_path_c_str, - const char* diffusion_model_path_c_str, - const char* vae_path_c_str, - const char* taesd_path_c_str, - const char* control_net_path_c_str, - const char* lora_model_dir_c_str, - const char* embed_dir_c_str, - const char* id_embed_dir_c_str, - bool vae_decode_only, - bool vae_tiling, - bool free_params_immediately, - int n_threads, - enum sd_type_t wtype, - enum rng_type_t rng_type, - enum schedule_t s, - bool keep_clip_on_cpu, - bool keep_control_net_cpu, - bool keep_vae_on_cpu) { +sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params) { sd_ctx_t* sd_ctx = (sd_ctx_t*)malloc(sizeof(sd_ctx_t)); if (sd_ctx == NULL) { return NULL; } - std::string model_path(model_path_c_str); - std::string clip_l_path(clip_l_path_c_str); - std::string clip_g_path(clip_g_path_c_str); - std::string t5xxl_path(t5xxl_path_c_str); - std::string diffusion_model_path(diffusion_model_path_c_str); - std::string vae_path(vae_path_c_str); - std::string taesd_path(taesd_path_c_str); - std::string control_net_path(control_net_path_c_str); - std::string embd_path(embed_dir_c_str); - std::string id_embd_path(id_embed_dir_c_str); - std::string lora_model_dir(lora_model_dir_c_str); - - sd_ctx->sd = new StableDiffusionGGML(n_threads, - vae_decode_only, - free_params_immediately, - lora_model_dir, - rng_type); + + sd_ctx->sd = new StableDiffusionGGML(); if (sd_ctx->sd == NULL) { return NULL; } - if (!sd_ctx->sd->load_from_file(model_path, - clip_l_path, - clip_g_path, - t5xxl_path_c_str, - diffusion_model_path, - vae_path, - control_net_path, - embd_path, - id_embd_path, - taesd_path, - vae_tiling, - (ggml_type)wtype, - s, - keep_clip_on_cpu, - keep_control_net_cpu, - keep_vae_on_cpu)) { + if (!sd_ctx->sd->init(sd_ctx_params)) { delete sd_ctx->sd; sd_ctx->sd = NULL; free(sd_ctx); @@ -1093,25 +1508,28 @@ void free_sd_ctx(sd_ctx_t* sd_ctx) { free(sd_ctx); } -sd_image_t* generate_image(sd_ctx_t* sd_ctx, - struct ggml_context* work_ctx, - ggml_tensor* init_latent, - std::string prompt, - std::string negative_prompt, - int clip_skip, - float cfg_scale, - float guidance, - int width, - int height, - enum sample_method_t sample_method, - const std::vector& sigmas, - int64_t seed, - int batch_count, - const sd_image_t* control_cond, - float control_strength, - float style_ratio, - bool normalize_input, - std::string input_id_images_path) { +sd_image_t* generate_image_internal(sd_ctx_t* sd_ctx, + struct ggml_context* work_ctx, + ggml_tensor* init_latent, + std::string prompt, + std::string negative_prompt, + int clip_skip, + sd_guidance_params_t guidance, + float eta, + int width, + int height, + enum sample_method_t sample_method, + const std::vector& sigmas, + int64_t seed, + int batch_count, + const sd_image_t* control_cond, + float control_strength, + float style_ratio, + bool normalize_input, + std::string input_id_images_path, + std::vector ref_latents, + ggml_tensor* concat_latent = NULL, + ggml_tensor* denoise_mask = NULL) { if (seed < 0) { // Generally, when using the provided command line, the seed is always >0. // However, to prevent potential issues if 'stable-diffusion.cpp' is invoked as a library @@ -1151,7 +1569,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, if (sd_ctx->sd->stacked_id) { if (!sd_ctx->sd->pmid_lora->applied) { t0 = ggml_time_ms(); - sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->n_threads); + sd_ctx->sd->pmid_lora->apply(sd_ctx->sd->tensors, sd_ctx->sd->version, sd_ctx->sd->n_threads); t1 = ggml_time_ms(); sd_ctx->sd->pmid_lora->applied = true; LOG_INFO("pmid_lora apply completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); @@ -1161,11 +1579,15 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, } // preprocess input id images std::vector input_id_images; + bool pmv2 = sd_ctx->sd->pmid_model->get_version() == PM_VERSION_2; if (sd_ctx->sd->pmid_model && input_id_images_path.size() > 0) { std::vector img_files = get_files_from_dir(input_id_images_path); for (std::string img_file : img_files) { int c = 0; int width, height; + if (ends_with(img_file, "safetensors")) { + continue; + } uint8_t* input_image_buffer = stbi_load(img_file.c_str(), &width, &height, &c, 3); if (input_image_buffer == NULL) { LOG_ERROR("PhotoMaker load image from '%s' failed", img_file.c_str()); @@ -1203,18 +1625,23 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, else sd_mul_images_to_tensor(init_image->data, init_img, i, NULL, NULL); } - t0 = ggml_time_ms(); - auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx, - sd_ctx->sd->n_threads, prompt, - clip_skip, - width, - height, - num_input_images, - sd_ctx->sd->diffusion_model->get_adm_in_channels()); - id_cond = std::get<0>(cond_tup); - class_tokens_mask = std::get<1>(cond_tup); // - - id_cond.c_crossattn = sd_ctx->sd->id_encoder(work_ctx, init_img, id_cond.c_crossattn, class_tokens_mask); + t0 = ggml_time_ms(); + auto cond_tup = sd_ctx->sd->cond_stage_model->get_learned_condition_with_trigger(work_ctx, + sd_ctx->sd->n_threads, prompt, + clip_skip, + width, + height, + num_input_images, + sd_ctx->sd->diffusion_model->get_adm_in_channels()); + id_cond = std::get<0>(cond_tup); + class_tokens_mask = std::get<1>(cond_tup); // + struct ggml_tensor* id_embeds = NULL; + if (pmv2) { + // id_embeds = sd_ctx->sd->pmid_id_embeds->get(); + id_embeds = load_tensor_from_file(work_ctx, path_join(input_id_images_path, "id_embeds.bin")); + // print_ggml_tensor(id_embeds, true, "id_embeds:"); + } + id_cond.c_crossattn = sd_ctx->sd->id_encoder(work_ctx, init_img, id_cond.c_crossattn, id_embeds, class_tokens_mask); t1 = ggml_time_ms(); LOG_INFO("Photomaker ID Stacking, taking %" PRId64 " ms", t1 - t0); if (sd_ctx->sd->free_params_immediately) { @@ -1250,9 +1677,10 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, sd_ctx->sd->diffusion_model->get_adm_in_channels()); SDCondition uncond; - if (cfg_scale != 1.0) { + if (guidance.txt_cfg != 1.0 || + (sd_version_is_inpaint_or_unet_edit(sd_ctx->sd->version) && guidance.txt_cfg != guidance.img_cfg)) { bool force_zero_embeddings = false; - if (sd_ctx->sd->version == VERSION_SDXL && negative_prompt.size() == 0) { + if (sd_version_is_sdxl(sd_ctx->sd->version) && negative_prompt.size() == 0 && !sd_ctx->sd->is_using_edm_v_parameterization) { force_zero_embeddings = true; } uncond = sd_ctx->sd->cond_stage_model->get_learned_condition(work_ctx, @@ -1281,14 +1709,59 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, // Sample std::vector final_latents; // collect latents to decode int C = 4; - if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { + if (sd_version_is_sd3(sd_ctx->sd->version)) { C = 16; - } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + } else if (sd_version_is_flux(sd_ctx->sd->version)) { C = 16; } int W = width / 8; int H = height / 8; LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); + if (sd_version_is_inpaint(sd_ctx->sd->version)) { + int64_t mask_channels = 1; + if (sd_ctx->sd->version == VERSION_FLUX_FILL) { + mask_channels = 8 * 8; // flatten the whole mask + } + auto empty_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, init_latent->ne[0], init_latent->ne[1], mask_channels + init_latent->ne[2], 1); + // no mask, set the whole image as masked + for (int64_t x = 0; x < empty_latent->ne[0]; x++) { + for (int64_t y = 0; y < empty_latent->ne[1]; y++) { + if (sd_ctx->sd->version == VERSION_FLUX_FILL) { + // TODO: this might be wrong + for (int64_t c = 0; c < init_latent->ne[2]; c++) { + ggml_tensor_set_f32(empty_latent, 0, x, y, c); + } + for (int64_t c = init_latent->ne[2]; c < empty_latent->ne[2]; c++) { + ggml_tensor_set_f32(empty_latent, 1, x, y, c); + } + } else { + ggml_tensor_set_f32(empty_latent, 1, x, y, 0); + for (int64_t c = 1; c < empty_latent->ne[2]; c++) { + ggml_tensor_set_f32(empty_latent, 0, x, y, c); + } + } + } + } + if (concat_latent == NULL) { + concat_latent = empty_latent; + } + cond.c_concat = concat_latent; + uncond.c_concat = empty_latent; + denoise_mask = NULL; + } else if (sd_version_is_unet_edit(sd_ctx->sd->version)) { + auto empty_latent = ggml_dup_tensor(work_ctx, init_latent); + ggml_set_f32(empty_latent, 0); + uncond.c_concat = empty_latent; + if (concat_latent == NULL) { + concat_latent = empty_latent; + } + cond.c_concat = ref_latents[0]; + } + SDCondition img_cond; + if (uncond.c_crossattn != NULL && + (sd_version_is_inpaint_or_unet_edit(sd_ctx->sd->version) && guidance.txt_cfg != guidance.img_cfg)) { + img_cond = SDCondition(uncond.c_crossattn, uncond.c_vector, cond.c_concat); + } for (int b = 0; b < batch_count; b++) { int64_t sampling_start = ggml_time_ms(); int64_t cur_seed = seed + b; @@ -1307,20 +1780,26 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, LOG_INFO("PHOTOMAKER: start_merge_step: %d", start_merge_step); } + // Disable min_cfg + guidance.min_cfg = guidance.txt_cfg; + struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx, x_t, noise, cond, uncond, + img_cond, image_hint, control_strength, - cfg_scale, - cfg_scale, guidance, + eta, sample_method, sigmas, start_merge_step, - id_cond); + id_cond, + ref_latents, + denoise_mask); + // struct ggml_tensor* x_0 = load_tensor_from_file(ctx, "samples_ddim.bin"); // print_ggml_tensor(x_0); int64_t sampling_end = ggml_time_ms(); @@ -1370,136 +1849,51 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, return result_images; } -sd_image_t* txt2img(sd_ctx_t* sd_ctx, - const char* prompt_c_str, - const char* negative_prompt_c_str, - int clip_skip, - float cfg_scale, - float guidance, - int width, - int height, - enum sample_method_t sample_method, - int sample_steps, - int64_t seed, - int batch_count, - const sd_image_t* control_cond, - float control_strength, - float style_ratio, - bool normalize_input, - const char* input_id_images_path_c_str) { - LOG_DEBUG("txt2img %dx%d", width, height); - if (sd_ctx == NULL) { - return NULL; - } - - struct ggml_init_params params; - params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB - if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { - params.mem_size *= 3; - } - if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { - params.mem_size *= 4; - } - if (sd_ctx->sd->stacked_id) { - params.mem_size += static_cast(10 * 1024 * 1024); // 10 MB - } - params.mem_size += width * height * 3 * sizeof(float); - params.mem_size *= batch_count; - params.mem_buffer = NULL; - params.no_alloc = false; - // LOG_DEBUG("mem_size %u ", params.mem_size); - - struct ggml_context* work_ctx = ggml_init(params); - if (!work_ctx) { - LOG_ERROR("ggml_init() failed"); - return NULL; - } - - size_t t0 = ggml_time_ms(); - - std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); - +ggml_tensor* generate_init_latent(sd_ctx_t* sd_ctx, + ggml_context* work_ctx, + int width, + int height) { int C = 4; - if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { + if (sd_version_is_sd3(sd_ctx->sd->version)) { C = 16; - } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + } else if (sd_version_is_flux(sd_ctx->sd->version)) { C = 16; } int W = width / 8; int H = height / 8; ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); - if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { + if (sd_version_is_sd3(sd_ctx->sd->version)) { ggml_set_f32(init_latent, 0.0609f); - } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + } else if (sd_version_is_flux(sd_ctx->sd->version)) { ggml_set_f32(init_latent, 0.1159f); } else { ggml_set_f32(init_latent, 0.f); } - - sd_image_t* result_images = generate_image(sd_ctx, - work_ctx, - init_latent, - prompt_c_str, - negative_prompt_c_str, - clip_skip, - cfg_scale, - guidance, - width, - height, - sample_method, - sigmas, - seed, - batch_count, - control_cond, - control_strength, - style_ratio, - normalize_input, - input_id_images_path_c_str); - - size_t t1 = ggml_time_ms(); - - LOG_INFO("txt2img completed in %.2fs", (t1 - t0) * 1.0f / 1000); - - return result_images; + return init_latent; } -sd_image_t* img2img(sd_ctx_t* sd_ctx, - sd_image_t init_image, - const char* prompt_c_str, - const char* negative_prompt_c_str, - int clip_skip, - float cfg_scale, - float guidance, - int width, - int height, - sample_method_t sample_method, - int sample_steps, - float strength, - int64_t seed, - int batch_count, - const sd_image_t* control_cond, - float control_strength, - float style_ratio, - bool normalize_input, - const char* input_id_images_path_c_str) { - LOG_DEBUG("img2img %dx%d", width, height); - if (sd_ctx == NULL) { +sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params) { + int width = sd_img_gen_params->width; + int height = sd_img_gen_params->height; + LOG_DEBUG("generate_image %dx%d", width, height); + if (sd_ctx == NULL || sd_img_gen_params == NULL) { return NULL; } struct ggml_init_params params; params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB - if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { - params.mem_size *= 2; - } - if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { + if (sd_version_is_sd3(sd_ctx->sd->version)) { params.mem_size *= 3; } + if (sd_version_is_flux(sd_ctx->sd->version)) { + params.mem_size *= 4; + } if (sd_ctx->sd->stacked_id) { params.mem_size += static_cast(10 * 1024 * 1024); // 10 MB } - params.mem_size += width * height * 3 * sizeof(float) * 2; - params.mem_size *= batch_count; + params.mem_size += width * height * 3 * sizeof(float) * 3; + params.mem_size += width * height * 3 * sizeof(float) * 3 * sd_img_gen_params->ref_images_count; + params.mem_size *= sd_img_gen_params->batch_count; params.mem_buffer = NULL; params.no_alloc = false; // LOG_DEBUG("mem_size %u ", params.mem_size); @@ -1510,85 +1904,197 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, return NULL; } - size_t t0 = ggml_time_ms(); - + int64_t seed = sd_img_gen_params->seed; if (seed < 0) { srand((int)time(NULL)); seed = rand(); } sd_ctx->sd->rng->manual_seed(seed); - ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); - sd_image_to_tensor(init_image.data, init_img); - ggml_tensor* init_latent = NULL; - if (!sd_ctx->sd->use_tiny_autoencoder) { - ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img); - init_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); + size_t t0 = ggml_time_ms(); + + ggml_tensor* init_latent = NULL; + ggml_tensor* concat_latent = NULL; + ggml_tensor* denoise_mask = NULL; + std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sd_img_gen_params->sample_steps); + + if (sd_img_gen_params->init_image.data) { + LOG_INFO("IMG2IMG"); + + size_t t_enc = static_cast(sd_img_gen_params->sample_steps * sd_img_gen_params->strength); + if (t_enc == sd_img_gen_params->sample_steps) + t_enc--; + LOG_INFO("target t_enc is %zu steps", t_enc); + std::vector sigma_sched; + sigma_sched.assign(sigmas.begin() + sd_img_gen_params->sample_steps - t_enc - 1, sigmas.end()); + sigmas = sigma_sched; + + ggml_tensor* init_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); + ggml_tensor* mask_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 1, 1); + + sd_mask_to_tensor(sd_img_gen_params->mask_image.data, mask_img); + sd_image_to_tensor(sd_img_gen_params->init_image.data, init_img); + + if (sd_version_is_inpaint(sd_ctx->sd->version)) { + int64_t mask_channels = 1; + if (sd_ctx->sd->version == VERSION_FLUX_FILL) { + mask_channels = 8 * 8; // flatten the whole mask + } + ggml_tensor* masked_img = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width, height, 3, 1); + sd_apply_mask(init_img, mask_img, masked_img); + ggml_tensor* masked_latent = NULL; + if (!sd_ctx->sd->use_tiny_autoencoder) { + ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); + masked_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); + } else { + masked_latent = sd_ctx->sd->encode_first_stage(work_ctx, masked_img); + } + concat_latent = ggml_new_tensor_4d(work_ctx, + GGML_TYPE_F32, + masked_latent->ne[0], + masked_latent->ne[1], + mask_channels + masked_latent->ne[2], + 1); + for (int ix = 0; ix < masked_latent->ne[0]; ix++) { + for (int iy = 0; iy < masked_latent->ne[1]; iy++) { + int mx = ix * 8; + int my = iy * 8; + if (sd_ctx->sd->version == VERSION_FLUX_FILL) { + for (int k = 0; k < masked_latent->ne[2]; k++) { + float v = ggml_tensor_get_f32(masked_latent, ix, iy, k); + ggml_tensor_set_f32(concat_latent, v, ix, iy, k); + } + // "Encode" 8x8 mask chunks into a flattened 1x64 vector, and concatenate to masked image + for (int x = 0; x < 8; x++) { + for (int y = 0; y < 8; y++) { + float m = ggml_tensor_get_f32(mask_img, mx + x, my + y); + // TODO: check if the way the mask is flattened is correct (is it supposed to be x*8+y or x+8*y?) + // python code was using "b (h 8) (w 8) -> b (8 8) h w" + ggml_tensor_set_f32(concat_latent, m, ix, iy, masked_latent->ne[2] + x * 8 + y); + } + } + } else { + float m = ggml_tensor_get_f32(mask_img, mx, my); + ggml_tensor_set_f32(concat_latent, m, ix, iy, 0); + for (int k = 0; k < masked_latent->ne[2]; k++) { + float v = ggml_tensor_get_f32(masked_latent, ix, iy, k); + ggml_tensor_set_f32(concat_latent, v, ix, iy, k + mask_channels); + } + } + } + } + } + + { + // LOG_WARN("Inpainting with a base model is not great"); + denoise_mask = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, width / 8, height / 8, 1, 1); + for (int ix = 0; ix < denoise_mask->ne[0]; ix++) { + for (int iy = 0; iy < denoise_mask->ne[1]; iy++) { + int mx = ix * 8; + int my = iy * 8; + float m = ggml_tensor_get_f32(mask_img, mx, my); + ggml_tensor_set_f32(denoise_mask, m, ix, iy); + } + } + } + + if (!sd_ctx->sd->use_tiny_autoencoder) { + ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, init_img); + init_latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); + } else { + init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); + } } else { - init_latent = sd_ctx->sd->encode_first_stage(work_ctx, init_img); + LOG_INFO("TXT2IMG"); + if (sd_version_is_inpaint(sd_ctx->sd->version)) { + LOG_WARN("This is an inpainting model, this should only be used in img2img mode with a mask"); + } + init_latent = generate_init_latent(sd_ctx, work_ctx, width, height); + } + + if (sd_img_gen_params->ref_images_count > 0) { + LOG_INFO("EDIT mode"); + } + + std::vector ref_latents; + for (int i = 0; i < sd_img_gen_params->ref_images_count; i++) { + ggml_tensor* img = ggml_new_tensor_4d(work_ctx, + GGML_TYPE_F32, + sd_img_gen_params->ref_images[i].width, + sd_img_gen_params->ref_images[i].height, + 3, + 1); + sd_image_to_tensor(sd_img_gen_params->ref_images[i].data, img); + + ggml_tensor* latent = NULL; + if (sd_ctx->sd->use_tiny_autoencoder) { + latent = sd_ctx->sd->encode_first_stage(work_ctx, img); + } else if (sd_ctx->sd->version == VERSION_SD1_PIX2PIX) { + latent = sd_ctx->sd->encode_first_stage(work_ctx, img); + latent = ggml_view_3d(work_ctx, + latent, + latent->ne[0], + latent->ne[1], + latent->ne[2] / 2, + latent->nb[1], + latent->nb[2], + 0); + } else { + ggml_tensor* moments = sd_ctx->sd->encode_first_stage(work_ctx, img); + latent = sd_ctx->sd->get_first_stage_encoding(work_ctx, moments); + } + ref_latents.push_back(latent); + } + + if (sd_img_gen_params->init_image.data != NULL || sd_img_gen_params->ref_images_count > 0) { + size_t t1 = ggml_time_ms(); + LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); } - print_ggml_tensor(init_latent, true); - size_t t1 = ggml_time_ms(); - LOG_INFO("encode_first_stage completed, taking %.2fs", (t1 - t0) * 1.0f / 1000); - - std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); - size_t t_enc = static_cast(sample_steps * strength); - LOG_INFO("target t_enc is %zu steps", t_enc); - std::vector sigma_sched; - sigma_sched.assign(sigmas.begin() + sample_steps - t_enc - 1, sigmas.end()); - - sd_image_t* result_images = generate_image(sd_ctx, - work_ctx, - init_latent, - prompt_c_str, - negative_prompt_c_str, - clip_skip, - cfg_scale, - guidance, - width, - height, - sample_method, - sigma_sched, - seed, - batch_count, - control_cond, - control_strength, - style_ratio, - normalize_input, - input_id_images_path_c_str); + + sd_image_t* result_images = generate_image_internal(sd_ctx, + work_ctx, + init_latent, + SAFE_STR(sd_img_gen_params->prompt), + SAFE_STR(sd_img_gen_params->negative_prompt), + sd_img_gen_params->clip_skip, + sd_img_gen_params->guidance, + sd_img_gen_params->eta, + width, + height, + sd_img_gen_params->sample_method, + sigmas, + seed, + sd_img_gen_params->batch_count, + sd_img_gen_params->control_cond, + sd_img_gen_params->control_strength, + sd_img_gen_params->style_strength, + sd_img_gen_params->normalize_input, + sd_img_gen_params->input_id_images_path, + ref_latents, + concat_latent, + denoise_mask); size_t t2 = ggml_time_ms(); - LOG_INFO("img2img completed in %.2fs", (t1 - t0) * 1.0f / 1000); + LOG_INFO("generate_image completed in %.2fs", (t2 - t0) * 1.0f / 1000); return result_images; } -SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, - sd_image_t init_image, - int width, - int height, - int video_frames, - int motion_bucket_id, - int fps, - float augmentation_level, - float min_cfg, - float cfg_scale, - enum sample_method_t sample_method, - int sample_steps, - float strength, - int64_t seed) { - if (sd_ctx == NULL) { +SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params) { + if (sd_ctx == NULL || sd_vid_gen_params == NULL) { return NULL; } + int width = sd_vid_gen_params->width; + int height = sd_vid_gen_params->height; LOG_INFO("img2vid %dx%d", width, height); - std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); + std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sd_vid_gen_params->sample_steps); struct ggml_init_params params; params.mem_size = static_cast(10 * 1024) * 1024; // 10 MB - params.mem_size += width * height * 3 * sizeof(float) * video_frames; + params.mem_size += width * height * 3 * sizeof(float) * sd_vid_gen_params->video_frames; params.mem_buffer = NULL; params.no_alloc = false; // LOG_DEBUG("mem_size %u ", params.mem_size); @@ -1600,6 +2106,7 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, return NULL; } + int64_t seed = sd_vid_gen_params->seed; if (seed < 0) { seed = (int)time(NULL); } @@ -1609,12 +2116,12 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, int64_t t0 = ggml_time_ms(); SDCondition cond = sd_ctx->sd->get_svd_condition(work_ctx, - init_image, + sd_vid_gen_params->init_image, width, height, - fps, - motion_bucket_id, - augmentation_level); + sd_vid_gen_params->fps, + sd_vid_gen_params->motion_bucket_id, + sd_vid_gen_params->augmentation_level); auto uc_crossattn = ggml_dup_tensor(work_ctx, cond.c_crossattn); ggml_set_f32(uc_crossattn, 0.f); @@ -1636,24 +2143,24 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, int C = 4; int W = width / 8; int H = height / 8; - struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, video_frames); + struct ggml_tensor* x_t = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, sd_vid_gen_params->video_frames); ggml_set_f32(x_t, 0.f); - struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, video_frames); + struct ggml_tensor* noise = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, sd_vid_gen_params->video_frames); ggml_tensor_set_f32_randn(noise, sd_ctx->sd->rng); - LOG_INFO("sampling using %s method", sampling_methods_str[sample_method]); + LOG_INFO("sampling using %s method", sampling_methods_str[sd_vid_gen_params->sample_method]); struct ggml_tensor* x_0 = sd_ctx->sd->sample(work_ctx, x_t, noise, cond, uncond, {}, + {}, 0.f, - min_cfg, - cfg_scale, + sd_vid_gen_params->guidance, 0.f, - sample_method, + sd_vid_gen_params->sample_method, sigmas, -1, SDCondition(NULL, NULL, NULL)); @@ -1673,13 +2180,13 @@ SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, return NULL; } - sd_image_t* result_images = (sd_image_t*)calloc(video_frames, sizeof(sd_image_t)); + sd_image_t* result_images = (sd_image_t*)calloc(sd_vid_gen_params->video_frames, sizeof(sd_image_t)); if (result_images == NULL) { ggml_free(work_ctx); return NULL; } - for (size_t i = 0; i < video_frames; i++) { + for (size_t i = 0; i < sd_vid_gen_params->video_frames; i++) { auto img_i = ggml_view_3d(work_ctx, img, img->ne[0], img->ne[1], img->ne[2], img->nb[1], img->nb[2], img->nb[3] * i); result_images[i].width = width; diff --git a/stable-diffusion.h b/stable-diffusion.h index 812e8fc94..a60325923 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -30,7 +30,8 @@ extern "C" { enum rng_type_t { STD_DEFAULT_RNG, - CUDA_RNG + CUDA_RNG, + RNG_TYPE_COUNT }; enum sample_method_t { @@ -44,7 +45,9 @@ enum sample_method_t { IPNDM, IPNDM_V, LCM, - N_SAMPLE_METHODS + DDIM_TRAILING, + TCD, + SAMPLE_METHOD_COUNT }; enum schedule_t { @@ -54,7 +57,7 @@ enum schedule_t { EXPONENTIAL, AYS, GITS, - N_SCHEDULES + SCHEDULE_COUNT }; // same as enum ggml_type @@ -65,39 +68,42 @@ enum sd_type_t { SD_TYPE_Q4_1 = 3, // SD_TYPE_Q4_2 = 4, support has been removed // SD_TYPE_Q4_3 = 5, support has been removed - SD_TYPE_Q5_0 = 6, - SD_TYPE_Q5_1 = 7, - SD_TYPE_Q8_0 = 8, - SD_TYPE_Q8_1 = 9, - SD_TYPE_Q2_K = 10, - SD_TYPE_Q3_K = 11, - SD_TYPE_Q4_K = 12, - SD_TYPE_Q5_K = 13, - SD_TYPE_Q6_K = 14, - SD_TYPE_Q8_K = 15, - SD_TYPE_IQ2_XXS = 16, - SD_TYPE_IQ2_XS = 17, - SD_TYPE_IQ3_XXS = 18, - SD_TYPE_IQ1_S = 19, - SD_TYPE_IQ4_NL = 20, - SD_TYPE_IQ3_S = 21, - SD_TYPE_IQ2_S = 22, - SD_TYPE_IQ4_XS = 23, - SD_TYPE_I8 = 24, - SD_TYPE_I16 = 25, - SD_TYPE_I32 = 26, - SD_TYPE_I64 = 27, - SD_TYPE_F64 = 28, - SD_TYPE_IQ1_M = 29, - SD_TYPE_BF16 = 30, - SD_TYPE_Q4_0_4_4 = 31, - SD_TYPE_Q4_0_4_8 = 32, - SD_TYPE_Q4_0_8_8 = 33, - SD_TYPE_COUNT, + SD_TYPE_Q5_0 = 6, + SD_TYPE_Q5_1 = 7, + SD_TYPE_Q8_0 = 8, + SD_TYPE_Q8_1 = 9, + SD_TYPE_Q2_K = 10, + SD_TYPE_Q3_K = 11, + SD_TYPE_Q4_K = 12, + SD_TYPE_Q5_K = 13, + SD_TYPE_Q6_K = 14, + SD_TYPE_Q8_K = 15, + SD_TYPE_IQ2_XXS = 16, + SD_TYPE_IQ2_XS = 17, + SD_TYPE_IQ3_XXS = 18, + SD_TYPE_IQ1_S = 19, + SD_TYPE_IQ4_NL = 20, + SD_TYPE_IQ3_S = 21, + SD_TYPE_IQ2_S = 22, + SD_TYPE_IQ4_XS = 23, + SD_TYPE_I8 = 24, + SD_TYPE_I16 = 25, + SD_TYPE_I32 = 26, + SD_TYPE_I64 = 27, + SD_TYPE_F64 = 28, + SD_TYPE_IQ1_M = 29, + SD_TYPE_BF16 = 30, + // SD_TYPE_Q4_0_4_4 = 31, support has been removed from gguf files + // SD_TYPE_Q4_0_4_8 = 32, + // SD_TYPE_Q4_0_8_8 = 33, + SD_TYPE_TQ1_0 = 34, + SD_TYPE_TQ2_0 = 35, + // SD_TYPE_IQ4_NL_4_4 = 36, + // SD_TYPE_IQ4_NL_4_8 = 37, + // SD_TYPE_IQ4_NL_8_8 = 38, + SD_TYPE_COUNT = 39, }; -SD_API const char* sd_type_name(enum sd_type_t type); - enum sd_log_level_t { SD_LOG_DEBUG, SD_LOG_INFO, @@ -105,13 +111,33 @@ enum sd_log_level_t { SD_LOG_ERROR }; -typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data); -typedef void (*sd_progress_cb_t)(int step, int steps, float time, void* data); - -SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data); -SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data); -SD_API int32_t get_num_physical_cores(); -SD_API const char* sd_get_system_info(); +typedef struct { + const char* model_path; + const char* clip_l_path; + const char* clip_g_path; + const char* t5xxl_path; + const char* diffusion_model_path; + const char* vae_path; + const char* taesd_path; + const char* control_net_path; + const char* lora_model_dir; + const char* embedding_dir; + const char* stacked_id_embed_dir; + bool vae_decode_only; + bool vae_tiling; + bool free_params_immediately; + int n_threads; + enum sd_type_t wtype; + enum rng_type_t rng_type; + enum schedule_t schedule; + bool keep_clip_on_cpu; + bool keep_control_net_on_cpu; + bool keep_vae_on_cpu; + bool diffusion_flash_attn; + bool chroma_use_dit_mask; + bool chroma_use_t5_mask; + int chroma_t5_mask_pad; +} sd_ctx_params_t; typedef struct { uint32_t width; @@ -120,95 +146,106 @@ typedef struct { uint8_t* data; } sd_image_t; +typedef struct { + int* layers; + size_t layer_count; + float layer_start; + float layer_end; + float scale; +} sd_slg_params_t; + +typedef struct { + float txt_cfg; + float img_cfg; + float min_cfg; + float distilled_guidance; + sd_slg_params_t slg; +} sd_guidance_params_t; + +typedef struct { + const char* prompt; + const char* negative_prompt; + int clip_skip; + sd_guidance_params_t guidance; + sd_image_t init_image; + sd_image_t* ref_images; + int ref_images_count; + sd_image_t mask_image; + int width; + int height; + enum sample_method_t sample_method; + int sample_steps; + float eta; + float strength; + int64_t seed; + int batch_count; + const sd_image_t* control_cond; + float control_strength; + float style_strength; + bool normalize_input; + const char* input_id_images_path; +} sd_img_gen_params_t; + +typedef struct { + sd_image_t init_image; + int width; + int height; + sd_guidance_params_t guidance; + enum sample_method_t sample_method; + int sample_steps; + float strength; + int64_t seed; + int video_frames; + int motion_bucket_id; + int fps; + float augmentation_level; +} sd_vid_gen_params_t; + typedef struct sd_ctx_t sd_ctx_t; -SD_API sd_ctx_t* new_sd_ctx(const char* model_path, - const char* clip_l_path, - const char* clip_g_path, - const char* t5xxl_path, - const char* diffusion_model_path, - const char* vae_path, - const char* taesd_path, - const char* control_net_path_c_str, - const char* lora_model_dir, - const char* embed_dir_c_str, - const char* stacked_id_embed_dir_c_str, - bool vae_decode_only, - bool vae_tiling, - bool free_params_immediately, - int n_threads, - enum sd_type_t wtype, - enum rng_type_t rng_type, - enum schedule_t s, - bool keep_clip_on_cpu, - bool keep_control_net_cpu, - bool keep_vae_on_cpu); +typedef void (*sd_log_cb_t)(enum sd_log_level_t level, const char* text, void* data); +typedef void (*sd_progress_cb_t)(int step, int steps, float time, void* data); + +SD_API void sd_set_log_callback(sd_log_cb_t sd_log_cb, void* data); +SD_API void sd_set_progress_callback(sd_progress_cb_t cb, void* data); +SD_API int32_t get_num_physical_cores(); +SD_API const char* sd_get_system_info(); + +SD_API const char* sd_type_name(enum sd_type_t type); +SD_API enum sd_type_t str_to_sd_type(const char* str); +SD_API const char* sd_rng_type_name(enum rng_type_t rng_type); +SD_API enum rng_type_t str_to_rng_type(const char* str); +SD_API const char* sd_sample_method_name(enum sample_method_t sample_method); +SD_API enum sample_method_t str_to_sample_method(const char* str); +SD_API const char* sd_schedule_name(enum schedule_t schedule); +SD_API enum schedule_t str_to_schedule(const char* str); + +SD_API void sd_ctx_params_init(sd_ctx_params_t* sd_ctx_params); +SD_API char* sd_ctx_params_to_str(const sd_ctx_params_t* sd_ctx_params); +SD_API sd_ctx_t* new_sd_ctx(const sd_ctx_params_t* sd_ctx_params); SD_API void free_sd_ctx(sd_ctx_t* sd_ctx); -SD_API sd_image_t* txt2img(sd_ctx_t* sd_ctx, - const char* prompt, - const char* negative_prompt, - int clip_skip, - float cfg_scale, - float guidance, - int width, - int height, - enum sample_method_t sample_method, - int sample_steps, - int64_t seed, - int batch_count, - const sd_image_t* control_cond, - float control_strength, - float style_strength, - bool normalize_input, - const char* input_id_images_path); - -SD_API sd_image_t* img2img(sd_ctx_t* sd_ctx, - sd_image_t init_image, - const char* prompt, - const char* negative_prompt, - int clip_skip, - float cfg_scale, - float guidance, - int width, - int height, - enum sample_method_t sample_method, - int sample_steps, - float strength, - int64_t seed, - int batch_count, - const sd_image_t* control_cond, - float control_strength, - float style_strength, - bool normalize_input, - const char* input_id_images_path); - -SD_API sd_image_t* img2vid(sd_ctx_t* sd_ctx, - sd_image_t init_image, - int width, - int height, - int video_frames, - int motion_bucket_id, - int fps, - float augmentation_level, - float min_cfg, - float cfg_scale, - enum sample_method_t sample_method, - int sample_steps, - float strength, - int64_t seed); +SD_API void sd_img_gen_params_init(sd_img_gen_params_t* sd_img_gen_params); +SD_API char* sd_img_gen_params_to_str(const sd_img_gen_params_t* sd_img_gen_params); +SD_API sd_image_t* generate_image(sd_ctx_t* sd_ctx, const sd_img_gen_params_t* sd_img_gen_params); + +SD_API void sd_vid_gen_params_init(sd_vid_gen_params_t* sd_vid_gen_params); +SD_API sd_image_t* generate_video(sd_ctx_t* sd_ctx, const sd_vid_gen_params_t* sd_vid_gen_params); // broken typedef struct upscaler_ctx_t upscaler_ctx_t; SD_API upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path, - int n_threads, - enum sd_type_t wtype); + int n_threads); SD_API void free_upscaler_ctx(upscaler_ctx_t* upscaler_ctx); SD_API sd_image_t upscale(upscaler_ctx_t* upscaler_ctx, sd_image_t input_image, uint32_t upscale_factor); -SD_API bool convert(const char* input_path, const char* vae_path, const char* output_path, enum sd_type_t output_type); +SD_API bool convert(const char* input_path, + const char* vae_path, + const char* output_path, + enum sd_type_t output_type, + const char* tensor_type_rules); SD_API uint8_t* preprocess_canny(uint8_t* img, int width, diff --git a/t5.hpp b/t5.hpp index 79109e34b..d511ef24b 100644 --- a/t5.hpp +++ b/t5.hpp @@ -385,6 +385,7 @@ class T5UniGramTokenizer { void pad_tokens(std::vector& tokens, std::vector& weights, + std::vector* attention_mask, size_t max_length = 0, bool padding = false) { if (max_length > 0 && padding) { @@ -397,11 +398,15 @@ class T5UniGramTokenizer { LOG_DEBUG("token length: %llu", length); std::vector new_tokens; std::vector new_weights; + std::vector new_attention_mask; int token_idx = 0; for (int i = 0; i < length; i++) { if (token_idx >= orig_token_num) { break; } + if (attention_mask != nullptr) { + new_attention_mask.push_back(0.0); + } if (i % max_length == max_length - 1) { new_tokens.push_back(eos_id_); new_weights.push_back(1.0); @@ -414,13 +419,24 @@ class T5UniGramTokenizer { new_tokens.push_back(eos_id_); new_weights.push_back(1.0); + if (attention_mask != nullptr) { + new_attention_mask.push_back(0.0); + } + tokens = new_tokens; weights = new_weights; + if (attention_mask != nullptr) { + *attention_mask = new_attention_mask; + } if (padding) { int pad_token_id = pad_id_; tokens.insert(tokens.end(), length - tokens.size(), pad_token_id); weights.insert(weights.end(), length - weights.size(), 1.0); + if (attention_mask != nullptr) { + // maybe keep some padding tokens unmasked? + attention_mask->insert(attention_mask->end(), length - attention_mask->size(), -HUGE_VALF); + } } } } @@ -441,8 +457,9 @@ class T5LayerNorm : public UnaryBlock { int64_t hidden_size; float eps; - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["weight"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + enum ggml_type wtype = GGML_TYPE_F32; //(tensor_types.find(prefix + "weight") != tensor_types.end()) ? tensor_types[prefix + "weight"] : GGML_TYPE_F32; + params["weight"] = ggml_new_tensor_1d(ctx, wtype, hidden_size); } public: @@ -578,6 +595,7 @@ class T5Attention : public GGMLBlock { } if (past_bias != NULL) { if (mask != NULL) { + mask = ggml_repeat(ctx, mask, past_bias); mask = ggml_add(ctx, mask, past_bias); } else { mask = past_bias; @@ -717,14 +735,15 @@ struct T5Runner : public GGMLRunner { std::vector relative_position_bucket_vec; T5Runner(ggml_backend_t backend, - ggml_type wtype, + std::map& tensor_types, + const std::string prefix, int64_t num_layers = 24, int64_t model_dim = 4096, int64_t ff_dim = 10240, int64_t num_heads = 64, int64_t vocab_size = 32128) - : GGMLRunner(backend, wtype), model(num_layers, model_dim, ff_dim, num_heads, vocab_size) { - model.init(params_ctx, wtype); + : GGMLRunner(backend), model(num_layers, model_dim, ff_dim, num_heads, vocab_size) { + model.init(params_ctx, tensor_types, prefix); } std::string get_desc() { @@ -737,15 +756,17 @@ struct T5Runner : public GGMLRunner { struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* input_ids, - struct ggml_tensor* relative_position_bucket) { + struct ggml_tensor* relative_position_bucket, + struct ggml_tensor* attention_mask = NULL) { size_t N = input_ids->ne[1]; size_t n_token = input_ids->ne[0]; - auto hidden_states = model.forward(ctx, input_ids, NULL, NULL, relative_position_bucket); // [N, n_token, model_dim] + auto hidden_states = model.forward(ctx, input_ids, NULL, attention_mask, relative_position_bucket); // [N, n_token, model_dim] return hidden_states; } - struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids) { + struct ggml_cgraph* build_graph(struct ggml_tensor* input_ids, + struct ggml_tensor* attention_mask = NULL) { struct ggml_cgraph* gf = ggml_new_graph(compute_ctx); input_ids = to_backend(input_ids); @@ -765,7 +786,7 @@ struct T5Runner : public GGMLRunner { input_ids->ne[0]); set_backend_tensor_data(relative_position_bucket, relative_position_bucket_vec.data()); - struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, relative_position_bucket); + struct ggml_tensor* hidden_states = forward(compute_ctx, input_ids, relative_position_bucket, attention_mask); ggml_build_forward_expand(gf, hidden_states); @@ -774,10 +795,11 @@ struct T5Runner : public GGMLRunner { void compute(const int n_threads, struct ggml_tensor* input_ids, + struct ggml_tensor* attention_mask, ggml_tensor** output, ggml_context* output_ctx = NULL) { auto get_graph = [&]() -> struct ggml_cgraph* { - return build_graph(input_ids); + return build_graph(input_ids, attention_mask); }; GGMLRunner::compute(get_graph, n_threads, true, output, output_ctx); } @@ -854,14 +876,17 @@ struct T5Embedder { T5UniGramTokenizer tokenizer; T5Runner model; + static std::map empty_tensor_types; + T5Embedder(ggml_backend_t backend, - ggml_type wtype, - int64_t num_layers = 24, - int64_t model_dim = 4096, - int64_t ff_dim = 10240, - int64_t num_heads = 64, - int64_t vocab_size = 32128) - : model(backend, wtype, num_layers, model_dim, ff_dim, num_heads, vocab_size) { + std::map& tensor_types = empty_tensor_types, + const std::string prefix = "", + int64_t num_layers = 24, + int64_t model_dim = 4096, + int64_t ff_dim = 10240, + int64_t num_heads = 64, + int64_t vocab_size = 32128) + : model(backend, tensor_types, prefix, num_layers, model_dim, ff_dim, num_heads, vocab_size) { } void get_param_tensors(std::map& tensors, const std::string prefix) { @@ -872,9 +897,9 @@ struct T5Embedder { model.alloc_params_buffer(); } - std::pair, std::vector> tokenize(std::string text, - size_t max_length = 0, - bool padding = false) { + std::tuple, std::vector, std::vector> tokenize(std::string text, + size_t max_length = 0, + bool padding = false) { auto parsed_attention = parse_prompt_attention(text); { @@ -901,14 +926,16 @@ struct T5Embedder { tokens.push_back(EOS_TOKEN_ID); weights.push_back(1.0); - tokenizer.pad_tokens(tokens, weights, max_length, padding); + std::vector attention_mask; + + tokenizer.pad_tokens(tokens, weights, &attention_mask, max_length, padding); // for (int i = 0; i < tokens.size(); i++) { // std::cout << tokens[i] << ":" << weights[i] << ", "; // } // std::cout << std::endl; - return {tokens, weights}; + return {tokens, weights, attention_mask}; } void test() { @@ -929,8 +956,8 @@ struct T5Embedder { // TODO: fix cuda nan std::string text("a lovely cat"); auto tokens_and_weights = tokenize(text, 77, true); - std::vector& tokens = tokens_and_weights.first; - std::vector& weights = tokens_and_weights.second; + std::vector& tokens = std::get<0>(tokens_and_weights); + std::vector& weights = std::get<1>(tokens_and_weights); for (auto token : tokens) { printf("%d ", token); } @@ -939,7 +966,7 @@ struct T5Embedder { struct ggml_tensor* out = NULL; int t0 = ggml_time_ms(); - model.compute(8, input_ids, &out, work_ctx); + model.compute(8, input_ids, NULL, &out, work_ctx); int t1 = ggml_time_ms(); print_ggml_tensor(out); @@ -951,7 +978,7 @@ struct T5Embedder { // ggml_backend_t backend = ggml_backend_cuda_init(0); ggml_backend_t backend = ggml_backend_cpu_init(); ggml_type model_data_type = GGML_TYPE_F32; - std::shared_ptr t5 = std::shared_ptr(new T5Embedder(backend, model_data_type)); + std::shared_ptr t5 = std::shared_ptr(new T5Embedder(backend)); { LOG_INFO("loading from '%s'", file_path.c_str()); diff --git a/tae.hpp b/tae.hpp index 0e03b884e..678c44c57 100644 --- a/tae.hpp +++ b/tae.hpp @@ -62,7 +62,8 @@ class TinyEncoder : public UnaryBlock { int num_blocks = 3; public: - TinyEncoder() { + TinyEncoder(int z_channels = 4) + : z_channels(z_channels) { int index = 0; blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(in_channels, channels, {3, 3}, {1, 1}, {1, 1})); blocks[std::to_string(index++)] = std::shared_ptr(new TAEBlock(channels, channels)); @@ -106,7 +107,10 @@ class TinyDecoder : public UnaryBlock { int num_blocks = 3; public: - TinyDecoder(int index = 0) { + TinyDecoder(int z_channels = 4) + : z_channels(z_channels) { + int index = 0; + blocks[std::to_string(index++)] = std::shared_ptr(new Conv2d(z_channels, channels, {3, 3}, {1, 1}, {1, 1})); index++; // nn.ReLU() @@ -145,7 +149,7 @@ class TinyDecoder : public UnaryBlock { if (i == 1) { h = ggml_relu_inplace(ctx, h); } else { - h = ggml_upscale(ctx, h, 2); + h = ggml_upscale(ctx, h, 2, GGML_SCALE_MODE_NEAREST); } continue; } @@ -163,12 +167,16 @@ class TAESD : public GGMLBlock { bool decode_only; public: - TAESD(bool decode_only = true) + TAESD(bool decode_only = true, SDVersion version = VERSION_SD1) : decode_only(decode_only) { - blocks["decoder.layers"] = std::shared_ptr(new TinyDecoder()); + int z_channels = 4; + if (sd_version_is_dit(version)) { + z_channels = 16; + } + blocks["decoder.layers"] = std::shared_ptr(new TinyDecoder(z_channels)); if (!decode_only) { - blocks["encoder.layers"] = std::shared_ptr(new TinyEncoder()); + blocks["encoder.layers"] = std::shared_ptr(new TinyEncoder(z_channels)); } } @@ -188,12 +196,14 @@ struct TinyAutoEncoder : public GGMLRunner { bool decode_only = false; TinyAutoEncoder(ggml_backend_t backend, - ggml_type wtype, - bool decoder_only = true) + std::map& tensor_types, + const std::string prefix, + bool decoder_only = true, + SDVersion version = VERSION_SD1) : decode_only(decoder_only), - taesd(decode_only), - GGMLRunner(backend, wtype) { - taesd.init(params_ctx, wtype); + taesd(decoder_only, version), + GGMLRunner(backend) { + taesd.init(params_ctx, tensor_types, prefix); } std::string get_desc() { diff --git a/thirdparty/httplib.h b/thirdparty/httplib.h new file mode 100644 index 000000000..f360bd93e --- /dev/null +++ b/thirdparty/httplib.h @@ -0,0 +1,9465 @@ +// +// httplib.h +// +// Copyright (c) 2024 Yuji Hirose. All rights reserved. +// MIT License +// + +#ifndef CPPHTTPLIB_HTTPLIB_H +#define CPPHTTPLIB_HTTPLIB_H + +#define CPPHTTPLIB_VERSION "0.15.3" + +/* + * Configuration + */ + +#ifndef CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND +#define CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_KEEPALIVE_MAX_COUNT +#define CPPHTTPLIB_KEEPALIVE_MAX_COUNT 5 +#endif + +#ifndef CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND +#define CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND 300 +#endif + +#ifndef CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND +#define CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_READ_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_READ_TIMEOUT_USECOND +#define CPPHTTPLIB_READ_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_WRITE_TIMEOUT_SECOND +#define CPPHTTPLIB_WRITE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_WRITE_TIMEOUT_USECOND +#define CPPHTTPLIB_WRITE_TIMEOUT_USECOND 0 +#endif + +#ifndef CPPHTTPLIB_IDLE_INTERVAL_SECOND +#define CPPHTTPLIB_IDLE_INTERVAL_SECOND 0 +#endif + +#ifndef CPPHTTPLIB_IDLE_INTERVAL_USECOND +#ifdef _WIN32 +#define CPPHTTPLIB_IDLE_INTERVAL_USECOND 10000 +#else +#define CPPHTTPLIB_IDLE_INTERVAL_USECOND 0 +#endif +#endif + +#ifndef CPPHTTPLIB_REQUEST_URI_MAX_LENGTH +#define CPPHTTPLIB_REQUEST_URI_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_HEADER_MAX_LENGTH +#define CPPHTTPLIB_HEADER_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_REDIRECT_MAX_COUNT +#define CPPHTTPLIB_REDIRECT_MAX_COUNT 20 +#endif + +#ifndef CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT +#define CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT 1024 +#endif + +#ifndef CPPHTTPLIB_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH ((std::numeric_limits::max)()) +#endif + +#ifndef CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH +#define CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH 8192 +#endif + +#ifndef CPPHTTPLIB_RANGE_MAX_COUNT +#define CPPHTTPLIB_RANGE_MAX_COUNT 1024 +#endif + +#ifndef CPPHTTPLIB_TCP_NODELAY +#define CPPHTTPLIB_TCP_NODELAY false +#endif + +#ifndef CPPHTTPLIB_RECV_BUFSIZ +#define CPPHTTPLIB_RECV_BUFSIZ size_t(4096u) +#endif + +#ifndef CPPHTTPLIB_COMPRESSION_BUFSIZ +#define CPPHTTPLIB_COMPRESSION_BUFSIZ size_t(16384u) +#endif + +#ifndef CPPHTTPLIB_THREAD_POOL_COUNT +#define CPPHTTPLIB_THREAD_POOL_COUNT \ + ((std::max)(8u, std::thread::hardware_concurrency() > 0 \ + ? std::thread::hardware_concurrency() - 1 \ + : 0)) +#endif + +#ifndef CPPHTTPLIB_RECV_FLAGS +#define CPPHTTPLIB_RECV_FLAGS 0 +#endif + +#ifndef CPPHTTPLIB_SEND_FLAGS +#define CPPHTTPLIB_SEND_FLAGS 0 +#endif + +#ifndef CPPHTTPLIB_LISTEN_BACKLOG +#define CPPHTTPLIB_LISTEN_BACKLOG 5 +#endif + +/* + * Headers + */ + +#ifdef _WIN32 +#ifndef _CRT_SECURE_NO_WARNINGS +#define _CRT_SECURE_NO_WARNINGS +#endif //_CRT_SECURE_NO_WARNINGS + +#ifndef _CRT_NONSTDC_NO_DEPRECATE +#define _CRT_NONSTDC_NO_DEPRECATE +#endif //_CRT_NONSTDC_NO_DEPRECATE + +#if defined(_MSC_VER) +#if _MSC_VER < 1900 +#error Sorry, Visual Studio versions prior to 2015 are not supported +#endif + +#pragma comment(lib, "ws2_32.lib") + +#ifdef _WIN64 +using ssize_t = __int64; +#else +using ssize_t = long; +#endif +#endif // _MSC_VER + +#ifndef S_ISREG +#define S_ISREG(m) (((m)&S_IFREG) == S_IFREG) +#endif // S_ISREG + +#ifndef S_ISDIR +#define S_ISDIR(m) (((m)&S_IFDIR) == S_IFDIR) +#endif // S_ISDIR + +#ifndef NOMINMAX +#define NOMINMAX +#endif // NOMINMAX + +#include +#include +#include + +#ifndef WSA_FLAG_NO_HANDLE_INHERIT +#define WSA_FLAG_NO_HANDLE_INHERIT 0x80 +#endif + +using socket_t = SOCKET; +#ifdef CPPHTTPLIB_USE_POLL +#define poll(fds, nfds, timeout) WSAPoll(fds, nfds, timeout) +#endif + +#else // not _WIN32 + +#include +#if !defined(_AIX) && !defined(__MVS__) +#include +#endif +#ifdef __MVS__ +#include +#ifndef NI_MAXHOST +#define NI_MAXHOST 1025 +#endif +#endif +#include +#include +#include +#ifdef __linux__ +#include +#endif +#include +#ifdef CPPHTTPLIB_USE_POLL +#include +#endif +#include +#include +#include +#include +#include +#include +#include + +using socket_t = int; +#ifndef INVALID_SOCKET +#define INVALID_SOCKET (-1) +#endif +#endif //_WIN32 + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef _WIN32 +#include + +// these are defined in wincrypt.h and it breaks compilation if BoringSSL is +// used +#undef X509_NAME +#undef X509_CERT_PAIR +#undef X509_EXTENSIONS +#undef PKCS7_SIGNER_INFO + +#ifdef _MSC_VER +#pragma comment(lib, "crypt32.lib") +#endif +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && defined(__APPLE__) +#include +#if TARGET_OS_OSX +#include +#include +#endif // TARGET_OS_OSX +#endif // _WIN32 + +#include +#include +#include +#include + +#if defined(_WIN32) && defined(OPENSSL_USE_APPLINK) +#include +#endif + +#include +#include + +#if OPENSSL_VERSION_NUMBER < 0x30000000L +#error Sorry, OpenSSL versions prior to 3.0.0 are not supported +#endif + +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +#include +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +#include +#include +#endif + +/* + * Declaration + */ +namespace httplib { + +namespace detail { + +/* + * Backport std::make_unique from C++14. + * + * NOTE: This code came up with the following stackoverflow post: + * https://stackoverflow.com/questions/10149840/c-arrays-and-make-unique + * + */ + +template +typename std::enable_if::value, std::unique_ptr>::type +make_unique(Args &&...args) { + return std::unique_ptr(new T(std::forward(args)...)); +} + +template +typename std::enable_if::value, std::unique_ptr>::type +make_unique(std::size_t n) { + typedef typename std::remove_extent::type RT; + return std::unique_ptr(new RT[n]); +} + +struct ci { + bool operator()(const std::string &s1, const std::string &s2) const { + return std::lexicographical_compare(s1.begin(), s1.end(), s2.begin(), + s2.end(), + [](unsigned char c1, unsigned char c2) { + return ::tolower(c1) < ::tolower(c2); + }); + } +}; + +// This is based on +// "http://www.open-std.org/jtc1/sc22/wg21/docs/papers/2014/n4189". + +struct scope_exit { + explicit scope_exit(std::function &&f) + : exit_function(std::move(f)), execute_on_destruction{true} {} + + scope_exit(scope_exit &&rhs) noexcept + : exit_function(std::move(rhs.exit_function)), + execute_on_destruction{rhs.execute_on_destruction} { + rhs.release(); + } + + ~scope_exit() { + if (execute_on_destruction) { this->exit_function(); } + } + + void release() { this->execute_on_destruction = false; } + +private: + scope_exit(const scope_exit &) = delete; + void operator=(const scope_exit &) = delete; + scope_exit &operator=(scope_exit &&) = delete; + + std::function exit_function; + bool execute_on_destruction; +}; + +} // namespace detail + +enum StatusCode { + // Information responses + Continue_100 = 100, + SwitchingProtocol_101 = 101, + Processing_102 = 102, + EarlyHints_103 = 103, + + // Successful responses + OK_200 = 200, + Created_201 = 201, + Accepted_202 = 202, + NonAuthoritativeInformation_203 = 203, + NoContent_204 = 204, + ResetContent_205 = 205, + PartialContent_206 = 206, + MultiStatus_207 = 207, + AlreadyReported_208 = 208, + IMUsed_226 = 226, + + // Redirection messages + MultipleChoices_300 = 300, + MovedPermanently_301 = 301, + Found_302 = 302, + SeeOther_303 = 303, + NotModified_304 = 304, + UseProxy_305 = 305, + unused_306 = 306, + TemporaryRedirect_307 = 307, + PermanentRedirect_308 = 308, + + // Client error responses + BadRequest_400 = 400, + Unauthorized_401 = 401, + PaymentRequired_402 = 402, + Forbidden_403 = 403, + NotFound_404 = 404, + MethodNotAllowed_405 = 405, + NotAcceptable_406 = 406, + ProxyAuthenticationRequired_407 = 407, + RequestTimeout_408 = 408, + Conflict_409 = 409, + Gone_410 = 410, + LengthRequired_411 = 411, + PreconditionFailed_412 = 412, + PayloadTooLarge_413 = 413, + UriTooLong_414 = 414, + UnsupportedMediaType_415 = 415, + RangeNotSatisfiable_416 = 416, + ExpectationFailed_417 = 417, + ImATeapot_418 = 418, + MisdirectedRequest_421 = 421, + UnprocessableContent_422 = 422, + Locked_423 = 423, + FailedDependency_424 = 424, + TooEarly_425 = 425, + UpgradeRequired_426 = 426, + PreconditionRequired_428 = 428, + TooManyRequests_429 = 429, + RequestHeaderFieldsTooLarge_431 = 431, + UnavailableForLegalReasons_451 = 451, + + // Server error responses + InternalServerError_500 = 500, + NotImplemented_501 = 501, + BadGateway_502 = 502, + ServiceUnavailable_503 = 503, + GatewayTimeout_504 = 504, + HttpVersionNotSupported_505 = 505, + VariantAlsoNegotiates_506 = 506, + InsufficientStorage_507 = 507, + LoopDetected_508 = 508, + NotExtended_510 = 510, + NetworkAuthenticationRequired_511 = 511, +}; + +using Headers = std::multimap; + +using Params = std::multimap; +using Match = std::smatch; + +using Progress = std::function; + +struct Response; +using ResponseHandler = std::function; + +struct MultipartFormData { + std::string name; + std::string content; + std::string filename; + std::string content_type; +}; +using MultipartFormDataItems = std::vector; +using MultipartFormDataMap = std::multimap; + +class DataSink { +public: + DataSink() : os(&sb_), sb_(*this) {} + + DataSink(const DataSink &) = delete; + DataSink &operator=(const DataSink &) = delete; + DataSink(DataSink &&) = delete; + DataSink &operator=(DataSink &&) = delete; + + std::function write; + std::function is_writable; + std::function done; + std::function done_with_trailer; + std::ostream os; + +private: + class data_sink_streambuf final : public std::streambuf { + public: + explicit data_sink_streambuf(DataSink &sink) : sink_(sink) {} + + protected: + std::streamsize xsputn(const char *s, std::streamsize n) override { + sink_.write(s, static_cast(n)); + return n; + } + + private: + DataSink &sink_; + }; + + data_sink_streambuf sb_; +}; + +using ContentProvider = + std::function; + +using ContentProviderWithoutLength = + std::function; + +using ContentProviderResourceReleaser = std::function; + +struct MultipartFormDataProvider { + std::string name; + ContentProviderWithoutLength provider; + std::string filename; + std::string content_type; +}; +using MultipartFormDataProviderItems = std::vector; + +using ContentReceiverWithProgress = + std::function; + +using ContentReceiver = + std::function; + +using MultipartContentHeader = + std::function; + +class ContentReader { +public: + using Reader = std::function; + using MultipartReader = std::function; + + ContentReader(Reader reader, MultipartReader multipart_reader) + : reader_(std::move(reader)), + multipart_reader_(std::move(multipart_reader)) {} + + bool operator()(MultipartContentHeader header, + ContentReceiver receiver) const { + return multipart_reader_(std::move(header), std::move(receiver)); + } + + bool operator()(ContentReceiver receiver) const { + return reader_(std::move(receiver)); + } + + Reader reader_; + MultipartReader multipart_reader_; +}; + +using Range = std::pair; +using Ranges = std::vector; + +struct Request { + std::string method; + std::string path; + Headers headers; + std::string body; + + std::string remote_addr; + int remote_port = -1; + std::string local_addr; + int local_port = -1; + + // for server + std::string version; + std::string target; + Params params; + MultipartFormDataMap files; + Ranges ranges; + Match matches; + std::unordered_map path_params; + + // for client + ResponseHandler response_handler; + ContentReceiverWithProgress content_receiver; + Progress progress; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + const SSL *ssl = nullptr; +#endif + + bool has_header(const std::string &key) const; + std::string get_header_value(const std::string &key, size_t id = 0) const; + uint64_t get_header_value_u64(const std::string &key, size_t id = 0) const; + size_t get_header_value_count(const std::string &key) const; + void set_header(const std::string &key, const std::string &val); + + bool has_param(const std::string &key) const; + std::string get_param_value(const std::string &key, size_t id = 0) const; + size_t get_param_value_count(const std::string &key) const; + + bool is_multipart_form_data() const; + + bool has_file(const std::string &key) const; + MultipartFormData get_file_value(const std::string &key) const; + std::vector get_file_values(const std::string &key) const; + + // private members... + size_t redirect_count_ = CPPHTTPLIB_REDIRECT_MAX_COUNT; + size_t content_length_ = 0; + ContentProvider content_provider_; + bool is_chunked_content_provider_ = false; + size_t authorization_count_ = 0; +}; + +struct Response { + std::string version; + int status = -1; + std::string reason; + Headers headers; + std::string body; + std::string location; // Redirect location + + bool has_header(const std::string &key) const; + std::string get_header_value(const std::string &key, size_t id = 0) const; + uint64_t get_header_value_u64(const std::string &key, size_t id = 0) const; + size_t get_header_value_count(const std::string &key) const; + void set_header(const std::string &key, const std::string &val); + + void set_redirect(const std::string &url, int status = StatusCode::Found_302); + void set_content(const char *s, size_t n, const std::string &content_type); + void set_content(const std::string &s, const std::string &content_type); + void set_content(std::string &&s, const std::string &content_type); + + void set_content_provider( + size_t length, const std::string &content_type, ContentProvider provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_content_provider( + const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + void set_chunked_content_provider( + const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser = nullptr); + + Response() = default; + Response(const Response &) = default; + Response &operator=(const Response &) = default; + Response(Response &&) = default; + Response &operator=(Response &&) = default; + ~Response() { + if (content_provider_resource_releaser_) { + content_provider_resource_releaser_(content_provider_success_); + } + } + + // private members... + size_t content_length_ = 0; + ContentProvider content_provider_; + ContentProviderResourceReleaser content_provider_resource_releaser_; + bool is_chunked_content_provider_ = false; + bool content_provider_success_ = false; +}; + +class Stream { +public: + virtual ~Stream() = default; + + virtual bool is_readable() const = 0; + virtual bool is_writable() const = 0; + + virtual ssize_t read(char *ptr, size_t size) = 0; + virtual ssize_t write(const char *ptr, size_t size) = 0; + virtual void get_remote_ip_and_port(std::string &ip, int &port) const = 0; + virtual void get_local_ip_and_port(std::string &ip, int &port) const = 0; + virtual socket_t socket() const = 0; + + template + ssize_t write_format(const char *fmt, const Args &...args); + ssize_t write(const char *ptr); + ssize_t write(const std::string &s); +}; + +class TaskQueue { +public: + TaskQueue() = default; + virtual ~TaskQueue() = default; + + virtual bool enqueue(std::function fn) = 0; + virtual void shutdown() = 0; + + virtual void on_idle() {} +}; + +class ThreadPool final : public TaskQueue { +public: + explicit ThreadPool(size_t n, size_t mqr = 0) + : shutdown_(false), max_queued_requests_(mqr) { + while (n) { + threads_.emplace_back(worker(*this)); + n--; + } + } + + ThreadPool(const ThreadPool &) = delete; + ~ThreadPool() override = default; + + bool enqueue(std::function fn) override { + { + std::unique_lock lock(mutex_); + if (max_queued_requests_ > 0 && jobs_.size() >= max_queued_requests_) { + return false; + } + jobs_.push_back(std::move(fn)); + } + + cond_.notify_one(); + return true; + } + + void shutdown() override { + // Stop all worker threads... + { + std::unique_lock lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); + + // Join... + for (auto &t : threads_) { + t.join(); + } + } + +private: + struct worker { + explicit worker(ThreadPool &pool) : pool_(pool) {} + + void operator()() { + for (;;) { + std::function fn; + { + std::unique_lock lock(pool_.mutex_); + + pool_.cond_.wait( + lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + + if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } + + fn = std::move(pool_.jobs_.front()); + pool_.jobs_.pop_front(); + } + + assert(true == static_cast(fn)); + fn(); + } + } + + ThreadPool &pool_; + }; + friend struct worker; + + std::vector threads_; + std::list> jobs_; + + bool shutdown_; + size_t max_queued_requests_ = 0; + + std::condition_variable cond_; + std::mutex mutex_; +}; + +using Logger = std::function; + +using SocketOptions = std::function; + +void default_socket_options(socket_t sock); + +const char *status_message(int status); + +std::string get_bearer_token_auth(const Request &req); + +namespace detail { + +class MatcherBase { +public: + virtual ~MatcherBase() = default; + + // Match request path and populate its matches and + virtual bool match(Request &request) const = 0; +}; + +/** + * Captures parameters in request path and stores them in Request::path_params + * + * Capture name is a substring of a pattern from : to /. + * The rest of the pattern is matched agains the request path directly + * Parameters are captured starting from the next character after + * the end of the last matched static pattern fragment until the next /. + * + * Example pattern: + * "/path/fragments/:capture/more/fragments/:second_capture" + * Static fragments: + * "/path/fragments/", "more/fragments/" + * + * Given the following request path: + * "/path/fragments/:1/more/fragments/:2" + * the resulting capture will be + * {{"capture", "1"}, {"second_capture", "2"}} + */ +class PathParamsMatcher final : public MatcherBase { +public: + PathParamsMatcher(const std::string &pattern); + + bool match(Request &request) const override; + +private: + static constexpr char marker = ':'; + // Treat segment separators as the end of path parameter capture + // Does not need to handle query parameters as they are parsed before path + // matching + static constexpr char separator = '/'; + + // Contains static path fragments to match against, excluding the '/' after + // path params + // Fragments are separated by path params + std::vector static_fragments_; + // Stores the names of the path parameters to be used as keys in the + // Request::path_params map + std::vector param_names_; +}; + +/** + * Performs std::regex_match on request path + * and stores the result in Request::matches + * + * Note that regex match is performed directly on the whole request. + * This means that wildcard patterns may match multiple path segments with /: + * "/begin/(.*)/end" will match both "/begin/middle/end" and "/begin/1/2/end". + */ +class RegexMatcher final : public MatcherBase { +public: + RegexMatcher(const std::string &pattern) : regex_(pattern) {} + + bool match(Request &request) const override; + +private: + std::regex regex_; +}; + +ssize_t write_headers(Stream &strm, const Headers &headers); + +} // namespace detail + +class Server { +public: + using Handler = std::function; + + using ExceptionHandler = + std::function; + + enum class HandlerResponse { + Handled, + Unhandled, + }; + using HandlerWithResponse = + std::function; + + using HandlerWithContentReader = std::function; + + using Expect100ContinueHandler = + std::function; + + Server(); + + virtual ~Server(); + + virtual bool is_valid() const; + + Server &Get(const std::string &pattern, Handler handler); + Server &Post(const std::string &pattern, Handler handler); + Server &Post(const std::string &pattern, HandlerWithContentReader handler); + Server &Put(const std::string &pattern, Handler handler); + Server &Put(const std::string &pattern, HandlerWithContentReader handler); + Server &Patch(const std::string &pattern, Handler handler); + Server &Patch(const std::string &pattern, HandlerWithContentReader handler); + Server &Delete(const std::string &pattern, Handler handler); + Server &Delete(const std::string &pattern, HandlerWithContentReader handler); + Server &Options(const std::string &pattern, Handler handler); + + bool set_base_dir(const std::string &dir, + const std::string &mount_point = std::string()); + bool set_mount_point(const std::string &mount_point, const std::string &dir, + Headers headers = Headers()); + bool remove_mount_point(const std::string &mount_point); + Server &set_file_extension_and_mimetype_mapping(const std::string &ext, + const std::string &mime); + Server &set_default_file_mimetype(const std::string &mime); + Server &set_file_request_handler(Handler handler); + + Server &set_error_handler(HandlerWithResponse handler); + Server &set_error_handler(Handler handler); + Server &set_exception_handler(ExceptionHandler handler); + Server &set_pre_routing_handler(HandlerWithResponse handler); + Server &set_post_routing_handler(Handler handler); + + Server &set_expect_100_continue_handler(Expect100ContinueHandler handler); + Server &set_logger(Logger logger); + + Server &set_address_family(int family); + Server &set_tcp_nodelay(bool on); + Server &set_socket_options(SocketOptions socket_options); + + Server &set_default_headers(Headers headers); + Server & + set_header_writer(std::function const &writer); + + Server &set_keep_alive_max_count(size_t count); + Server &set_keep_alive_timeout(time_t sec); + + Server &set_read_timeout(time_t sec, time_t usec = 0); + template + Server &set_read_timeout(const std::chrono::duration &duration); + + Server &set_write_timeout(time_t sec, time_t usec = 0); + template + Server &set_write_timeout(const std::chrono::duration &duration); + + Server &set_idle_interval(time_t sec, time_t usec = 0); + template + Server &set_idle_interval(const std::chrono::duration &duration); + + Server &set_payload_max_length(size_t length); + + bool bind_to_port(const std::string &host, int port, int socket_flags = 0); + int bind_to_any_port(const std::string &host, int socket_flags = 0); + bool listen_after_bind(); + + bool listen(const std::string &host, int port, int socket_flags = 0); + + bool is_running() const; + void wait_until_ready() const; + void stop(); + + std::function new_task_queue; + +protected: + bool process_request(Stream &strm, bool close_connection, + bool &connection_closed, + const std::function &setup_request); + + std::atomic svr_sock_{INVALID_SOCKET}; + size_t keep_alive_max_count_ = CPPHTTPLIB_KEEPALIVE_MAX_COUNT; + time_t keep_alive_timeout_sec_ = CPPHTTPLIB_KEEPALIVE_TIMEOUT_SECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; + time_t idle_interval_sec_ = CPPHTTPLIB_IDLE_INTERVAL_SECOND; + time_t idle_interval_usec_ = CPPHTTPLIB_IDLE_INTERVAL_USECOND; + size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; + +private: + using Handlers = + std::vector, Handler>>; + using HandlersForContentReader = + std::vector, + HandlerWithContentReader>>; + + static std::unique_ptr + make_matcher(const std::string &pattern); + + socket_t create_server_socket(const std::string &host, int port, + int socket_flags, + SocketOptions socket_options) const; + int bind_internal(const std::string &host, int port, int socket_flags); + bool listen_internal(); + + bool routing(Request &req, Response &res, Stream &strm); + bool handle_file_request(const Request &req, Response &res, + bool head = false); + bool dispatch_request(Request &req, Response &res, + const Handlers &handlers) const; + bool dispatch_request_for_content_reader( + Request &req, Response &res, ContentReader content_reader, + const HandlersForContentReader &handlers) const; + + bool parse_request_line(const char *s, Request &req) const; + void apply_ranges(const Request &req, Response &res, + std::string &content_type, std::string &boundary) const; + bool write_response(Stream &strm, bool close_connection, Request &req, + Response &res); + bool write_response_with_content(Stream &strm, bool close_connection, + const Request &req, Response &res); + bool write_response_core(Stream &strm, bool close_connection, + const Request &req, Response &res, + bool need_apply_ranges); + bool write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type); + bool read_content(Stream &strm, Request &req, Response &res); + bool + read_content_with_content_receiver(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver); + bool read_content_core(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver) const; + + virtual bool process_and_close_socket(socket_t sock); + + std::atomic is_running_{false}; + std::atomic done_{false}; + + struct MountPointEntry { + std::string mount_point; + std::string base_dir; + Headers headers; + }; + std::vector base_dirs_; + std::map file_extension_and_mimetype_map_; + std::string default_file_mimetype_ = "application/octet-stream"; + Handler file_request_handler_; + + Handlers get_handlers_; + Handlers post_handlers_; + HandlersForContentReader post_handlers_for_content_reader_; + Handlers put_handlers_; + HandlersForContentReader put_handlers_for_content_reader_; + Handlers patch_handlers_; + HandlersForContentReader patch_handlers_for_content_reader_; + Handlers delete_handlers_; + HandlersForContentReader delete_handlers_for_content_reader_; + Handlers options_handlers_; + + HandlerWithResponse error_handler_; + ExceptionHandler exception_handler_; + HandlerWithResponse pre_routing_handler_; + Handler post_routing_handler_; + Expect100ContinueHandler expect_100_continue_handler_; + + Logger logger_; + + int address_family_ = AF_UNSPEC; + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + SocketOptions socket_options_ = default_socket_options; + + Headers default_headers_; + std::function header_writer_ = + detail::write_headers; +}; + +enum class Error { + Success = 0, + Unknown, + Connection, + BindIPAddress, + Read, + Write, + ExceedRedirectCount, + Canceled, + SSLConnection, + SSLLoadingCerts, + SSLServerVerification, + UnsupportedMultipartBoundaryChars, + Compression, + ConnectionTimeout, + ProxyConnection, + + // For internal use only + SSLPeerCouldBeClosed_, +}; + +std::string to_string(Error error); + +std::ostream &operator<<(std::ostream &os, const Error &obj); + +class Result { +public: + Result() = default; + Result(std::unique_ptr &&res, Error err, + Headers &&request_headers = Headers{}) + : res_(std::move(res)), err_(err), + request_headers_(std::move(request_headers)) {} + // Response + operator bool() const { return res_ != nullptr; } + bool operator==(std::nullptr_t) const { return res_ == nullptr; } + bool operator!=(std::nullptr_t) const { return res_ != nullptr; } + const Response &value() const { return *res_; } + Response &value() { return *res_; } + const Response &operator*() const { return *res_; } + Response &operator*() { return *res_; } + const Response *operator->() const { return res_.get(); } + Response *operator->() { return res_.get(); } + + // Error + Error error() const { return err_; } + + // Request Headers + bool has_request_header(const std::string &key) const; + std::string get_request_header_value(const std::string &key, + size_t id = 0) const; + uint64_t get_request_header_value_u64(const std::string &key, + size_t id = 0) const; + size_t get_request_header_value_count(const std::string &key) const; + +private: + std::unique_ptr res_; + Error err_ = Error::Unknown; + Headers request_headers_; +}; + +class ClientImpl { +public: + explicit ClientImpl(const std::string &host); + + explicit ClientImpl(const std::string &host, int port); + + explicit ClientImpl(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path); + + virtual ~ClientImpl(); + + virtual bool is_valid() const; + + Result Get(const std::string &path); + Result Get(const std::string &path, const Headers &headers); + Result Get(const std::string &path, Progress progress); + Result Get(const std::string &path, const Headers &headers, + Progress progress); + Result Get(const std::string &path, ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver); + Result Get(const std::string &path, ContentReceiver content_receiver, + Progress progress); + Result Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const std::string &path, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, ContentReceiver content_receiver, + Progress progress); + + Result Get(const std::string &path, const Params ¶ms, + const Headers &headers, Progress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, + const Headers &headers, ContentReceiver content_receiver, + Progress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, + const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress = nullptr); + + Result Head(const std::string &path); + Result Head(const std::string &path, const Headers &headers); + + Result Post(const std::string &path); + Result Post(const std::string &path, const Headers &headers); + Result Post(const std::string &path, const char *body, size_t content_length, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type); + Result Post(const std::string &path, const std::string &body, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, + const Params ¶ms); + Result Post(const std::string &path, const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, const std::string &boundary); + Result Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Put(const std::string &path); + Result Put(const std::string &path, const char *body, size_t content_length, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type); + Result Put(const std::string &path, const std::string &body, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Put(const std::string &path, size_t content_length, + ContentProvider content_provider, const std::string &content_type); + Result Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, + const Params ¶ms); + Result Put(const std::string &path, const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, const std::string &boundary); + Result Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Patch(const std::string &path); + Result Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type); + Result Patch(const std::string &path, const std::string &body, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + + Result Delete(const std::string &path); + Result Delete(const std::string &path, const Headers &headers); + Result Delete(const std::string &path, const char *body, + size_t content_length, const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type); + Result Delete(const std::string &path, const std::string &body, + const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + + Result Options(const std::string &path); + Result Options(const std::string &path, const Headers &headers); + + bool send(Request &req, Response &res, Error &error); + Result send(const Request &req); + + void stop(); + + std::string host() const; + int port() const; + + size_t is_socket_open() const; + socket_t socket() const; + + void set_hostname_addr_map(std::map addr_map); + + void set_default_headers(Headers headers); + + void + set_header_writer(std::function const &writer); + + void set_address_family(int family); + void set_tcp_nodelay(bool on); + void set_socket_options(SocketOptions socket_options); + + void set_connection_timeout(time_t sec, time_t usec = 0); + template + void + set_connection_timeout(const std::chrono::duration &duration); + + void set_read_timeout(time_t sec, time_t usec = 0); + template + void set_read_timeout(const std::chrono::duration &duration); + + void set_write_timeout(time_t sec, time_t usec = 0); + template + void set_write_timeout(const std::chrono::duration &duration); + + void set_basic_auth(const std::string &username, const std::string &password); + void set_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_digest_auth(const std::string &username, + const std::string &password); +#endif + + void set_keep_alive(bool on); + void set_follow_location(bool on); + + void set_url_encode(bool on); + + void set_compress(bool on); + + void set_decompress(bool on); + + void set_interface(const std::string &intf); + + void set_proxy(const std::string &host, int port); + void set_proxy_basic_auth(const std::string &username, + const std::string &password); + void set_proxy_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_proxy_digest_auth(const std::string &username, + const std::string &password); +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path = std::string()); + void set_ca_cert_store(X509_STORE *ca_cert_store); + X509_STORE *create_ca_cert_store(const char *ca_cert, std::size_t size) const; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void enable_server_certificate_verification(bool enabled); +#endif + + void set_logger(Logger logger); + +protected: + struct Socket { + socket_t sock = INVALID_SOCKET; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSL *ssl = nullptr; +#endif + + bool is_open() const { return sock != INVALID_SOCKET; } + }; + + virtual bool create_and_connect_socket(Socket &socket, Error &error); + + // All of: + // shutdown_ssl + // shutdown_socket + // close_socket + // should ONLY be called when socket_mutex_ is locked. + // Also, shutdown_ssl and close_socket should also NOT be called concurrently + // with a DIFFERENT thread sending requests using that socket. + virtual void shutdown_ssl(Socket &socket, bool shutdown_gracefully); + void shutdown_socket(Socket &socket) const; + void close_socket(Socket &socket); + + bool process_request(Stream &strm, Request &req, Response &res, + bool close_connection, Error &error); + + bool write_content_with_provider(Stream &strm, const Request &req, + Error &error) const; + + void copy_settings(const ClientImpl &rhs); + + // Socket endpoint information + const std::string host_; + const int port_; + const std::string host_and_port_; + + // Current open socket + Socket socket_; + mutable std::mutex socket_mutex_; + std::recursive_mutex request_mutex_; + + // These are all protected under socket_mutex + size_t socket_requests_in_flight_ = 0; + std::thread::id socket_requests_are_from_thread_ = std::thread::id(); + bool socket_should_be_closed_when_request_is_done_ = false; + + // Hostname-IP map + std::map addr_map_; + + // Default headers + Headers default_headers_; + + // Header writer + std::function header_writer_ = + detail::write_headers; + + // Settings + std::string client_cert_path_; + std::string client_key_path_; + + time_t connection_timeout_sec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND; + time_t connection_timeout_usec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND; + time_t read_timeout_sec_ = CPPHTTPLIB_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = CPPHTTPLIB_READ_TIMEOUT_USECOND; + time_t write_timeout_sec_ = CPPHTTPLIB_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_WRITE_TIMEOUT_USECOND; + + std::string basic_auth_username_; + std::string basic_auth_password_; + std::string bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string digest_auth_username_; + std::string digest_auth_password_; +#endif + + bool keep_alive_ = false; + bool follow_location_ = false; + + bool url_encode_ = true; + + int address_family_ = AF_UNSPEC; + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + SocketOptions socket_options_ = nullptr; + + bool compress_ = false; + bool decompress_ = true; + + std::string interface_; + + std::string proxy_host_; + int proxy_port_ = -1; + + std::string proxy_basic_auth_username_; + std::string proxy_basic_auth_password_; + std::string proxy_bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string proxy_digest_auth_username_; + std::string proxy_digest_auth_password_; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + + X509_STORE *ca_cert_store_ = nullptr; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool server_certificate_verification_ = true; +#endif + + Logger logger_; + +private: + bool send_(Request &req, Response &res, Error &error); + Result send_(Request &&req); + + socket_t create_client_socket(Error &error) const; + bool read_response_line(Stream &strm, const Request &req, + Response &res) const; + bool write_request(Stream &strm, Request &req, bool close_connection, + Error &error); + bool redirect(Request &req, Response &res, Error &error); + bool handle_request(Stream &strm, Request &req, Response &res, + bool close_connection, Error &error); + std::unique_ptr send_with_content_provider( + Request &req, const char *body, size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, Error &error); + Result send_with_content_provider( + const std::string &method, const std::string &path, + const Headers &headers, const char *body, size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type); + ContentProviderWithoutLength get_multipart_content_provider( + const std::string &boundary, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) const; + + std::string adjust_host_string(const std::string &host) const; + + virtual bool process_socket(const Socket &socket, + std::function callback); + virtual bool is_ssl() const; +}; + +class Client { +public: + // Universal interface + explicit Client(const std::string &scheme_host_port); + + explicit Client(const std::string &scheme_host_port, + const std::string &client_cert_path, + const std::string &client_key_path); + + // HTTP only interface + explicit Client(const std::string &host, int port); + + explicit Client(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path); + + Client(Client &&) = default; + + ~Client(); + + bool is_valid() const; + + Result Get(const std::string &path); + Result Get(const std::string &path, const Headers &headers); + Result Get(const std::string &path, Progress progress); + Result Get(const std::string &path, const Headers &headers, + Progress progress); + Result Get(const std::string &path, ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver); + Result Get(const std::string &path, ContentReceiver content_receiver, + Progress progress); + Result Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver, Progress progress); + Result Get(const std::string &path, ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver); + Result Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, ContentReceiver content_receiver, + Progress progress); + Result Get(const std::string &path, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress); + + Result Get(const std::string &path, const Params ¶ms, + const Headers &headers, Progress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, + const Headers &headers, ContentReceiver content_receiver, + Progress progress = nullptr); + Result Get(const std::string &path, const Params ¶ms, + const Headers &headers, ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress = nullptr); + + Result Head(const std::string &path); + Result Head(const std::string &path, const Headers &headers); + + Result Post(const std::string &path); + Result Post(const std::string &path, const Headers &headers); + Result Post(const std::string &path, const char *body, size_t content_length, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type); + Result Post(const std::string &path, const std::string &body, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Post(const std::string &path, const Params ¶ms); + Result Post(const std::string &path, const Headers &headers, + const Params ¶ms); + Result Post(const std::string &path, const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items); + Result Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, const std::string &boundary); + Result Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Put(const std::string &path); + Result Put(const std::string &path, const char *body, size_t content_length, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, const char *body, + size_t content_length, const std::string &content_type); + Result Put(const std::string &path, const std::string &body, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Put(const std::string &path, size_t content_length, + ContentProvider content_provider, const std::string &content_type); + Result Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Put(const std::string &path, const Params ¶ms); + Result Put(const std::string &path, const Headers &headers, + const Params ¶ms); + Result Put(const std::string &path, const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items); + Result Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, const std::string &boundary); + Result Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items); + + Result Patch(const std::string &path); + Result Patch(const std::string &path, const char *body, size_t content_length, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type); + Result Patch(const std::string &path, const std::string &body, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + Result Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + size_t content_length, ContentProvider content_provider, + const std::string &content_type); + Result Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type); + + Result Delete(const std::string &path); + Result Delete(const std::string &path, const Headers &headers); + Result Delete(const std::string &path, const char *body, + size_t content_length, const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type); + Result Delete(const std::string &path, const std::string &body, + const std::string &content_type); + Result Delete(const std::string &path, const Headers &headers, + const std::string &body, const std::string &content_type); + + Result Options(const std::string &path); + Result Options(const std::string &path, const Headers &headers); + + bool send(Request &req, Response &res, Error &error); + Result send(const Request &req); + + void stop(); + + std::string host() const; + int port() const; + + size_t is_socket_open() const; + socket_t socket() const; + + void set_hostname_addr_map(std::map addr_map); + + void set_default_headers(Headers headers); + + void + set_header_writer(std::function const &writer); + + void set_address_family(int family); + void set_tcp_nodelay(bool on); + void set_socket_options(SocketOptions socket_options); + + void set_connection_timeout(time_t sec, time_t usec = 0); + template + void + set_connection_timeout(const std::chrono::duration &duration); + + void set_read_timeout(time_t sec, time_t usec = 0); + template + void set_read_timeout(const std::chrono::duration &duration); + + void set_write_timeout(time_t sec, time_t usec = 0); + template + void set_write_timeout(const std::chrono::duration &duration); + + void set_basic_auth(const std::string &username, const std::string &password); + void set_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_digest_auth(const std::string &username, + const std::string &password); +#endif + + void set_keep_alive(bool on); + void set_follow_location(bool on); + + void set_url_encode(bool on); + + void set_compress(bool on); + + void set_decompress(bool on); + + void set_interface(const std::string &intf); + + void set_proxy(const std::string &host, int port); + void set_proxy_basic_auth(const std::string &username, + const std::string &password); + void set_proxy_bearer_token_auth(const std::string &token); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_proxy_digest_auth(const std::string &username, + const std::string &password); +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void enable_server_certificate_verification(bool enabled); +#endif + + void set_logger(Logger logger); + + // SSL +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + void set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path = std::string()); + + void set_ca_cert_store(X509_STORE *ca_cert_store); + void load_ca_cert_store(const char *ca_cert, std::size_t size); + + long get_openssl_verify_result() const; + + SSL_CTX *ssl_context() const; +#endif + +private: + std::unique_ptr cli_; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + bool is_ssl_ = false; +#endif +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLServer : public Server { +public: + SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path = nullptr, + const char *client_ca_cert_dir_path = nullptr, + const char *private_key_password = nullptr); + + SSLServer(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store = nullptr); + + SSLServer( + const std::function &setup_ssl_ctx_callback); + + ~SSLServer() override; + + bool is_valid() const override; + + SSL_CTX *ssl_context() const; + +private: + bool process_and_close_socket(socket_t sock) override; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; +}; + +class SSLClient final : public ClientImpl { +public: + explicit SSLClient(const std::string &host); + + explicit SSLClient(const std::string &host, int port); + + explicit SSLClient(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path, + const std::string &private_key_password = std::string()); + + explicit SSLClient(const std::string &host, int port, X509 *client_cert, + EVP_PKEY *client_key, + const std::string &private_key_password = std::string()); + + ~SSLClient() override; + + bool is_valid() const override; + + void set_ca_cert_store(X509_STORE *ca_cert_store); + void load_ca_cert_store(const char *ca_cert, std::size_t size); + + long get_openssl_verify_result() const; + + SSL_CTX *ssl_context() const; + +private: + bool create_and_connect_socket(Socket &socket, Error &error) override; + void shutdown_ssl(Socket &socket, bool shutdown_gracefully) override; + void shutdown_ssl_impl(Socket &socket, bool shutdown_gracefully); + + bool process_socket(const Socket &socket, + std::function callback) override; + bool is_ssl() const override; + + bool connect_with_proxy(Socket &sock, Response &res, bool &success, + Error &error); + bool initialize_ssl(Socket &socket, Error &error); + + bool load_certs(); + + bool verify_host(X509 *server_cert) const; + bool verify_host_with_subject_alt_name(X509 *server_cert) const; + bool verify_host_with_common_name(X509 *server_cert) const; + bool check_host_name(const char *pattern, size_t pattern_len) const; + + SSL_CTX *ctx_; + std::mutex ctx_mutex_; + std::once_flag initialize_cert_; + + std::vector host_components_; + + long verify_result_ = 0; + + friend class ClientImpl; +}; +#endif + +/* + * Implementation of template methods. + */ + +namespace detail { + +template +inline void duration_to_sec_and_usec(const T &duration, U callback) { + auto sec = std::chrono::duration_cast(duration).count(); + auto usec = std::chrono::duration_cast( + duration - std::chrono::seconds(sec)) + .count(); + callback(static_cast(sec), static_cast(usec)); +} + +inline uint64_t get_header_value_u64(const Headers &headers, + const std::string &key, size_t id, + uint64_t def) { + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { + return std::strtoull(it->second.data(), nullptr, 10); + } + return def; +} + +} // namespace detail + +inline uint64_t Request::get_header_value_u64(const std::string &key, + size_t id) const { + return detail::get_header_value_u64(headers, key, id, 0); +} + +inline uint64_t Response::get_header_value_u64(const std::string &key, + size_t id) const { + return detail::get_header_value_u64(headers, key, id, 0); +} + +template +inline ssize_t Stream::write_format(const char *fmt, const Args &...args) { + const auto bufsiz = 2048; + std::array buf{}; + + auto sn = snprintf(buf.data(), buf.size() - 1, fmt, args...); + if (sn <= 0) { return sn; } + + auto n = static_cast(sn); + + if (n >= buf.size() - 1) { + std::vector glowable_buf(buf.size()); + + while (n >= glowable_buf.size() - 1) { + glowable_buf.resize(glowable_buf.size() * 2); + n = static_cast( + snprintf(&glowable_buf[0], glowable_buf.size() - 1, fmt, args...)); + } + return write(&glowable_buf[0], n); + } else { + return write(buf.data(), n); + } +} + +inline void default_socket_options(socket_t sock) { + int yes = 1; +#ifdef _WIN32 + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, + reinterpret_cast(&yes), sizeof(yes)); + setsockopt(sock, SOL_SOCKET, SO_EXCLUSIVEADDRUSE, + reinterpret_cast(&yes), sizeof(yes)); +#else +#ifdef SO_REUSEPORT + setsockopt(sock, SOL_SOCKET, SO_REUSEPORT, + reinterpret_cast(&yes), sizeof(yes)); +#else + setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, + reinterpret_cast(&yes), sizeof(yes)); +#endif +#endif +} + +inline const char *status_message(int status) { + switch (status) { + case StatusCode::Continue_100: return "Continue"; + case StatusCode::SwitchingProtocol_101: return "Switching Protocol"; + case StatusCode::Processing_102: return "Processing"; + case StatusCode::EarlyHints_103: return "Early Hints"; + case StatusCode::OK_200: return "OK"; + case StatusCode::Created_201: return "Created"; + case StatusCode::Accepted_202: return "Accepted"; + case StatusCode::NonAuthoritativeInformation_203: + return "Non-Authoritative Information"; + case StatusCode::NoContent_204: return "No Content"; + case StatusCode::ResetContent_205: return "Reset Content"; + case StatusCode::PartialContent_206: return "Partial Content"; + case StatusCode::MultiStatus_207: return "Multi-Status"; + case StatusCode::AlreadyReported_208: return "Already Reported"; + case StatusCode::IMUsed_226: return "IM Used"; + case StatusCode::MultipleChoices_300: return "Multiple Choices"; + case StatusCode::MovedPermanently_301: return "Moved Permanently"; + case StatusCode::Found_302: return "Found"; + case StatusCode::SeeOther_303: return "See Other"; + case StatusCode::NotModified_304: return "Not Modified"; + case StatusCode::UseProxy_305: return "Use Proxy"; + case StatusCode::unused_306: return "unused"; + case StatusCode::TemporaryRedirect_307: return "Temporary Redirect"; + case StatusCode::PermanentRedirect_308: return "Permanent Redirect"; + case StatusCode::BadRequest_400: return "Bad Request"; + case StatusCode::Unauthorized_401: return "Unauthorized"; + case StatusCode::PaymentRequired_402: return "Payment Required"; + case StatusCode::Forbidden_403: return "Forbidden"; + case StatusCode::NotFound_404: return "Not Found"; + case StatusCode::MethodNotAllowed_405: return "Method Not Allowed"; + case StatusCode::NotAcceptable_406: return "Not Acceptable"; + case StatusCode::ProxyAuthenticationRequired_407: + return "Proxy Authentication Required"; + case StatusCode::RequestTimeout_408: return "Request Timeout"; + case StatusCode::Conflict_409: return "Conflict"; + case StatusCode::Gone_410: return "Gone"; + case StatusCode::LengthRequired_411: return "Length Required"; + case StatusCode::PreconditionFailed_412: return "Precondition Failed"; + case StatusCode::PayloadTooLarge_413: return "Payload Too Large"; + case StatusCode::UriTooLong_414: return "URI Too Long"; + case StatusCode::UnsupportedMediaType_415: return "Unsupported Media Type"; + case StatusCode::RangeNotSatisfiable_416: return "Range Not Satisfiable"; + case StatusCode::ExpectationFailed_417: return "Expectation Failed"; + case StatusCode::ImATeapot_418: return "I'm a teapot"; + case StatusCode::MisdirectedRequest_421: return "Misdirected Request"; + case StatusCode::UnprocessableContent_422: return "Unprocessable Content"; + case StatusCode::Locked_423: return "Locked"; + case StatusCode::FailedDependency_424: return "Failed Dependency"; + case StatusCode::TooEarly_425: return "Too Early"; + case StatusCode::UpgradeRequired_426: return "Upgrade Required"; + case StatusCode::PreconditionRequired_428: return "Precondition Required"; + case StatusCode::TooManyRequests_429: return "Too Many Requests"; + case StatusCode::RequestHeaderFieldsTooLarge_431: + return "Request Header Fields Too Large"; + case StatusCode::UnavailableForLegalReasons_451: + return "Unavailable For Legal Reasons"; + case StatusCode::NotImplemented_501: return "Not Implemented"; + case StatusCode::BadGateway_502: return "Bad Gateway"; + case StatusCode::ServiceUnavailable_503: return "Service Unavailable"; + case StatusCode::GatewayTimeout_504: return "Gateway Timeout"; + case StatusCode::HttpVersionNotSupported_505: + return "HTTP Version Not Supported"; + case StatusCode::VariantAlsoNegotiates_506: return "Variant Also Negotiates"; + case StatusCode::InsufficientStorage_507: return "Insufficient Storage"; + case StatusCode::LoopDetected_508: return "Loop Detected"; + case StatusCode::NotExtended_510: return "Not Extended"; + case StatusCode::NetworkAuthenticationRequired_511: + return "Network Authentication Required"; + + default: + case StatusCode::InternalServerError_500: return "Internal Server Error"; + } +} + +inline std::string get_bearer_token_auth(const Request &req) { + if (req.has_header("Authorization")) { + static std::string BearerHeaderPrefix = "Bearer "; + return req.get_header_value("Authorization") + .substr(BearerHeaderPrefix.length()); + } + return ""; +} + +template +inline Server & +Server::set_read_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_read_timeout(sec, usec); }); + return *this; +} + +template +inline Server & +Server::set_write_timeout(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_write_timeout(sec, usec); }); + return *this; +} + +template +inline Server & +Server::set_idle_interval(const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_idle_interval(sec, usec); }); + return *this; +} + +inline std::string to_string(const Error error) { + switch (error) { + case Error::Success: return "Success (no error)"; + case Error::Connection: return "Could not establish connection"; + case Error::BindIPAddress: return "Failed to bind IP address"; + case Error::Read: return "Failed to read connection"; + case Error::Write: return "Failed to write connection"; + case Error::ExceedRedirectCount: return "Maximum redirect count exceeded"; + case Error::Canceled: return "Connection handling canceled"; + case Error::SSLConnection: return "SSL connection failed"; + case Error::SSLLoadingCerts: return "SSL certificate loading failed"; + case Error::SSLServerVerification: return "SSL server verification failed"; + case Error::UnsupportedMultipartBoundaryChars: + return "Unsupported HTTP multipart boundary characters"; + case Error::Compression: return "Compression failed"; + case Error::ConnectionTimeout: return "Connection timed out"; + case Error::ProxyConnection: return "Proxy connection failed"; + case Error::Unknown: return "Unknown"; + default: break; + } + + return "Invalid"; +} + +inline std::ostream &operator<<(std::ostream &os, const Error &obj) { + os << to_string(obj); + os << " (" << static_cast::type>(obj) << ')'; + return os; +} + +inline uint64_t Result::get_request_header_value_u64(const std::string &key, + size_t id) const { + return detail::get_header_value_u64(request_headers_, key, id, 0); +} + +template +inline void ClientImpl::set_connection_timeout( + const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t usec) { + set_connection_timeout(sec, usec); + }); +} + +template +inline void ClientImpl::set_read_timeout( + const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_read_timeout(sec, usec); }); +} + +template +inline void ClientImpl::set_write_timeout( + const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec( + duration, [&](time_t sec, time_t usec) { set_write_timeout(sec, usec); }); +} + +template +inline void Client::set_connection_timeout( + const std::chrono::duration &duration) { + cli_->set_connection_timeout(duration); +} + +template +inline void +Client::set_read_timeout(const std::chrono::duration &duration) { + cli_->set_read_timeout(duration); +} + +template +inline void +Client::set_write_timeout(const std::chrono::duration &duration) { + cli_->set_write_timeout(duration); +} + +/* + * Forward declarations and types that will be part of the .h file if split into + * .h + .cc. + */ + +std::string hosted_at(const std::string &hostname); + +void hosted_at(const std::string &hostname, std::vector &addrs); + +std::string append_query_params(const std::string &path, const Params ¶ms); + +std::pair make_range_header(const Ranges &ranges); + +std::pair +make_basic_authentication_header(const std::string &username, + const std::string &password, + bool is_proxy = false); + +namespace detail { + +std::string encode_query_param(const std::string &value); + +std::string decode_url(const std::string &s, bool convert_plus_to_space); + +void read_file(const std::string &path, std::string &out); + +std::string trim_copy(const std::string &s); + +void split(const char *b, const char *e, char d, + std::function fn); + +void split(const char *b, const char *e, char d, size_t m, + std::function fn); + +bool process_client_socket(socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, + std::function callback); + +socket_t create_client_socket( + const std::string &host, const std::string &ip, int port, + int address_family, bool tcp_nodelay, SocketOptions socket_options, + time_t connection_timeout_sec, time_t connection_timeout_usec, + time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, const std::string &intf, Error &error); + +const char *get_header_value(const Headers &headers, const std::string &key, + size_t id = 0, const char *def = nullptr); + +std::string params_to_query_str(const Params ¶ms); + +void parse_query_text(const std::string &s, Params ¶ms); + +bool parse_multipart_boundary(const std::string &content_type, + std::string &boundary); + +bool parse_range_header(const std::string &s, Ranges &ranges); + +int close_socket(socket_t sock); + +ssize_t send_socket(socket_t sock, const void *ptr, size_t size, int flags); + +ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags); + +enum class EncodingType { None = 0, Gzip, Brotli }; + +EncodingType encoding_type(const Request &req, const Response &res); + +class BufferStream final : public Stream { +public: + BufferStream() = default; + ~BufferStream() override = default; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + + const std::string &get_buffer() const; + +private: + std::string buffer; + size_t position = 0; +}; + +class compressor { +public: + virtual ~compressor() = default; + + typedef std::function Callback; + virtual bool compress(const char *data, size_t data_length, bool last, + Callback callback) = 0; +}; + +class decompressor { +public: + virtual ~decompressor() = default; + + virtual bool is_valid() const = 0; + + typedef std::function Callback; + virtual bool decompress(const char *data, size_t data_length, + Callback callback) = 0; +}; + +class nocompressor final : public compressor { +public: + ~nocompressor() override = default; + + bool compress(const char *data, size_t data_length, bool /*last*/, + Callback callback) override; +}; + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +class gzip_compressor final : public compressor { +public: + gzip_compressor(); + ~gzip_compressor() override; + + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override; + +private: + bool is_valid_ = false; + z_stream strm_; +}; + +class gzip_decompressor final : public decompressor { +public: + gzip_decompressor(); + ~gzip_decompressor() override; + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, + Callback callback) override; + +private: + bool is_valid_ = false; + z_stream strm_; +}; +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +class brotli_compressor final : public compressor { +public: + brotli_compressor(); + ~brotli_compressor(); + + bool compress(const char *data, size_t data_length, bool last, + Callback callback) override; + +private: + BrotliEncoderState *state_ = nullptr; +}; + +class brotli_decompressor final : public decompressor { +public: + brotli_decompressor(); + ~brotli_decompressor(); + + bool is_valid() const override; + + bool decompress(const char *data, size_t data_length, + Callback callback) override; + +private: + BrotliDecoderResult decoder_r; + BrotliDecoderState *decoder_s = nullptr; +}; +#endif + +// NOTE: until the read size reaches `fixed_buffer_size`, use `fixed_buffer` +// to store data. The call can set memory on stack for performance. +class stream_line_reader { +public: + stream_line_reader(Stream &strm, char *fixed_buffer, + size_t fixed_buffer_size); + const char *ptr() const; + size_t size() const; + bool end_with_crlf() const; + bool getline(); + +private: + void append(char c); + + Stream &strm_; + char *fixed_buffer_; + const size_t fixed_buffer_size_; + size_t fixed_buffer_used_size_ = 0; + std::string glowable_buffer_; +}; + +class mmap { +public: + mmap(const char *path); + ~mmap(); + + bool open(const char *path); + void close(); + + bool is_open() const; + size_t size() const; + const char *data() const; + +private: +#if defined(_WIN32) + HANDLE hFile_; + HANDLE hMapping_; +#else + int fd_; +#endif + size_t size_; + void *addr_; +}; + +} // namespace detail + +// ---------------------------------------------------------------------------- + +/* + * Implementation that will be part of the .cc file if split into .h + .cc. + */ + +namespace detail { + +inline bool is_hex(char c, int &v) { + if (0x20 <= c && isdigit(c)) { + v = c - '0'; + return true; + } else if ('A' <= c && c <= 'F') { + v = c - 'A' + 10; + return true; + } else if ('a' <= c && c <= 'f') { + v = c - 'a' + 10; + return true; + } + return false; +} + +inline bool from_hex_to_i(const std::string &s, size_t i, size_t cnt, + int &val) { + if (i >= s.size()) { return false; } + + val = 0; + for (; cnt; i++, cnt--) { + if (!s[i]) { return false; } + auto v = 0; + if (is_hex(s[i], v)) { + val = val * 16 + v; + } else { + return false; + } + } + return true; +} + +inline std::string from_i_to_hex(size_t n) { + static const auto charset = "0123456789abcdef"; + std::string ret; + do { + ret = charset[n & 15] + ret; + n >>= 4; + } while (n > 0); + return ret; +} + +inline size_t to_utf8(int code, char *buff) { + if (code < 0x0080) { + buff[0] = static_cast(code & 0x7F); + return 1; + } else if (code < 0x0800) { + buff[0] = static_cast(0xC0 | ((code >> 6) & 0x1F)); + buff[1] = static_cast(0x80 | (code & 0x3F)); + return 2; + } else if (code < 0xD800) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0xE000) { // D800 - DFFF is invalid... + return 0; + } else if (code < 0x10000) { + buff[0] = static_cast(0xE0 | ((code >> 12) & 0xF)); + buff[1] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[2] = static_cast(0x80 | (code & 0x3F)); + return 3; + } else if (code < 0x110000) { + buff[0] = static_cast(0xF0 | ((code >> 18) & 0x7)); + buff[1] = static_cast(0x80 | ((code >> 12) & 0x3F)); + buff[2] = static_cast(0x80 | ((code >> 6) & 0x3F)); + buff[3] = static_cast(0x80 | (code & 0x3F)); + return 4; + } + + // NOTREACHED + return 0; +} + +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c +inline std::string base64_encode(const std::string &in) { + static const auto lookup = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + std::string out; + out.reserve(in.size()); + + auto val = 0; + auto valb = -6; + + for (auto c : in) { + val = (val << 8) + static_cast(c); + valb += 8; + while (valb >= 0) { + out.push_back(lookup[(val >> valb) & 0x3F]); + valb -= 6; + } + } + + if (valb > -6) { out.push_back(lookup[((val << 8) >> (valb + 8)) & 0x3F]); } + + while (out.size() % 4) { + out.push_back('='); + } + + return out; +} + +inline bool is_file(const std::string &path) { +#ifdef _WIN32 + return _access_s(path.c_str(), 0) == 0; +#else + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISREG(st.st_mode); +#endif +} + +inline bool is_dir(const std::string &path) { + struct stat st; + return stat(path.c_str(), &st) >= 0 && S_ISDIR(st.st_mode); +} + +inline bool is_valid_path(const std::string &path) { + size_t level = 0; + size_t i = 0; + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + + while (i < path.size()) { + // Read component + auto beg = i; + while (i < path.size() && path[i] != '/') { + if (path[i] == '\0') { + return false; + } else if (path[i] == '\\') { + return false; + } + i++; + } + + auto len = i - beg; + assert(len > 0); + + if (!path.compare(beg, len, ".")) { + ; + } else if (!path.compare(beg, len, "..")) { + if (level == 0) { return false; } + level--; + } else { + level++; + } + + // Skip slash + while (i < path.size() && path[i] == '/') { + i++; + } + } + + return true; +} + +inline std::string encode_query_param(const std::string &value) { + std::ostringstream escaped; + escaped.fill('0'); + escaped << std::hex; + + for (auto c : value) { + if (std::isalnum(static_cast(c)) || c == '-' || c == '_' || + c == '.' || c == '!' || c == '~' || c == '*' || c == '\'' || c == '(' || + c == ')') { + escaped << c; + } else { + escaped << std::uppercase; + escaped << '%' << std::setw(2) + << static_cast(static_cast(c)); + escaped << std::nouppercase; + } + } + + return escaped.str(); +} + +inline std::string encode_url(const std::string &s) { + std::string result; + result.reserve(s.size()); + + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case ' ': result += "%20"; break; + case '+': result += "%2B"; break; + case '\r': result += "%0D"; break; + case '\n': result += "%0A"; break; + case '\'': result += "%27"; break; + case ',': result += "%2C"; break; + // case ':': result += "%3A"; break; // ok? probably... + case ';': result += "%3B"; break; + default: + auto c = static_cast(s[i]); + if (c >= 0x80) { + result += '%'; + char hex[4]; + auto len = snprintf(hex, sizeof(hex) - 1, "%02X", c); + assert(len == 2); + result.append(hex, static_cast(len)); + } else { + result += s[i]; + } + break; + } + } + + return result; +} + +inline std::string decode_url(const std::string &s, + bool convert_plus_to_space) { + std::string result; + + for (size_t i = 0; i < s.size(); i++) { + if (s[i] == '%' && i + 1 < s.size()) { + if (s[i + 1] == 'u') { + auto val = 0; + if (from_hex_to_i(s, i + 2, 4, val)) { + // 4 digits Unicode codes + char buff[4]; + size_t len = to_utf8(val, buff); + if (len > 0) { result.append(buff, len); } + i += 5; // 'u0000' + } else { + result += s[i]; + } + } else { + auto val = 0; + if (from_hex_to_i(s, i + 1, 2, val)) { + // 2 digits hex codes + result += static_cast(val); + i += 2; // '00' + } else { + result += s[i]; + } + } + } else if (convert_plus_to_space && s[i] == '+') { + result += ' '; + } else { + result += s[i]; + } + } + + return result; +} + +inline void read_file(const std::string &path, std::string &out) { + std::ifstream fs(path, std::ios_base::binary); + fs.seekg(0, std::ios_base::end); + auto size = fs.tellg(); + fs.seekg(0); + out.resize(static_cast(size)); + fs.read(&out[0], static_cast(size)); +} + +inline std::string file_extension(const std::string &path) { + std::smatch m; + static auto re = std::regex("\\.([a-zA-Z0-9]+)$"); + if (std::regex_search(path, m, re)) { return m[1].str(); } + return std::string(); +} + +inline bool is_space_or_tab(char c) { return c == ' ' || c == '\t'; } + +inline std::pair trim(const char *b, const char *e, size_t left, + size_t right) { + while (b + left < e && is_space_or_tab(b[left])) { + left++; + } + while (right > 0 && is_space_or_tab(b[right - 1])) { + right--; + } + return std::make_pair(left, right); +} + +inline std::string trim_copy(const std::string &s) { + auto r = trim(s.data(), s.data() + s.size(), 0, s.size()); + return s.substr(r.first, r.second - r.first); +} + +inline std::string trim_double_quotes_copy(const std::string &s) { + if (s.length() >= 2 && s.front() == '"' && s.back() == '"') { + return s.substr(1, s.size() - 2); + } + return s; +} + +inline void split(const char *b, const char *e, char d, + std::function fn) { + return split(b, e, d, (std::numeric_limits::max)(), std::move(fn)); +} + +inline void split(const char *b, const char *e, char d, size_t m, + std::function fn) { + size_t i = 0; + size_t beg = 0; + size_t count = 1; + + while (e ? (b + i < e) : (b[i] != '\0')) { + if (b[i] == d && count < m) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { fn(&b[r.first], &b[r.second]); } + beg = i + 1; + count++; + } + i++; + } + + if (i) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { fn(&b[r.first], &b[r.second]); } + } +} + +inline stream_line_reader::stream_line_reader(Stream &strm, char *fixed_buffer, + size_t fixed_buffer_size) + : strm_(strm), fixed_buffer_(fixed_buffer), + fixed_buffer_size_(fixed_buffer_size) {} + +inline const char *stream_line_reader::ptr() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_; + } else { + return glowable_buffer_.data(); + } +} + +inline size_t stream_line_reader::size() const { + if (glowable_buffer_.empty()) { + return fixed_buffer_used_size_; + } else { + return glowable_buffer_.size(); + } +} + +inline bool stream_line_reader::end_with_crlf() const { + auto end = ptr() + size(); + return size() >= 2 && end[-2] == '\r' && end[-1] == '\n'; +} + +inline bool stream_line_reader::getline() { + fixed_buffer_used_size_ = 0; + glowable_buffer_.clear(); + + for (size_t i = 0;; i++) { + char byte; + auto n = strm_.read(&byte, 1); + + if (n < 0) { + return false; + } else if (n == 0) { + if (i == 0) { + return false; + } else { + break; + } + } + + append(byte); + + if (byte == '\n') { break; } + } + + return true; +} + +inline void stream_line_reader::append(char c) { + if (fixed_buffer_used_size_ < fixed_buffer_size_ - 1) { + fixed_buffer_[fixed_buffer_used_size_++] = c; + fixed_buffer_[fixed_buffer_used_size_] = '\0'; + } else { + if (glowable_buffer_.empty()) { + assert(fixed_buffer_[fixed_buffer_used_size_] == '\0'); + glowable_buffer_.assign(fixed_buffer_, fixed_buffer_used_size_); + } + glowable_buffer_ += c; + } +} + +inline mmap::mmap(const char *path) +#if defined(_WIN32) + : hFile_(NULL), hMapping_(NULL) +#else + : fd_(-1) +#endif + , + size_(0), addr_(nullptr) { + open(path); +} + +inline mmap::~mmap() { close(); } + +inline bool mmap::open(const char *path) { + close(); + +#if defined(_WIN32) + std::wstring wpath; + for (size_t i = 0; i < strlen(path); i++) { + wpath += path[i]; + } + + hFile_ = ::CreateFile2(wpath.c_str(), GENERIC_READ, FILE_SHARE_READ, + OPEN_EXISTING, NULL); + + if (hFile_ == INVALID_HANDLE_VALUE) { return false; } + + LARGE_INTEGER size{}; + if (!::GetFileSizeEx(hFile_, &size)) { return false; } + size_ = static_cast(size.QuadPart); + + hMapping_ = + ::CreateFileMappingFromApp(hFile_, NULL, PAGE_READONLY, size_, NULL); + + if (hMapping_ == NULL) { + close(); + return false; + } + + addr_ = ::MapViewOfFileFromApp(hMapping_, FILE_MAP_READ, 0, 0); +#else + fd_ = ::open(path, O_RDONLY); + if (fd_ == -1) { return false; } + + struct stat sb; + if (fstat(fd_, &sb) == -1) { + close(); + return false; + } + size_ = static_cast(sb.st_size); + + addr_ = ::mmap(NULL, size_, PROT_READ, MAP_PRIVATE, fd_, 0); +#endif + + if (addr_ == nullptr) { + close(); + return false; + } + + return true; +} + +inline bool mmap::is_open() const { return addr_ != nullptr; } + +inline size_t mmap::size() const { return size_; } + +inline const char *mmap::data() const { + return static_cast(addr_); +} + +inline void mmap::close() { +#if defined(_WIN32) + if (addr_) { + ::UnmapViewOfFile(addr_); + addr_ = nullptr; + } + + if (hMapping_) { + ::CloseHandle(hMapping_); + hMapping_ = NULL; + } + + if (hFile_ != INVALID_HANDLE_VALUE) { + ::CloseHandle(hFile_); + hFile_ = INVALID_HANDLE_VALUE; + } +#else + if (addr_ != nullptr) { + munmap(addr_, size_); + addr_ = nullptr; + } + + if (fd_ != -1) { + ::close(fd_); + fd_ = -1; + } +#endif + size_ = 0; +} +inline int close_socket(socket_t sock) { +#ifdef _WIN32 + return closesocket(sock); +#else + return close(sock); +#endif +} + +template inline ssize_t handle_EINTR(T fn) { + ssize_t res = 0; + while (true) { + res = fn(); + if (res < 0 && errno == EINTR) { continue; } + break; + } + return res; +} + +inline ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags) { + return handle_EINTR([&]() { + return recv(sock, +#ifdef _WIN32 + static_cast(ptr), static_cast(size), +#else + ptr, size, +#endif + flags); + }); +} + +inline ssize_t send_socket(socket_t sock, const void *ptr, size_t size, + int flags) { + return handle_EINTR([&]() { + return send(sock, +#ifdef _WIN32 + static_cast(ptr), static_cast(size), +#else + ptr, size, +#endif + flags); + }); +} + +inline ssize_t select_read(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); +#else +#ifndef _WIN32 + if (sock >= FD_SETSIZE) { return -1; } +#endif + + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + return handle_EINTR([&]() { + return select(static_cast(sock + 1), &fds, nullptr, nullptr, &tv); + }); +#endif +} + +inline ssize_t select_write(socket_t sock, time_t sec, time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + return handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); +#else +#ifndef _WIN32 + if (sock >= FD_SETSIZE) { return -1; } +#endif + + fd_set fds; + FD_ZERO(&fds); + FD_SET(sock, &fds); + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + return handle_EINTR([&]() { + return select(static_cast(sock + 1), nullptr, &fds, nullptr, &tv); + }); +#endif +} + +inline Error wait_until_socket_is_ready(socket_t sock, time_t sec, + time_t usec) { +#ifdef CPPHTTPLIB_USE_POLL + struct pollfd pfd_read; + pfd_read.fd = sock; + pfd_read.events = POLLIN | POLLOUT; + + auto timeout = static_cast(sec * 1000 + usec / 1000); + + auto poll_res = handle_EINTR([&]() { return poll(&pfd_read, 1, timeout); }); + + if (poll_res == 0) { return Error::ConnectionTimeout; } + + if (poll_res > 0 && pfd_read.revents & (POLLIN | POLLOUT)) { + auto error = 0; + socklen_t len = sizeof(error); + auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len); + auto successful = res >= 0 && !error; + return successful ? Error::Success : Error::Connection; + } + + return Error::Connection; +#else +#ifndef _WIN32 + if (sock >= FD_SETSIZE) { return Error::Connection; } +#endif + + fd_set fdsr; + FD_ZERO(&fdsr); + FD_SET(sock, &fdsr); + + auto fdsw = fdsr; + auto fdse = fdsr; + + timeval tv; + tv.tv_sec = static_cast(sec); + tv.tv_usec = static_cast(usec); + + auto ret = handle_EINTR([&]() { + return select(static_cast(sock + 1), &fdsr, &fdsw, &fdse, &tv); + }); + + if (ret == 0) { return Error::ConnectionTimeout; } + + if (ret > 0 && (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { + auto error = 0; + socklen_t len = sizeof(error); + auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, + reinterpret_cast(&error), &len); + auto successful = res >= 0 && !error; + return successful ? Error::Success : Error::Connection; + } + return Error::Connection; +#endif +} + +inline bool is_socket_alive(socket_t sock) { + const auto val = detail::select_read(sock, 0, 0); + if (val == 0) { + return true; + } else if (val < 0 && errno == EBADF) { + return false; + } + char buf[1]; + return detail::read_socket(sock, &buf[0], sizeof(buf), MSG_PEEK) > 0; +} + +class SocketStream final : public Stream { +public: + SocketStream(socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, + time_t write_timeout_sec, time_t write_timeout_usec); + ~SocketStream() override; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + +private: + socket_t sock_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; + + std::vector read_buff_; + size_t read_buff_off_ = 0; + size_t read_buff_content_size_ = 0; + + static const size_t read_buff_size_ = 1024l * 4; +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +class SSLSocketStream final : public Stream { +public: + SSLSocketStream(socket_t sock, SSL *ssl, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec); + ~SSLSocketStream() override; + + bool is_readable() const override; + bool is_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + +private: + socket_t sock_; + SSL *ssl_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; +}; +#endif + +inline bool keep_alive(socket_t sock, time_t keep_alive_timeout_sec) { + using namespace std::chrono; + auto start = steady_clock::now(); + while (true) { + auto val = select_read(sock, 0, 10000); + if (val < 0) { + return false; + } else if (val == 0) { + auto current = steady_clock::now(); + auto duration = duration_cast(current - start); + auto timeout = keep_alive_timeout_sec * 1000; + if (duration.count() > timeout) { return false; } + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } else { + return true; + } + } +} + +template +inline bool +process_server_socket_core(const std::atomic &svr_sock, socket_t sock, + size_t keep_alive_max_count, + time_t keep_alive_timeout_sec, T callback) { + assert(keep_alive_max_count > 0); + auto ret = false; + auto count = keep_alive_max_count; + while (svr_sock != INVALID_SOCKET && count > 0 && + keep_alive(sock, keep_alive_timeout_sec)) { + auto close_connection = count == 1; + auto connection_closed = false; + ret = callback(close_connection, connection_closed); + if (!ret || connection_closed) { break; } + count--; + } + return ret; +} + +template +inline bool +process_server_socket(const std::atomic &svr_sock, socket_t sock, + size_t keep_alive_max_count, + time_t keep_alive_timeout_sec, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { + return process_server_socket_core( + svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); +} + +inline bool process_client_socket(socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec, + std::function callback) { + SocketStream strm(sock, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm); +} + +inline int shutdown_socket(socket_t sock) { +#ifdef _WIN32 + return shutdown(sock, SD_BOTH); +#else + return shutdown(sock, SHUT_RDWR); +#endif +} + +template +socket_t create_socket(const std::string &host, const std::string &ip, int port, + int address_family, int socket_flags, bool tcp_nodelay, + SocketOptions socket_options, + BindOrConnect bind_or_connect) { + // Get address info + const char *node = nullptr; + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (!ip.empty()) { + node = ip.c_str(); + // Ask getaddrinfo to convert IP in c-string to address + hints.ai_family = AF_UNSPEC; + hints.ai_flags = AI_NUMERICHOST; + } else { + if (!host.empty()) { node = host.c_str(); } + hints.ai_family = address_family; + hints.ai_flags = socket_flags; + } + +#ifndef _WIN32 + if (hints.ai_family == AF_UNIX) { + const auto addrlen = host.length(); + if (addrlen > sizeof(sockaddr_un::sun_path)) { return INVALID_SOCKET; } + + auto sock = socket(hints.ai_family, hints.ai_socktype, hints.ai_protocol); + if (sock != INVALID_SOCKET) { + sockaddr_un addr{}; + addr.sun_family = AF_UNIX; + std::copy(host.begin(), host.end(), addr.sun_path); + + hints.ai_addr = reinterpret_cast(&addr); + hints.ai_addrlen = static_cast( + sizeof(addr) - sizeof(addr.sun_path) + addrlen); + + fcntl(sock, F_SETFD, FD_CLOEXEC); + if (socket_options) { socket_options(sock); } + + if (!bind_or_connect(sock, hints)) { + close_socket(sock); + sock = INVALID_SOCKET; + } + } + return sock; + } +#endif + + auto service = std::to_string(port); + + if (getaddrinfo(node, service.c_str(), &hints, &result)) { +#if defined __linux__ && !defined __ANDROID__ + res_init(); +#endif + return INVALID_SOCKET; + } + + for (auto rp = result; rp; rp = rp->ai_next) { + // Create a socket +#ifdef _WIN32 + auto sock = + WSASocketW(rp->ai_family, rp->ai_socktype, rp->ai_protocol, nullptr, 0, + WSA_FLAG_NO_HANDLE_INHERIT | WSA_FLAG_OVERLAPPED); + /** + * Since the WSA_FLAG_NO_HANDLE_INHERIT is only supported on Windows 7 SP1 + * and above the socket creation fails on older Windows Systems. + * + * Let's try to create a socket the old way in this case. + * + * Reference: + * https://docs.microsoft.com/en-us/windows/win32/api/winsock2/nf-winsock2-wsasocketa + * + * WSA_FLAG_NO_HANDLE_INHERIT: + * This flag is supported on Windows 7 with SP1, Windows Server 2008 R2 with + * SP1, and later + * + */ + if (sock == INVALID_SOCKET) { + sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); + } +#else + auto sock = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); +#endif + if (sock == INVALID_SOCKET) { continue; } + +#ifndef _WIN32 + if (fcntl(sock, F_SETFD, FD_CLOEXEC) == -1) { + close_socket(sock); + continue; + } +#endif + + if (tcp_nodelay) { + auto yes = 1; +#ifdef _WIN32 + setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, + reinterpret_cast(&yes), sizeof(yes)); +#else + setsockopt(sock, IPPROTO_TCP, TCP_NODELAY, + reinterpret_cast(&yes), sizeof(yes)); +#endif + } + + if (socket_options) { socket_options(sock); } + + if (rp->ai_family == AF_INET6) { + auto no = 0; +#ifdef _WIN32 + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, + reinterpret_cast(&no), sizeof(no)); +#else + setsockopt(sock, IPPROTO_IPV6, IPV6_V6ONLY, + reinterpret_cast(&no), sizeof(no)); +#endif + } + + // bind or connect + if (bind_or_connect(sock, *rp)) { + freeaddrinfo(result); + return sock; + } + + close_socket(sock); + } + + freeaddrinfo(result); + return INVALID_SOCKET; +} + +inline void set_nonblocking(socket_t sock, bool nonblocking) { +#ifdef _WIN32 + auto flags = nonblocking ? 1UL : 0UL; + ioctlsocket(sock, FIONBIO, &flags); +#else + auto flags = fcntl(sock, F_GETFL, 0); + fcntl(sock, F_SETFL, + nonblocking ? (flags | O_NONBLOCK) : (flags & (~O_NONBLOCK))); +#endif +} + +inline bool is_connection_error() { +#ifdef _WIN32 + return WSAGetLastError() != WSAEWOULDBLOCK; +#else + return errno != EINPROGRESS; +#endif +} + +inline bool bind_ip_address(socket_t sock, const std::string &host) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (getaddrinfo(host.c_str(), "0", &hints, &result)) { return false; } + + auto ret = false; + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &ai = *rp; + if (!::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + ret = true; + break; + } + } + + freeaddrinfo(result); + return ret; +} + +#if !defined _WIN32 && !defined ANDROID && !defined _AIX && !defined __MVS__ +#define USE_IF2IP +#endif + +#ifdef USE_IF2IP +inline std::string if2ip(int address_family, const std::string &ifn) { + struct ifaddrs *ifap; + getifaddrs(&ifap); + std::string addr_candidate; + for (auto ifa = ifap; ifa; ifa = ifa->ifa_next) { + if (ifa->ifa_addr && ifn == ifa->ifa_name && + (AF_UNSPEC == address_family || + ifa->ifa_addr->sa_family == address_family)) { + if (ifa->ifa_addr->sa_family == AF_INET) { + auto sa = reinterpret_cast(ifa->ifa_addr); + char buf[INET_ADDRSTRLEN]; + if (inet_ntop(AF_INET, &sa->sin_addr, buf, INET_ADDRSTRLEN)) { + freeifaddrs(ifap); + return std::string(buf, INET_ADDRSTRLEN); + } + } else if (ifa->ifa_addr->sa_family == AF_INET6) { + auto sa = reinterpret_cast(ifa->ifa_addr); + if (!IN6_IS_ADDR_LINKLOCAL(&sa->sin6_addr)) { + char buf[INET6_ADDRSTRLEN] = {}; + if (inet_ntop(AF_INET6, &sa->sin6_addr, buf, INET6_ADDRSTRLEN)) { + // equivalent to mac's IN6_IS_ADDR_UNIQUE_LOCAL + auto s6_addr_head = sa->sin6_addr.s6_addr[0]; + if (s6_addr_head == 0xfc || s6_addr_head == 0xfd) { + addr_candidate = std::string(buf, INET6_ADDRSTRLEN); + } else { + freeifaddrs(ifap); + return std::string(buf, INET6_ADDRSTRLEN); + } + } + } + } + } + } + freeifaddrs(ifap); + return addr_candidate; +} +#endif + +inline socket_t create_client_socket( + const std::string &host, const std::string &ip, int port, + int address_family, bool tcp_nodelay, SocketOptions socket_options, + time_t connection_timeout_sec, time_t connection_timeout_usec, + time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, const std::string &intf, Error &error) { + auto sock = create_socket( + host, ip, port, address_family, 0, tcp_nodelay, std::move(socket_options), + [&](socket_t sock2, struct addrinfo &ai) -> bool { + if (!intf.empty()) { +#ifdef USE_IF2IP + auto ip_from_if = if2ip(address_family, intf); + if (ip_from_if.empty()) { ip_from_if = intf; } + if (!bind_ip_address(sock2, ip_from_if)) { + error = Error::BindIPAddress; + return false; + } +#endif + } + + set_nonblocking(sock2, true); + + auto ret = + ::connect(sock2, ai.ai_addr, static_cast(ai.ai_addrlen)); + + if (ret < 0) { + if (is_connection_error()) { + error = Error::Connection; + return false; + } + error = wait_until_socket_is_ready(sock2, connection_timeout_sec, + connection_timeout_usec); + if (error != Error::Success) { return false; } + } + + set_nonblocking(sock2, false); + + { +#ifdef _WIN32 + auto timeout = static_cast(read_timeout_sec * 1000 + + read_timeout_usec / 1000); + setsockopt(sock2, SOL_SOCKET, SO_RCVTIMEO, + reinterpret_cast(&timeout), sizeof(timeout)); +#else + timeval tv; + tv.tv_sec = static_cast(read_timeout_sec); + tv.tv_usec = static_cast(read_timeout_usec); + setsockopt(sock2, SOL_SOCKET, SO_RCVTIMEO, + reinterpret_cast(&tv), sizeof(tv)); +#endif + } + { + +#ifdef _WIN32 + auto timeout = static_cast(write_timeout_sec * 1000 + + write_timeout_usec / 1000); + setsockopt(sock2, SOL_SOCKET, SO_SNDTIMEO, + reinterpret_cast(&timeout), sizeof(timeout)); +#else + timeval tv; + tv.tv_sec = static_cast(write_timeout_sec); + tv.tv_usec = static_cast(write_timeout_usec); + setsockopt(sock2, SOL_SOCKET, SO_SNDTIMEO, + reinterpret_cast(&tv), sizeof(tv)); +#endif + } + + error = Error::Success; + return true; + }); + + if (sock != INVALID_SOCKET) { + error = Error::Success; + } else { + if (error == Error::Success) { error = Error::Connection; } + } + + return sock; +} + +inline bool get_ip_and_port(const struct sockaddr_storage &addr, + socklen_t addr_len, std::string &ip, int &port) { + if (addr.ss_family == AF_INET) { + port = ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + port = + ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return false; + } + + std::array ipstr{}; + if (getnameinfo(reinterpret_cast(&addr), addr_len, + ipstr.data(), static_cast(ipstr.size()), nullptr, + 0, NI_NUMERICHOST)) { + return false; + } + + ip = ipstr.data(); + return true; +} + +inline void get_local_ip_and_port(socket_t sock, std::string &ip, int &port) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (!getsockname(sock, reinterpret_cast(&addr), + &addr_len)) { + get_ip_and_port(addr, addr_len, ip, port); + } +} + +inline void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + + if (!getpeername(sock, reinterpret_cast(&addr), + &addr_len)) { +#ifndef _WIN32 + if (addr.ss_family == AF_UNIX) { +#if defined(__linux__) + struct ucred ucred; + socklen_t len = sizeof(ucred); + if (getsockopt(sock, SOL_SOCKET, SO_PEERCRED, &ucred, &len) == 0) { + port = ucred.pid; + } +#elif defined(SOL_LOCAL) && defined(SO_PEERPID) // __APPLE__ + pid_t pid; + socklen_t len = sizeof(pid); + if (getsockopt(sock, SOL_LOCAL, SO_PEERPID, &pid, &len) == 0) { + port = pid; + } +#endif + return; + } +#endif + get_ip_and_port(addr, addr_len, ip, port); + } +} + +inline constexpr unsigned int str2tag_core(const char *s, size_t l, + unsigned int h) { + return (l == 0) + ? h + : str2tag_core( + s + 1, l - 1, + // Unsets the 6 high bits of h, therefore no overflow happens + (((std::numeric_limits::max)() >> 6) & + h * 33) ^ + static_cast(*s)); +} + +inline unsigned int str2tag(const std::string &s) { + return str2tag_core(s.data(), s.size(), 0); +} + +namespace udl { + +inline constexpr unsigned int operator"" _t(const char *s, size_t l) { + return str2tag_core(s, l, 0); +} + +} // namespace udl + +inline std::string +find_content_type(const std::string &path, + const std::map &user_data, + const std::string &default_content_type) { + auto ext = file_extension(path); + + auto it = user_data.find(ext); + if (it != user_data.end()) { return it->second; } + + using udl::operator""_t; + + switch (str2tag(ext)) { + default: return default_content_type; + + case "css"_t: return "text/css"; + case "csv"_t: return "text/csv"; + case "htm"_t: + case "html"_t: return "text/html"; + case "js"_t: + case "mjs"_t: return "text/javascript"; + case "txt"_t: return "text/plain"; + case "vtt"_t: return "text/vtt"; + + case "apng"_t: return "image/apng"; + case "avif"_t: return "image/avif"; + case "bmp"_t: return "image/bmp"; + case "gif"_t: return "image/gif"; + case "png"_t: return "image/png"; + case "svg"_t: return "image/svg+xml"; + case "webp"_t: return "image/webp"; + case "ico"_t: return "image/x-icon"; + case "tif"_t: return "image/tiff"; + case "tiff"_t: return "image/tiff"; + case "jpg"_t: + case "jpeg"_t: return "image/jpeg"; + + case "mp4"_t: return "video/mp4"; + case "mpeg"_t: return "video/mpeg"; + case "webm"_t: return "video/webm"; + + case "mp3"_t: return "audio/mp3"; + case "mpga"_t: return "audio/mpeg"; + case "weba"_t: return "audio/webm"; + case "wav"_t: return "audio/wave"; + + case "otf"_t: return "font/otf"; + case "ttf"_t: return "font/ttf"; + case "woff"_t: return "font/woff"; + case "woff2"_t: return "font/woff2"; + + case "7z"_t: return "application/x-7z-compressed"; + case "atom"_t: return "application/atom+xml"; + case "pdf"_t: return "application/pdf"; + case "json"_t: return "application/json"; + case "rss"_t: return "application/rss+xml"; + case "tar"_t: return "application/x-tar"; + case "xht"_t: + case "xhtml"_t: return "application/xhtml+xml"; + case "xslt"_t: return "application/xslt+xml"; + case "xml"_t: return "application/xml"; + case "gz"_t: return "application/gzip"; + case "zip"_t: return "application/zip"; + case "wasm"_t: return "application/wasm"; + } +} + +inline bool can_compress_content_type(const std::string &content_type) { + using udl::operator""_t; + + auto tag = str2tag(content_type); + + switch (tag) { + case "image/svg+xml"_t: + case "application/javascript"_t: + case "application/json"_t: + case "application/xml"_t: + case "application/protobuf"_t: + case "application/xhtml+xml"_t: return true; + + default: + return !content_type.rfind("text/", 0) && tag != "text/event-stream"_t; + } +} + +inline EncodingType encoding_type(const Request &req, const Response &res) { + auto ret = + detail::can_compress_content_type(res.get_header_value("Content-Type")); + if (!ret) { return EncodingType::None; } + + const auto &s = req.get_header_value("Accept-Encoding"); + (void)(s); + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + // TODO: 'Accept-Encoding' has br, not br;q=0 + ret = s.find("br") != std::string::npos; + if (ret) { return EncodingType::Brotli; } +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + // TODO: 'Accept-Encoding' has gzip, not gzip;q=0 + ret = s.find("gzip") != std::string::npos; + if (ret) { return EncodingType::Gzip; } +#endif + + return EncodingType::None; +} + +inline bool nocompressor::compress(const char *data, size_t data_length, + bool /*last*/, Callback callback) { + if (!data_length) { return true; } + return callback(data, data_length); +} + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT +inline gzip_compressor::gzip_compressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + is_valid_ = deflateInit2(&strm_, Z_DEFAULT_COMPRESSION, Z_DEFLATED, 31, 8, + Z_DEFAULT_STRATEGY) == Z_OK; +} + +inline gzip_compressor::~gzip_compressor() { deflateEnd(&strm_); } + +inline bool gzip_compressor::compress(const char *data, size_t data_length, + bool last, Callback callback) { + assert(is_valid_); + + do { + constexpr size_t max_avail_in = + (std::numeric_limits::max)(); + + strm_.avail_in = static_cast( + (std::min)(data_length, max_avail_in)); + strm_.next_in = const_cast(reinterpret_cast(data)); + + data_length -= strm_.avail_in; + data += strm_.avail_in; + + auto flush = (last && data_length == 0) ? Z_FINISH : Z_NO_FLUSH; + auto ret = Z_OK; + + std::array buff{}; + do { + strm_.avail_out = static_cast(buff.size()); + strm_.next_out = reinterpret_cast(buff.data()); + + ret = deflate(&strm_, flush); + if (ret == Z_STREAM_ERROR) { return false; } + + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } while (strm_.avail_out == 0); + + assert((flush == Z_FINISH && ret == Z_STREAM_END) || + (flush == Z_NO_FLUSH && ret == Z_OK)); + assert(strm_.avail_in == 0); + } while (data_length > 0); + + return true; +} + +inline gzip_decompressor::gzip_decompressor() { + std::memset(&strm_, 0, sizeof(strm_)); + strm_.zalloc = Z_NULL; + strm_.zfree = Z_NULL; + strm_.opaque = Z_NULL; + + // 15 is the value of wbits, which should be at the maximum possible value + // to ensure that any gzip stream can be decoded. The offset of 32 specifies + // that the stream type should be automatically detected either gzip or + // deflate. + is_valid_ = inflateInit2(&strm_, 32 + 15) == Z_OK; +} + +inline gzip_decompressor::~gzip_decompressor() { inflateEnd(&strm_); } + +inline bool gzip_decompressor::is_valid() const { return is_valid_; } + +inline bool gzip_decompressor::decompress(const char *data, size_t data_length, + Callback callback) { + assert(is_valid_); + + auto ret = Z_OK; + + do { + constexpr size_t max_avail_in = + (std::numeric_limits::max)(); + + strm_.avail_in = static_cast( + (std::min)(data_length, max_avail_in)); + strm_.next_in = const_cast(reinterpret_cast(data)); + + data_length -= strm_.avail_in; + data += strm_.avail_in; + + std::array buff{}; + while (strm_.avail_in > 0 && ret == Z_OK) { + strm_.avail_out = static_cast(buff.size()); + strm_.next_out = reinterpret_cast(buff.data()); + + ret = inflate(&strm_, Z_NO_FLUSH); + + assert(ret != Z_STREAM_ERROR); + switch (ret) { + case Z_NEED_DICT: + case Z_DATA_ERROR: + case Z_MEM_ERROR: inflateEnd(&strm_); return false; + } + + if (!callback(buff.data(), buff.size() - strm_.avail_out)) { + return false; + } + } + + if (ret != Z_OK && ret != Z_STREAM_END) { return false; } + + } while (data_length > 0); + + return true; +} +#endif + +#ifdef CPPHTTPLIB_BROTLI_SUPPORT +inline brotli_compressor::brotli_compressor() { + state_ = BrotliEncoderCreateInstance(nullptr, nullptr, nullptr); +} + +inline brotli_compressor::~brotli_compressor() { + BrotliEncoderDestroyInstance(state_); +} + +inline bool brotli_compressor::compress(const char *data, size_t data_length, + bool last, Callback callback) { + std::array buff{}; + + auto operation = last ? BROTLI_OPERATION_FINISH : BROTLI_OPERATION_PROCESS; + auto available_in = data_length; + auto next_in = reinterpret_cast(data); + + for (;;) { + if (last) { + if (BrotliEncoderIsFinished(state_)) { break; } + } else { + if (!available_in) { break; } + } + + auto available_out = buff.size(); + auto next_out = buff.data(); + + if (!BrotliEncoderCompressStream(state_, operation, &available_in, &next_in, + &available_out, &next_out, nullptr)) { + return false; + } + + auto output_bytes = buff.size() - available_out; + if (output_bytes) { + callback(reinterpret_cast(buff.data()), output_bytes); + } + } + + return true; +} + +inline brotli_decompressor::brotli_decompressor() { + decoder_s = BrotliDecoderCreateInstance(0, 0, 0); + decoder_r = decoder_s ? BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT + : BROTLI_DECODER_RESULT_ERROR; +} + +inline brotli_decompressor::~brotli_decompressor() { + if (decoder_s) { BrotliDecoderDestroyInstance(decoder_s); } +} + +inline bool brotli_decompressor::is_valid() const { return decoder_s; } + +inline bool brotli_decompressor::decompress(const char *data, + size_t data_length, + Callback callback) { + if (decoder_r == BROTLI_DECODER_RESULT_SUCCESS || + decoder_r == BROTLI_DECODER_RESULT_ERROR) { + return 0; + } + + auto next_in = reinterpret_cast(data); + size_t avail_in = data_length; + size_t total_out; + + decoder_r = BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT; + + std::array buff{}; + while (decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT) { + char *next_out = buff.data(); + size_t avail_out = buff.size(); + + decoder_r = BrotliDecoderDecompressStream( + decoder_s, &avail_in, &next_in, &avail_out, + reinterpret_cast(&next_out), &total_out); + + if (decoder_r == BROTLI_DECODER_RESULT_ERROR) { return false; } + + if (!callback(buff.data(), buff.size() - avail_out)) { return false; } + } + + return decoder_r == BROTLI_DECODER_RESULT_SUCCESS || + decoder_r == BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT; +} +#endif + +inline bool has_header(const Headers &headers, const std::string &key) { + return headers.find(key) != headers.end(); +} + +inline const char *get_header_value(const Headers &headers, + const std::string &key, size_t id, + const char *def) { + auto rng = headers.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return it->second.c_str(); } + return def; +} + +inline bool compare_case_ignore(const std::string &a, const std::string &b) { + if (a.size() != b.size()) { return false; } + for (size_t i = 0; i < b.size(); i++) { + if (::tolower(a[i]) != ::tolower(b[i])) { return false; } + } + return true; +} + +template +inline bool parse_header(const char *beg, const char *end, T fn) { + // Skip trailing spaces and tabs. + while (beg < end && is_space_or_tab(end[-1])) { + end--; + } + + auto p = beg; + while (p < end && *p != ':') { + p++; + } + + if (p == end) { return false; } + + auto key_end = p; + + if (*p++ != ':') { return false; } + + while (p < end && is_space_or_tab(*p)) { + p++; + } + + if (p < end) { + auto key_len = key_end - beg; + if (!key_len) { return false; } + + auto key = std::string(beg, key_end); + auto val = compare_case_ignore(key, "Location") + ? std::string(p, end) + : decode_url(std::string(p, end), false); + fn(std::move(key), std::move(val)); + return true; + } + + return false; +} + +inline bool read_headers(Stream &strm, Headers &headers) { + const auto bufsiz = 2048; + char buf[bufsiz]; + stream_line_reader line_reader(strm, buf, bufsiz); + + for (;;) { + if (!line_reader.getline()) { return false; } + + // Check if the line ends with CRLF. + auto line_terminator_len = 2; + if (line_reader.end_with_crlf()) { + // Blank line indicates end of headers. + if (line_reader.size() == 2) { break; } +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + } else { + // Blank line indicates end of headers. + if (line_reader.size() == 1) { break; } + line_terminator_len = 1; + } +#else + } else { + continue; // Skip invalid line. + } +#endif + + if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } + + // Exclude line terminator + auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; + + parse_header(line_reader.ptr(), end, + [&](std::string &&key, std::string &&val) { + headers.emplace(std::move(key), std::move(val)); + }); + } + + return true; +} + +inline bool read_content_with_length(Stream &strm, uint64_t len, + Progress progress, + ContentReceiverWithProgress out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return false; } + + if (!out(buf, static_cast(n), r, len)) { return false; } + r += static_cast(n); + + if (progress) { + if (!progress(r, len)) { return false; } + } + } + + return true; +} + +inline void skip_content_with_length(Stream &strm, uint64_t len) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + if (n <= 0) { return; } + r += static_cast(n); + } +} + +inline bool read_content_without_length(Stream &strm, + ContentReceiverWithProgress out) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + uint64_t r = 0; + for (;;) { + auto n = strm.read(buf, CPPHTTPLIB_RECV_BUFSIZ); + if (n <= 0) { return true; } + + if (!out(buf, static_cast(n), r, 0)) { return false; } + r += static_cast(n); + } + + return true; +} + +template +inline bool read_content_chunked(Stream &strm, T &x, + ContentReceiverWithProgress out) { + const auto bufsiz = 16; + char buf[bufsiz]; + + stream_line_reader line_reader(strm, buf, bufsiz); + + if (!line_reader.getline()) { return false; } + + unsigned long chunk_len; + while (true) { + char *end_ptr; + + chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16); + + if (end_ptr == line_reader.ptr()) { return false; } + if (chunk_len == ULONG_MAX) { return false; } + + if (chunk_len == 0) { break; } + + if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + return false; + } + + if (!line_reader.getline()) { return false; } + + if (strcmp(line_reader.ptr(), "\r\n") != 0) { return false; } + + if (!line_reader.getline()) { return false; } + } + + assert(chunk_len == 0); + + // Trailer + if (!line_reader.getline()) { return false; } + + while (strcmp(line_reader.ptr(), "\r\n") != 0) { + if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } + + // Exclude line terminator + constexpr auto line_terminator_len = 2; + auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; + + parse_header(line_reader.ptr(), end, + [&](std::string &&key, std::string &&val) { + x.headers.emplace(std::move(key), std::move(val)); + }); + + if (!line_reader.getline()) { return false; } + } + + return true; +} + +inline bool is_chunked_transfer_encoding(const Headers &headers) { + return compare_case_ignore( + get_header_value(headers, "Transfer-Encoding", 0, ""), "chunked"); +} + +template +bool prepare_content_receiver(T &x, int &status, + ContentReceiverWithProgress receiver, + bool decompress, U callback) { + if (decompress) { + std::string encoding = x.get_header_value("Content-Encoding"); + std::unique_ptr decompressor; + + if (encoding == "gzip" || encoding == "deflate") { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + decompressor = detail::make_unique(); +#else + status = StatusCode::UnsupportedMediaType_415; + return false; +#endif + } else if (encoding.find("br") != std::string::npos) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + decompressor = detail::make_unique(); +#else + status = StatusCode::UnsupportedMediaType_415; + return false; +#endif + } + + if (decompressor) { + if (decompressor->is_valid()) { + ContentReceiverWithProgress out = [&](const char *buf, size_t n, + uint64_t off, uint64_t len) { + return decompressor->decompress(buf, n, + [&](const char *buf2, size_t n2) { + return receiver(buf2, n2, off, len); + }); + }; + return callback(std::move(out)); + } else { + status = StatusCode::InternalServerError_500; + return false; + } + } + } + + ContentReceiverWithProgress out = [&](const char *buf, size_t n, uint64_t off, + uint64_t len) { + return receiver(buf, n, off, len); + }; + return callback(std::move(out)); +} + +template +bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, + Progress progress, ContentReceiverWithProgress receiver, + bool decompress) { + return prepare_content_receiver( + x, status, std::move(receiver), decompress, + [&](const ContentReceiverWithProgress &out) { + auto ret = true; + auto exceed_payload_max_length = false; + + if (is_chunked_transfer_encoding(x.headers)) { + ret = read_content_chunked(strm, x, out); + } else if (!has_header(x.headers, "Content-Length")) { + ret = read_content_without_length(strm, out); + } else { + auto len = get_header_value_u64(x.headers, "Content-Length", 0, 0); + if (len > payload_max_length) { + exceed_payload_max_length = true; + skip_content_with_length(strm, len); + ret = false; + } else if (len > 0) { + ret = read_content_with_length(strm, len, std::move(progress), out); + } + } + + if (!ret) { + status = exceed_payload_max_length ? StatusCode::PayloadTooLarge_413 + : StatusCode::BadRequest_400; + } + return ret; + }); +} // namespace detail + +inline ssize_t write_headers(Stream &strm, const Headers &headers) { + ssize_t write_len = 0; + for (const auto &x : headers) { + auto len = + strm.write_format("%s: %s\r\n", x.first.c_str(), x.second.c_str()); + if (len < 0) { return len; } + write_len += len; + } + auto len = strm.write("\r\n"); + if (len < 0) { return len; } + write_len += len; + return write_len; +} + +inline bool write_data(Stream &strm, const char *d, size_t l) { + size_t offset = 0; + while (offset < l) { + auto length = strm.write(d + offset, l - offset); + if (length < 0) { return false; } + offset += static_cast(length); + } + return true; +} + +template +inline bool write_content(Stream &strm, const ContentProvider &content_provider, + size_t offset, size_t length, T is_shutting_down, + Error &error) { + size_t end_offset = offset + length; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + if (strm.is_writable() && write_data(strm, d, l)) { + offset += l; + } else { + ok = false; + } + } + return ok; + }; + + data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + + while (offset < end_offset && !is_shutting_down()) { + if (!strm.is_writable()) { + error = Error::Write; + return false; + } else if (!content_provider(offset, end_offset - offset, data_sink)) { + error = Error::Canceled; + return false; + } else if (!ok) { + error = Error::Write; + return false; + } + } + + error = Error::Success; + return true; +} + +template +inline bool write_content(Stream &strm, const ContentProvider &content_provider, + size_t offset, size_t length, + const T &is_shutting_down) { + auto error = Error::Success; + return write_content(strm, content_provider, offset, length, is_shutting_down, + error); +} + +template +inline bool +write_content_without_length(Stream &strm, + const ContentProvider &content_provider, + const T &is_shutting_down) { + size_t offset = 0; + auto data_available = true; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + offset += l; + if (!strm.is_writable() || !write_data(strm, d, l)) { ok = false; } + } + return ok; + }; + + data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + + data_sink.done = [&](void) { data_available = false; }; + + while (data_available && !is_shutting_down()) { + if (!strm.is_writable()) { + return false; + } else if (!content_provider(offset, 0, data_sink)) { + return false; + } else if (!ok) { + return false; + } + } + return true; +} + +template +inline bool +write_content_chunked(Stream &strm, const ContentProvider &content_provider, + const T &is_shutting_down, U &compressor, Error &error) { + size_t offset = 0; + auto data_available = true; + auto ok = true; + DataSink data_sink; + + data_sink.write = [&](const char *d, size_t l) -> bool { + if (ok) { + data_available = l > 0; + offset += l; + + std::string payload; + if (compressor.compress(d, l, false, + [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = + from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (!strm.is_writable() || + !write_data(strm, chunk.data(), chunk.size())) { + ok = false; + } + } + } else { + ok = false; + } + } + return ok; + }; + + data_sink.is_writable = [&]() -> bool { return strm.is_writable(); }; + + auto done_with_trailer = [&](const Headers *trailer) { + if (!ok) { return; } + + data_available = false; + + std::string payload; + if (!compressor.compress(nullptr, 0, true, + [&](const char *data, size_t data_len) { + payload.append(data, data_len); + return true; + })) { + ok = false; + return; + } + + if (!payload.empty()) { + // Emit chunked response header and footer for each chunk + auto chunk = from_i_to_hex(payload.size()) + "\r\n" + payload + "\r\n"; + if (!strm.is_writable() || + !write_data(strm, chunk.data(), chunk.size())) { + ok = false; + return; + } + } + + static const std::string done_marker("0\r\n"); + if (!write_data(strm, done_marker.data(), done_marker.size())) { + ok = false; + } + + // Trailer + if (trailer) { + for (const auto &kv : *trailer) { + std::string field_line = kv.first + ": " + kv.second + "\r\n"; + if (!write_data(strm, field_line.data(), field_line.size())) { + ok = false; + } + } + } + + static const std::string crlf("\r\n"); + if (!write_data(strm, crlf.data(), crlf.size())) { ok = false; } + }; + + data_sink.done = [&](void) { done_with_trailer(nullptr); }; + + data_sink.done_with_trailer = [&](const Headers &trailer) { + done_with_trailer(&trailer); + }; + + while (data_available && !is_shutting_down()) { + if (!strm.is_writable()) { + error = Error::Write; + return false; + } else if (!content_provider(offset, 0, data_sink)) { + error = Error::Canceled; + return false; + } else if (!ok) { + error = Error::Write; + return false; + } + } + + error = Error::Success; + return true; +} + +template +inline bool write_content_chunked(Stream &strm, + const ContentProvider &content_provider, + const T &is_shutting_down, U &compressor) { + auto error = Error::Success; + return write_content_chunked(strm, content_provider, is_shutting_down, + compressor, error); +} + +template +inline bool redirect(T &cli, Request &req, Response &res, + const std::string &path, const std::string &location, + Error &error) { + Request new_req = req; + new_req.path = path; + new_req.redirect_count_ -= 1; + + if (res.status == StatusCode::SeeOther_303 && + (req.method != "GET" && req.method != "HEAD")) { + new_req.method = "GET"; + new_req.body.clear(); + new_req.headers.clear(); + } + + Response new_res; + + auto ret = cli.send(new_req, new_res, error); + if (ret) { + req = new_req; + res = new_res; + + if (res.location.empty()) { res.location = location; } + } + return ret; +} + +inline std::string params_to_query_str(const Params ¶ms) { + std::string query; + + for (auto it = params.begin(); it != params.end(); ++it) { + if (it != params.begin()) { query += "&"; } + query += it->first; + query += "="; + query += encode_query_param(it->second); + } + return query; +} + +inline void parse_query_text(const std::string &s, Params ¶ms) { + std::set cache; + split(s.data(), s.data() + s.size(), '&', [&](const char *b, const char *e) { + std::string kv(b, e); + if (cache.find(kv) != cache.end()) { return; } + cache.insert(kv); + + std::string key; + std::string val; + split(b, e, '=', [&](const char *b2, const char *e2) { + if (key.empty()) { + key.assign(b2, e2); + } else { + val.assign(b2, e2); + } + }); + + if (!key.empty()) { + params.emplace(decode_url(key, true), decode_url(val, true)); + } + }); +} + +inline bool parse_multipart_boundary(const std::string &content_type, + std::string &boundary) { + auto boundary_keyword = "boundary="; + auto pos = content_type.find(boundary_keyword); + if (pos == std::string::npos) { return false; } + auto end = content_type.find(';', pos); + auto beg = pos + strlen(boundary_keyword); + boundary = trim_double_quotes_copy(content_type.substr(beg, end - beg)); + return !boundary.empty(); +} + +inline void parse_disposition_params(const std::string &s, Params ¶ms) { + std::set cache; + split(s.data(), s.data() + s.size(), ';', [&](const char *b, const char *e) { + std::string kv(b, e); + if (cache.find(kv) != cache.end()) { return; } + cache.insert(kv); + + std::string key; + std::string val; + split(b, e, '=', [&](const char *b2, const char *e2) { + if (key.empty()) { + key.assign(b2, e2); + } else { + val.assign(b2, e2); + } + }); + + if (!key.empty()) { + params.emplace(trim_double_quotes_copy((key)), + trim_double_quotes_copy((val))); + } + }); +} + +#ifdef CPPHTTPLIB_NO_EXCEPTIONS +inline bool parse_range_header(const std::string &s, Ranges &ranges) { +#else +inline bool parse_range_header(const std::string &s, Ranges &ranges) try { +#endif + static auto re_first_range = std::regex(R"(bytes=(\d*-\d*(?:,\s*\d*-\d*)*))"); + std::smatch m; + if (std::regex_match(s, m, re_first_range)) { + auto pos = static_cast(m.position(1)); + auto len = static_cast(m.length(1)); + auto all_valid_ranges = true; + split(&s[pos], &s[pos + len], ',', [&](const char *b, const char *e) { + if (!all_valid_ranges) { return; } + static auto re_another_range = std::regex(R"(\s*(\d*)-(\d*))"); + std::cmatch cm; + if (std::regex_match(b, e, cm, re_another_range)) { + ssize_t first = -1; + if (!cm.str(1).empty()) { + first = static_cast(std::stoll(cm.str(1))); + } + + ssize_t last = -1; + if (!cm.str(2).empty()) { + last = static_cast(std::stoll(cm.str(2))); + } + + if (first != -1 && last != -1 && first > last) { + all_valid_ranges = false; + return; + } + ranges.emplace_back(std::make_pair(first, last)); + } + }); + return all_valid_ranges; + } + return false; +#ifdef CPPHTTPLIB_NO_EXCEPTIONS +} +#else +} catch (...) { return false; } +#endif + +class MultipartFormDataParser { +public: + MultipartFormDataParser() = default; + + void set_boundary(std::string &&boundary) { + boundary_ = boundary; + dash_boundary_crlf_ = dash_ + boundary_ + crlf_; + crlf_dash_boundary_ = crlf_ + dash_ + boundary_; + } + + bool is_valid() const { return is_valid_; } + + bool parse(const char *buf, size_t n, const ContentReceiver &content_callback, + const MultipartContentHeader &header_callback) { + + buf_append(buf, n); + + while (buf_size() > 0) { + switch (state_) { + case 0: { // Initial boundary + buf_erase(buf_find(dash_boundary_crlf_)); + if (dash_boundary_crlf_.size() > buf_size()) { return true; } + if (!buf_start_with(dash_boundary_crlf_)) { return false; } + buf_erase(dash_boundary_crlf_.size()); + state_ = 1; + break; + } + case 1: { // New entry + clear_file_info(); + state_ = 2; + break; + } + case 2: { // Headers + auto pos = buf_find(crlf_); + if (pos > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } + while (pos < buf_size()) { + // Empty line + if (pos == 0) { + if (!header_callback(file_)) { + is_valid_ = false; + return false; + } + buf_erase(crlf_.size()); + state_ = 3; + break; + } + + const auto header = buf_head(pos); + + if (!parse_header(header.data(), header.data() + header.size(), + [&](std::string &&, std::string &&) {})) { + is_valid_ = false; + return false; + } + + static const std::string header_content_type = "Content-Type:"; + + if (start_with_case_ignore(header, header_content_type)) { + file_.content_type = + trim_copy(header.substr(header_content_type.size())); + } else { + static const std::regex re_content_disposition( + R"~(^Content-Disposition:\s*form-data;\s*(.*)$)~", + std::regex_constants::icase); + + std::smatch m; + if (std::regex_match(header, m, re_content_disposition)) { + Params params; + parse_disposition_params(m[1], params); + + auto it = params.find("name"); + if (it != params.end()) { + file_.name = it->second; + } else { + is_valid_ = false; + return false; + } + + it = params.find("filename"); + if (it != params.end()) { file_.filename = it->second; } + + it = params.find("filename*"); + if (it != params.end()) { + // Only allow UTF-8 enconnding... + static const std::regex re_rfc5987_encoding( + R"~(^UTF-8''(.+?)$)~", std::regex_constants::icase); + + std::smatch m2; + if (std::regex_match(it->second, m2, re_rfc5987_encoding)) { + file_.filename = decode_url(m2[1], false); // override... + } else { + is_valid_ = false; + return false; + } + } + } + } + buf_erase(pos + crlf_.size()); + pos = buf_find(crlf_); + } + if (state_ != 3) { return true; } + break; + } + case 3: { // Body + if (crlf_dash_boundary_.size() > buf_size()) { return true; } + auto pos = buf_find(crlf_dash_boundary_); + if (pos < buf_size()) { + if (!content_callback(buf_data(), pos)) { + is_valid_ = false; + return false; + } + buf_erase(pos + crlf_dash_boundary_.size()); + state_ = 4; + } else { + auto len = buf_size() - crlf_dash_boundary_.size(); + if (len > 0) { + if (!content_callback(buf_data(), len)) { + is_valid_ = false; + return false; + } + buf_erase(len); + } + return true; + } + break; + } + case 4: { // Boundary + if (crlf_.size() > buf_size()) { return true; } + if (buf_start_with(crlf_)) { + buf_erase(crlf_.size()); + state_ = 1; + } else { + if (dash_.size() > buf_size()) { return true; } + if (buf_start_with(dash_)) { + buf_erase(dash_.size()); + is_valid_ = true; + buf_erase(buf_size()); // Remove epilogue + } else { + return true; + } + } + break; + } + } + } + + return true; + } + +private: + void clear_file_info() { + file_.name.clear(); + file_.filename.clear(); + file_.content_type.clear(); + } + + bool start_with_case_ignore(const std::string &a, + const std::string &b) const { + if (a.size() < b.size()) { return false; } + for (size_t i = 0; i < b.size(); i++) { + if (::tolower(a[i]) != ::tolower(b[i])) { return false; } + } + return true; + } + + const std::string dash_ = "--"; + const std::string crlf_ = "\r\n"; + std::string boundary_; + std::string dash_boundary_crlf_; + std::string crlf_dash_boundary_; + + size_t state_ = 0; + bool is_valid_ = false; + MultipartFormData file_; + + // Buffer + bool start_with(const std::string &a, size_t spos, size_t epos, + const std::string &b) const { + if (epos - spos < b.size()) { return false; } + for (size_t i = 0; i < b.size(); i++) { + if (a[i + spos] != b[i]) { return false; } + } + return true; + } + + size_t buf_size() const { return buf_epos_ - buf_spos_; } + + const char *buf_data() const { return &buf_[buf_spos_]; } + + std::string buf_head(size_t l) const { return buf_.substr(buf_spos_, l); } + + bool buf_start_with(const std::string &s) const { + return start_with(buf_, buf_spos_, buf_epos_, s); + } + + size_t buf_find(const std::string &s) const { + auto c = s.front(); + + size_t off = buf_spos_; + while (off < buf_epos_) { + auto pos = off; + while (true) { + if (pos == buf_epos_) { return buf_size(); } + if (buf_[pos] == c) { break; } + pos++; + } + + auto remaining_size = buf_epos_ - pos; + if (s.size() > remaining_size) { return buf_size(); } + + if (start_with(buf_, pos, buf_epos_, s)) { return pos - buf_spos_; } + + off = pos + 1; + } + + return buf_size(); + } + + void buf_append(const char *data, size_t n) { + auto remaining_size = buf_size(); + if (remaining_size > 0 && buf_spos_ > 0) { + for (size_t i = 0; i < remaining_size; i++) { + buf_[i] = buf_[buf_spos_ + i]; + } + } + buf_spos_ = 0; + buf_epos_ = remaining_size; + + if (remaining_size + n > buf_.size()) { buf_.resize(remaining_size + n); } + + for (size_t i = 0; i < n; i++) { + buf_[buf_epos_ + i] = data[i]; + } + buf_epos_ += n; + } + + void buf_erase(size_t size) { buf_spos_ += size; } + + std::string buf_; + size_t buf_spos_ = 0; + size_t buf_epos_ = 0; +}; + +inline std::string to_lower(const char *beg, const char *end) { + std::string out; + auto it = beg; + while (it != end) { + out += static_cast(::tolower(*it)); + it++; + } + return out; +} + +inline std::string random_string(size_t length) { + static const char data[] = + "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"; + + // std::random_device might actually be deterministic on some + // platforms, but due to lack of support in the c++ standard library, + // doing better requires either some ugly hacks or breaking portability. + static std::random_device seed_gen; + + // Request 128 bits of entropy for initialization + static std::seed_seq seed_sequence{seed_gen(), seed_gen(), seed_gen(), + seed_gen()}; + + static std::mt19937 engine(seed_sequence); + + std::string result; + for (size_t i = 0; i < length; i++) { + result += data[engine() % (sizeof(data) - 1)]; + } + return result; +} + +inline std::string make_multipart_data_boundary() { + return "--cpp-httplib-multipart-data-" + detail::random_string(16); +} + +inline bool is_multipart_boundary_chars_valid(const std::string &boundary) { + auto valid = true; + for (size_t i = 0; i < boundary.size(); i++) { + auto c = boundary[i]; + if (!std::isalnum(c) && c != '-' && c != '_') { + valid = false; + break; + } + } + return valid; +} + +template +inline std::string +serialize_multipart_formdata_item_begin(const T &item, + const std::string &boundary) { + std::string body = "--" + boundary + "\r\n"; + body += "Content-Disposition: form-data; name=\"" + item.name + "\""; + if (!item.filename.empty()) { + body += "; filename=\"" + item.filename + "\""; + } + body += "\r\n"; + if (!item.content_type.empty()) { + body += "Content-Type: " + item.content_type + "\r\n"; + } + body += "\r\n"; + + return body; +} + +inline std::string serialize_multipart_formdata_item_end() { return "\r\n"; } + +inline std::string +serialize_multipart_formdata_finish(const std::string &boundary) { + return "--" + boundary + "--\r\n"; +} + +inline std::string +serialize_multipart_formdata_get_content_type(const std::string &boundary) { + return "multipart/form-data; boundary=" + boundary; +} + +inline std::string +serialize_multipart_formdata(const MultipartFormDataItems &items, + const std::string &boundary, bool finish = true) { + std::string body; + + for (const auto &item : items) { + body += serialize_multipart_formdata_item_begin(item, boundary); + body += item.content + serialize_multipart_formdata_item_end(); + } + + if (finish) { body += serialize_multipart_formdata_finish(boundary); } + + return body; +} + +inline bool range_error(Request &req, Response &res) { + if (!req.ranges.empty() && 200 <= res.status && res.status < 300) { + ssize_t contant_len = static_cast( + res.content_length_ ? res.content_length_ : res.body.size()); + + ssize_t prev_first_pos = -1; + ssize_t prev_last_pos = -1; + size_t overwrapping_count = 0; + + // NOTE: The following Range check is based on '14.2. Range' in RFC 9110 + // 'HTTP Semantics' to avoid potential denial-of-service attacks. + // https://www.rfc-editor.org/rfc/rfc9110#section-14.2 + + // Too many ranges + if (req.ranges.size() > CPPHTTPLIB_RANGE_MAX_COUNT) { return true; } + + for (auto &r : req.ranges) { + auto &first_pos = r.first; + auto &last_pos = r.second; + + if (first_pos == -1 && last_pos == -1) { + first_pos = 0; + last_pos = contant_len; + } + + if (first_pos == -1) { + first_pos = contant_len - last_pos; + last_pos = contant_len - 1; + } + + if (last_pos == -1) { last_pos = contant_len - 1; } + + // Range must be within content length + if (!(0 <= first_pos && first_pos <= last_pos && + last_pos <= contant_len - 1)) { + return true; + } + + // Ranges must be in ascending order + if (first_pos <= prev_first_pos) { return true; } + + // Request must not have more than two overlapping ranges + if (first_pos <= prev_last_pos) { + overwrapping_count++; + if (overwrapping_count > 2) { return true; } + } + + prev_first_pos = (std::max)(prev_first_pos, first_pos); + prev_last_pos = (std::max)(prev_last_pos, last_pos); + } + } + + return false; +} + +inline std::pair +get_range_offset_and_length(Range r, size_t content_length) { + (void)(content_length); // patch to get rid of "unused parameter" on release build + assert(r.first != -1 && r.second != -1); + assert(0 <= r.first && r.first < static_cast(content_length)); + assert(r.first <= r.second && + r.second < static_cast(content_length)); + + return std::make_pair(r.first, static_cast(r.second - r.first) + 1); +} + +inline std::string make_content_range_header_field( + const std::pair &offset_and_length, size_t content_length) { + auto st = offset_and_length.first; + auto ed = st + offset_and_length.second - 1; + + std::string field = "bytes "; + field += std::to_string(st); + field += "-"; + field += std::to_string(ed); + field += "/"; + field += std::to_string(content_length); + return field; +} + +template +bool process_multipart_ranges_data(const Request &req, + const std::string &boundary, + const std::string &content_type, + size_t content_length, SToken stoken, + CToken ctoken, Content content) { + for (size_t i = 0; i < req.ranges.size(); i++) { + ctoken("--"); + stoken(boundary); + ctoken("\r\n"); + if (!content_type.empty()) { + ctoken("Content-Type: "); + stoken(content_type); + ctoken("\r\n"); + } + + auto offset_and_length = + get_range_offset_and_length(req.ranges[i], content_length); + + ctoken("Content-Range: "); + stoken(make_content_range_header_field(offset_and_length, content_length)); + ctoken("\r\n"); + ctoken("\r\n"); + + if (!content(offset_and_length.first, offset_and_length.second)) { + return false; + } + ctoken("\r\n"); + } + + ctoken("--"); + stoken(boundary); + ctoken("--"); + + return true; +} + +inline void make_multipart_ranges_data(const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type, + size_t content_length, + std::string &data) { + process_multipart_ranges_data( + req, boundary, content_type, content_length, + [&](const std::string &token) { data += token; }, + [&](const std::string &token) { data += token; }, + [&](size_t offset, size_t length) { + assert(offset + length <= content_length); + data += res.body.substr(offset, length); + return true; + }); +} + +inline size_t get_multipart_ranges_data_length(const Request &req, + const std::string &boundary, + const std::string &content_type, + size_t content_length) { + size_t data_length = 0; + + process_multipart_ranges_data( + req, boundary, content_type, content_length, + [&](const std::string &token) { data_length += token.size(); }, + [&](const std::string &token) { data_length += token.size(); }, + [&](size_t /*offset*/, size_t length) { + data_length += length; + return true; + }); + + return data_length; +} + +template +inline bool +write_multipart_ranges_data(Stream &strm, const Request &req, Response &res, + const std::string &boundary, + const std::string &content_type, + size_t content_length, const T &is_shutting_down) { + return process_multipart_ranges_data( + req, boundary, content_type, content_length, + [&](const std::string &token) { strm.write(token); }, + [&](const std::string &token) { strm.write(token); }, + [&](size_t offset, size_t length) { + return write_content(strm, res.content_provider_, offset, length, + is_shutting_down); + }); +} + +inline bool expect_content(const Request &req) { + if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || + req.method == "PRI" || req.method == "DELETE") { + return true; + } + // TODO: check if Content-Length is set + return false; +} + +inline bool has_crlf(const std::string &s) { + auto p = s.c_str(); + while (*p) { + if (*p == '\r' || *p == '\n') { return true; } + p++; + } + return false; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline std::string message_digest(const std::string &s, const EVP_MD *algo) { + auto context = std::unique_ptr( + EVP_MD_CTX_new(), EVP_MD_CTX_free); + + unsigned int hash_length = 0; + unsigned char hash[EVP_MAX_MD_SIZE]; + + EVP_DigestInit_ex(context.get(), algo, nullptr); + EVP_DigestUpdate(context.get(), s.c_str(), s.size()); + EVP_DigestFinal_ex(context.get(), hash, &hash_length); + + std::stringstream ss; + for (auto i = 0u; i < hash_length; ++i) { + ss << std::hex << std::setw(2) << std::setfill('0') + << static_cast(hash[i]); + } + + return ss.str(); +} + +inline std::string MD5(const std::string &s) { + return message_digest(s, EVP_md5()); +} + +inline std::string SHA_256(const std::string &s) { + return message_digest(s, EVP_sha256()); +} + +inline std::string SHA_512(const std::string &s) { + return message_digest(s, EVP_sha512()); +} +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef _WIN32 +// NOTE: This code came up with the following stackoverflow post: +// https://stackoverflow.com/questions/9507184/can-openssl-on-windows-use-the-system-certificate-store +inline bool load_system_certs_on_windows(X509_STORE *store) { + auto hStore = CertOpenSystemStoreW((HCRYPTPROV_LEGACY)NULL, L"ROOT"); + if (!hStore) { return false; } + + auto result = false; + PCCERT_CONTEXT pContext = NULL; + while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != + nullptr) { + auto encoded_cert = + static_cast(pContext->pbCertEncoded); + + auto x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); + if (x509) { + X509_STORE_add_cert(store, x509); + X509_free(x509); + result = true; + } + } + + CertFreeCertificateContext(pContext); + CertCloseStore(hStore, 0); + + return result; +} +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && defined(__APPLE__) +#if TARGET_OS_OSX +template +using CFObjectPtr = + std::unique_ptr::type, void (*)(CFTypeRef)>; + +inline void cf_object_ptr_deleter(CFTypeRef obj) { + if (obj) { CFRelease(obj); } +} + +inline bool retrieve_certs_from_keychain(CFObjectPtr &certs) { + CFStringRef keys[] = {kSecClass, kSecMatchLimit, kSecReturnRef}; + CFTypeRef values[] = {kSecClassCertificate, kSecMatchLimitAll, + kCFBooleanTrue}; + + CFObjectPtr query( + CFDictionaryCreate(nullptr, reinterpret_cast(keys), values, + sizeof(keys) / sizeof(keys[0]), + &kCFTypeDictionaryKeyCallBacks, + &kCFTypeDictionaryValueCallBacks), + cf_object_ptr_deleter); + + if (!query) { return false; } + + CFTypeRef security_items = nullptr; + if (SecItemCopyMatching(query.get(), &security_items) != errSecSuccess || + CFArrayGetTypeID() != CFGetTypeID(security_items)) { + return false; + } + + certs.reset(reinterpret_cast(security_items)); + return true; +} + +inline bool retrieve_root_certs_from_keychain(CFObjectPtr &certs) { + CFArrayRef root_security_items = nullptr; + if (SecTrustCopyAnchorCertificates(&root_security_items) != errSecSuccess) { + return false; + } + + certs.reset(root_security_items); + return true; +} + +inline bool add_certs_to_x509_store(CFArrayRef certs, X509_STORE *store) { + auto result = false; + for (auto i = 0; i < CFArrayGetCount(certs); ++i) { + const auto cert = reinterpret_cast( + CFArrayGetValueAtIndex(certs, i)); + + if (SecCertificateGetTypeID() != CFGetTypeID(cert)) { continue; } + + CFDataRef cert_data = nullptr; + if (SecItemExport(cert, kSecFormatX509Cert, 0, nullptr, &cert_data) != + errSecSuccess) { + continue; + } + + CFObjectPtr cert_data_ptr(cert_data, cf_object_ptr_deleter); + + auto encoded_cert = static_cast( + CFDataGetBytePtr(cert_data_ptr.get())); + + auto x509 = + d2i_X509(NULL, &encoded_cert, CFDataGetLength(cert_data_ptr.get())); + + if (x509) { + X509_STORE_add_cert(store, x509); + X509_free(x509); + result = true; + } + } + + return result; +} + +inline bool load_system_certs_on_macos(X509_STORE *store) { + auto result = false; + CFObjectPtr certs(nullptr, cf_object_ptr_deleter); + if (retrieve_certs_from_keychain(certs) && certs) { + result = add_certs_to_x509_store(certs.get(), store); + } + + if (retrieve_root_certs_from_keychain(certs) && certs) { + result = add_certs_to_x509_store(certs.get(), store) || result; + } + + return result; +} +#endif // TARGET_OS_OSX +#endif // _WIN32 +#endif // CPPHTTPLIB_OPENSSL_SUPPORT + +#ifdef _WIN32 +class WSInit { +public: + WSInit() { + WSADATA wsaData; + if (WSAStartup(0x0002, &wsaData) == 0) is_valid_ = true; + } + + ~WSInit() { + if (is_valid_) WSACleanup(); + } + + bool is_valid_ = false; +}; + +static WSInit wsinit_; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline std::pair make_digest_authentication_header( + const Request &req, const std::map &auth, + size_t cnonce_count, const std::string &cnonce, const std::string &username, + const std::string &password, bool is_proxy = false) { + std::string nc; + { + std::stringstream ss; + ss << std::setfill('0') << std::setw(8) << std::hex << cnonce_count; + nc = ss.str(); + } + + std::string qop; + if (auth.find("qop") != auth.end()) { + qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } else if (qop.find("auth") != std::string::npos) { + qop = "auth"; + } else { + qop.clear(); + } + } + + std::string algo = "MD5"; + if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } + + std::string response; + { + auto H = algo == "SHA-256" ? detail::SHA_256 + : algo == "SHA-512" ? detail::SHA_512 + : detail::MD5; + + auto A1 = username + ":" + auth.at("realm") + ":" + password; + + auto A2 = req.method + ":" + req.path; + if (qop == "auth-int") { A2 += ":" + H(req.body); } + + if (qop.empty()) { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + H(A2)); + } else { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + + ":" + qop + ":" + H(A2)); + } + } + + auto opaque = (auth.find("opaque") != auth.end()) ? auth.at("opaque") : ""; + + auto field = "Digest username=\"" + username + "\", realm=\"" + + auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + + "\", uri=\"" + req.path + "\", algorithm=" + algo + + (qop.empty() ? ", response=\"" + : ", qop=" + qop + ", nc=" + nc + ", cnonce=\"" + + cnonce + "\", response=\"") + + response + "\"" + + (opaque.empty() ? "" : ", opaque=\"" + opaque + "\""); + + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); +} +#endif + +inline bool parse_www_authenticate(const Response &res, + std::map &auth, + bool is_proxy) { + auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; + if (res.has_header(auth_key)) { + static auto re = std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); + auto s = res.get_header_value(auth_key); + auto pos = s.find(' '); + if (pos != std::string::npos) { + auto type = s.substr(0, pos); + if (type == "Basic") { + return false; + } else if (type == "Digest") { + s = s.substr(pos + 1); + auto beg = std::sregex_iterator(s.begin(), s.end(), re); + for (auto i = beg; i != std::sregex_iterator(); ++i) { + const auto &m = *i; + auto key = s.substr(static_cast(m.position(1)), + static_cast(m.length(1))); + auto val = m.length(2) > 0 + ? s.substr(static_cast(m.position(2)), + static_cast(m.length(2))) + : s.substr(static_cast(m.position(3)), + static_cast(m.length(3))); + auth[key] = val; + } + return true; + } + } + } + return false; +} + +class ContentProviderAdapter { +public: + explicit ContentProviderAdapter( + ContentProviderWithoutLength &&content_provider) + : content_provider_(content_provider) {} + + bool operator()(size_t offset, size_t, DataSink &sink) { + return content_provider_(offset, sink); + } + +private: + ContentProviderWithoutLength content_provider_; +}; + +} // namespace detail + +inline std::string hosted_at(const std::string &hostname) { + std::vector addrs; + hosted_at(hostname, addrs); + if (addrs.empty()) { return std::string(); } + return addrs[0]; +} + +inline void hosted_at(const std::string &hostname, + std::vector &addrs) { + struct addrinfo hints; + struct addrinfo *result; + + memset(&hints, 0, sizeof(struct addrinfo)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + hints.ai_protocol = 0; + + if (getaddrinfo(hostname.c_str(), nullptr, &hints, &result)) { +#if defined __linux__ && !defined __ANDROID__ + res_init(); +#endif + return; + } + + for (auto rp = result; rp; rp = rp->ai_next) { + const auto &addr = + *reinterpret_cast(rp->ai_addr); + std::string ip; + auto dummy = -1; + if (detail::get_ip_and_port(addr, sizeof(struct sockaddr_storage), ip, + dummy)) { + addrs.push_back(ip); + } + } + + freeaddrinfo(result); +} + +inline std::string append_query_params(const std::string &path, + const Params ¶ms) { + std::string path_with_query = path; + const static std::regex re("[^?]+\\?.*"); + auto delm = std::regex_match(path, re) ? '&' : '?'; + path_with_query += delm + detail::params_to_query_str(params); + return path_with_query; +} + +// Header utilities +inline std::pair +make_range_header(const Ranges &ranges) { + std::string field = "bytes="; + auto i = 0; + for (const auto &r : ranges) { + if (i != 0) { field += ", "; } + if (r.first != -1) { field += std::to_string(r.first); } + field += '-'; + if (r.second != -1) { field += std::to_string(r.second); } + i++; + } + return std::make_pair("Range", std::move(field)); +} + +inline std::pair +make_basic_authentication_header(const std::string &username, + const std::string &password, bool is_proxy) { + auto field = "Basic " + detail::base64_encode(username + ":" + password); + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, std::move(field)); +} + +inline std::pair +make_bearer_token_authentication_header(const std::string &token, + bool is_proxy = false) { + auto field = "Bearer " + token; + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, std::move(field)); +} + +// Request implementation +inline bool Request::has_header(const std::string &key) const { + return detail::has_header(headers, key); +} + +inline std::string Request::get_header_value(const std::string &key, + size_t id) const { + return detail::get_header_value(headers, key, id, ""); +} + +inline size_t Request::get_header_value_count(const std::string &key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Request::set_header(const std::string &key, + const std::string &val) { + if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + headers.emplace(key, val); + } +} + +inline bool Request::has_param(const std::string &key) const { + return params.find(key) != params.end(); +} + +inline std::string Request::get_param_value(const std::string &key, + size_t id) const { + auto rng = params.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return it->second; } + return std::string(); +} + +inline size_t Request::get_param_value_count(const std::string &key) const { + auto r = params.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline bool Request::is_multipart_form_data() const { + const auto &content_type = get_header_value("Content-Type"); + return !content_type.rfind("multipart/form-data", 0); +} + +inline bool Request::has_file(const std::string &key) const { + return files.find(key) != files.end(); +} + +inline MultipartFormData Request::get_file_value(const std::string &key) const { + auto it = files.find(key); + if (it != files.end()) { return it->second; } + return MultipartFormData(); +} + +inline std::vector +Request::get_file_values(const std::string &key) const { + std::vector values; + auto rng = files.equal_range(key); + for (auto it = rng.first; it != rng.second; it++) { + values.push_back(it->second); + } + return values; +} + +// Response implementation +inline bool Response::has_header(const std::string &key) const { + return headers.find(key) != headers.end(); +} + +inline std::string Response::get_header_value(const std::string &key, + size_t id) const { + return detail::get_header_value(headers, key, id, ""); +} + +inline size_t Response::get_header_value_count(const std::string &key) const { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +inline void Response::set_header(const std::string &key, + const std::string &val) { + if (!detail::has_crlf(key) && !detail::has_crlf(val)) { + headers.emplace(key, val); + } +} + +inline void Response::set_redirect(const std::string &url, int stat) { + if (!detail::has_crlf(url)) { + set_header("Location", url); + if (300 <= stat && stat < 400) { + this->status = stat; + } else { + this->status = StatusCode::Found_302; + } + } +} + +inline void Response::set_content(const char *s, size_t n, + const std::string &content_type) { + body.assign(s, n); + + auto rng = headers.equal_range("Content-Type"); + headers.erase(rng.first, rng.second); + set_header("Content-Type", content_type); +} + +inline void Response::set_content(const std::string &s, + const std::string &content_type) { + set_content(s.data(), s.size(), content_type); +} + +inline void Response::set_content(std::string &&s, + const std::string &content_type) { + body = std::move(s); + + auto rng = headers.equal_range("Content-Type"); + headers.erase(rng.first, rng.second); + set_header("Content-Type", content_type); +} + +inline void Response::set_content_provider( + size_t in_length, const std::string &content_type, ContentProvider provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = in_length; + if (in_length > 0) { content_provider_ = std::move(provider); } + content_provider_resource_releaser_ = std::move(resource_releaser); + is_chunked_content_provider_ = false; +} + +inline void Response::set_content_provider( + const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = std::move(resource_releaser); + is_chunked_content_provider_ = false; +} + +inline void Response::set_chunked_content_provider( + const std::string &content_type, ContentProviderWithoutLength provider, + ContentProviderResourceReleaser resource_releaser) { + set_header("Content-Type", content_type); + content_length_ = 0; + content_provider_ = detail::ContentProviderAdapter(std::move(provider)); + content_provider_resource_releaser_ = std::move(resource_releaser); + is_chunked_content_provider_ = true; +} + +// Result implementation +inline bool Result::has_request_header(const std::string &key) const { + return request_headers_.find(key) != request_headers_.end(); +} + +inline std::string Result::get_request_header_value(const std::string &key, + size_t id) const { + return detail::get_header_value(request_headers_, key, id, ""); +} + +inline size_t +Result::get_request_header_value_count(const std::string &key) const { + auto r = request_headers_.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +// Stream implementation +inline ssize_t Stream::write(const char *ptr) { + return write(ptr, strlen(ptr)); +} + +inline ssize_t Stream::write(const std::string &s) { + return write(s.data(), s.size()); +} + +namespace detail { + +// Socket stream implementation +inline SocketStream::SocketStream(socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec) + : sock_(sock), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec), + write_timeout_sec_(write_timeout_sec), + write_timeout_usec_(write_timeout_usec), read_buff_(read_buff_size_, 0) {} + +inline SocketStream::~SocketStream() = default; + +inline bool SocketStream::is_readable() const { + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; +} + +inline bool SocketStream::is_writable() const { + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && + is_socket_alive(sock_); +} + +inline ssize_t SocketStream::read(char *ptr, size_t size) { +#ifdef _WIN32 + size = + (std::min)(size, static_cast((std::numeric_limits::max)())); +#else + size = (std::min)(size, + static_cast((std::numeric_limits::max)())); +#endif + + if (read_buff_off_ < read_buff_content_size_) { + auto remaining_size = read_buff_content_size_ - read_buff_off_; + if (size <= remaining_size) { + memcpy(ptr, read_buff_.data() + read_buff_off_, size); + read_buff_off_ += size; + return static_cast(size); + } else { + memcpy(ptr, read_buff_.data() + read_buff_off_, remaining_size); + read_buff_off_ += remaining_size; + return static_cast(remaining_size); + } + } + + if (!is_readable()) { return -1; } + + read_buff_off_ = 0; + read_buff_content_size_ = 0; + + if (size < read_buff_size_) { + auto n = read_socket(sock_, read_buff_.data(), read_buff_size_, + CPPHTTPLIB_RECV_FLAGS); + if (n <= 0) { + return n; + } else if (n <= static_cast(size)) { + memcpy(ptr, read_buff_.data(), static_cast(n)); + return n; + } else { + memcpy(ptr, read_buff_.data(), size); + read_buff_off_ = size; + read_buff_content_size_ = static_cast(n); + return static_cast(size); + } + } else { + return read_socket(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS); + } +} + +inline ssize_t SocketStream::write(const char *ptr, size_t size) { + if (!is_writable()) { return -1; } + +#if defined(_WIN32) && !defined(_WIN64) + size = + (std::min)(size, static_cast((std::numeric_limits::max)())); +#endif + + return send_socket(sock_, ptr, size, CPPHTTPLIB_SEND_FLAGS); +} + +inline void SocketStream::get_remote_ip_and_port(std::string &ip, + int &port) const { + return detail::get_remote_ip_and_port(sock_, ip, port); +} + +inline void SocketStream::get_local_ip_and_port(std::string &ip, + int &port) const { + return detail::get_local_ip_and_port(sock_, ip, port); +} + +inline socket_t SocketStream::socket() const { return sock_; } + +// Buffer stream implementation +inline bool BufferStream::is_readable() const { return true; } + +inline bool BufferStream::is_writable() const { return true; } + +inline ssize_t BufferStream::read(char *ptr, size_t size) { +#if defined(_MSC_VER) && _MSC_VER < 1910 + auto len_read = buffer._Copy_s(ptr, size, size, position); +#else + auto len_read = buffer.copy(ptr, size, position); +#endif + position += static_cast(len_read); + return static_cast(len_read); +} + +inline ssize_t BufferStream::write(const char *ptr, size_t size) { + buffer.append(ptr, size); + return static_cast(size); +} + +inline void BufferStream::get_remote_ip_and_port(std::string & /*ip*/, + int & /*port*/) const {} + +inline void BufferStream::get_local_ip_and_port(std::string & /*ip*/, + int & /*port*/) const {} + +inline socket_t BufferStream::socket() const { return 0; } + +inline const std::string &BufferStream::get_buffer() const { return buffer; } + +inline PathParamsMatcher::PathParamsMatcher(const std::string &pattern) { + // One past the last ending position of a path param substring + std::size_t last_param_end = 0; + +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + // Needed to ensure that parameter names are unique during matcher + // construction + // If exceptions are disabled, only last duplicate path + // parameter will be set + std::unordered_set param_name_set; +#endif + + while (true) { + const auto marker_pos = pattern.find(marker, last_param_end); + if (marker_pos == std::string::npos) { break; } + + static_fragments_.push_back( + pattern.substr(last_param_end, marker_pos - last_param_end)); + + const auto param_name_start = marker_pos + 1; + + auto sep_pos = pattern.find(separator, param_name_start); + if (sep_pos == std::string::npos) { sep_pos = pattern.length(); } + + auto param_name = + pattern.substr(param_name_start, sep_pos - param_name_start); + +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + if (param_name_set.find(param_name) != param_name_set.cend()) { + std::string msg = "Encountered path parameter '" + param_name + + "' multiple times in route pattern '" + pattern + "'."; + throw std::invalid_argument(msg); + } +#endif + + param_names_.push_back(std::move(param_name)); + + last_param_end = sep_pos + 1; + } + + if (last_param_end < pattern.length()) { + static_fragments_.push_back(pattern.substr(last_param_end)); + } +} + +inline bool PathParamsMatcher::match(Request &request) const { + request.matches = std::smatch(); + request.path_params.clear(); + request.path_params.reserve(param_names_.size()); + + // One past the position at which the path matched the pattern last time + std::size_t starting_pos = 0; + for (size_t i = 0; i < static_fragments_.size(); ++i) { + const auto &fragment = static_fragments_[i]; + + if (starting_pos + fragment.length() > request.path.length()) { + return false; + } + + // Avoid unnecessary allocation by using strncmp instead of substr + + // comparison + if (std::strncmp(request.path.c_str() + starting_pos, fragment.c_str(), + fragment.length()) != 0) { + return false; + } + + starting_pos += fragment.length(); + + // Should only happen when we have a static fragment after a param + // Example: '/users/:id/subscriptions' + // The 'subscriptions' fragment here does not have a corresponding param + if (i >= param_names_.size()) { continue; } + + auto sep_pos = request.path.find(separator, starting_pos); + if (sep_pos == std::string::npos) { sep_pos = request.path.length(); } + + const auto ¶m_name = param_names_[i]; + + request.path_params.emplace( + param_name, request.path.substr(starting_pos, sep_pos - starting_pos)); + + // Mark everythin up to '/' as matched + starting_pos = sep_pos + 1; + } + // Returns false if the path is longer than the pattern + return starting_pos >= request.path.length(); +} + +inline bool RegexMatcher::match(Request &request) const { + request.path_params.clear(); + return std::regex_match(request.path, request.matches, regex_); +} + +} // namespace detail + +// HTTP server implementation +inline Server::Server() + : new_task_queue( + [] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); }) { +#ifndef _WIN32 + signal(SIGPIPE, SIG_IGN); +#endif +} + +inline Server::~Server() = default; + +inline std::unique_ptr +Server::make_matcher(const std::string &pattern) { + if (pattern.find("/:") != std::string::npos) { + return detail::make_unique(pattern); + } else { + return detail::make_unique(pattern); + } +} + +inline Server &Server::Get(const std::string &pattern, Handler handler) { + get_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Post(const std::string &pattern, Handler handler) { + post_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Post(const std::string &pattern, + HandlerWithContentReader handler) { + post_handlers_for_content_reader_.emplace_back(make_matcher(pattern), + std::move(handler)); + return *this; +} + +inline Server &Server::Put(const std::string &pattern, Handler handler) { + put_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Put(const std::string &pattern, + HandlerWithContentReader handler) { + put_handlers_for_content_reader_.emplace_back(make_matcher(pattern), + std::move(handler)); + return *this; +} + +inline Server &Server::Patch(const std::string &pattern, Handler handler) { + patch_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Patch(const std::string &pattern, + HandlerWithContentReader handler) { + patch_handlers_for_content_reader_.emplace_back(make_matcher(pattern), + std::move(handler)); + return *this; +} + +inline Server &Server::Delete(const std::string &pattern, Handler handler) { + delete_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline Server &Server::Delete(const std::string &pattern, + HandlerWithContentReader handler) { + delete_handlers_for_content_reader_.emplace_back(make_matcher(pattern), + std::move(handler)); + return *this; +} + +inline Server &Server::Options(const std::string &pattern, Handler handler) { + options_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; +} + +inline bool Server::set_base_dir(const std::string &dir, + const std::string &mount_point) { + return set_mount_point(mount_point, dir); +} + +inline bool Server::set_mount_point(const std::string &mount_point, + const std::string &dir, Headers headers) { + if (detail::is_dir(dir)) { + std::string mnt = !mount_point.empty() ? mount_point : "/"; + if (!mnt.empty() && mnt[0] == '/') { + base_dirs_.push_back({mnt, dir, std::move(headers)}); + return true; + } + } + return false; +} + +inline bool Server::remove_mount_point(const std::string &mount_point) { + for (auto it = base_dirs_.begin(); it != base_dirs_.end(); ++it) { + if (it->mount_point == mount_point) { + base_dirs_.erase(it); + return true; + } + } + return false; +} + +inline Server & +Server::set_file_extension_and_mimetype_mapping(const std::string &ext, + const std::string &mime) { + file_extension_and_mimetype_map_[ext] = mime; + return *this; +} + +inline Server &Server::set_default_file_mimetype(const std::string &mime) { + default_file_mimetype_ = mime; + return *this; +} + +inline Server &Server::set_file_request_handler(Handler handler) { + file_request_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_error_handler(HandlerWithResponse handler) { + error_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_error_handler(Handler handler) { + error_handler_ = [handler](const Request &req, Response &res) { + handler(req, res); + return HandlerResponse::Handled; + }; + return *this; +} + +inline Server &Server::set_exception_handler(ExceptionHandler handler) { + exception_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_pre_routing_handler(HandlerWithResponse handler) { + pre_routing_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_post_routing_handler(Handler handler) { + post_routing_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_logger(Logger logger) { + logger_ = std::move(logger); + return *this; +} + +inline Server & +Server::set_expect_100_continue_handler(Expect100ContinueHandler handler) { + expect_100_continue_handler_ = std::move(handler); + return *this; +} + +inline Server &Server::set_address_family(int family) { + address_family_ = family; + return *this; +} + +inline Server &Server::set_tcp_nodelay(bool on) { + tcp_nodelay_ = on; + return *this; +} + +inline Server &Server::set_socket_options(SocketOptions socket_options) { + socket_options_ = std::move(socket_options); + return *this; +} + +inline Server &Server::set_default_headers(Headers headers) { + default_headers_ = std::move(headers); + return *this; +} + +inline Server &Server::set_header_writer( + std::function const &writer) { + header_writer_ = writer; + return *this; +} + +inline Server &Server::set_keep_alive_max_count(size_t count) { + keep_alive_max_count_ = count; + return *this; +} + +inline Server &Server::set_keep_alive_timeout(time_t sec) { + keep_alive_timeout_sec_ = sec; + return *this; +} + +inline Server &Server::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; + return *this; +} + +inline Server &Server::set_write_timeout(time_t sec, time_t usec) { + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; + return *this; +} + +inline Server &Server::set_idle_interval(time_t sec, time_t usec) { + idle_interval_sec_ = sec; + idle_interval_usec_ = usec; + return *this; +} + +inline Server &Server::set_payload_max_length(size_t length) { + payload_max_length_ = length; + return *this; +} + +inline bool Server::bind_to_port(const std::string &host, int port, + int socket_flags) { + return bind_internal(host, port, socket_flags) >= 0; +} +inline int Server::bind_to_any_port(const std::string &host, int socket_flags) { + return bind_internal(host, 0, socket_flags); +} + +inline bool Server::listen_after_bind() { + auto se = detail::scope_exit([&]() { done_ = true; }); + return listen_internal(); +} + +inline bool Server::listen(const std::string &host, int port, + int socket_flags) { + auto se = detail::scope_exit([&]() { done_ = true; }); + return bind_to_port(host, port, socket_flags) && listen_internal(); +} + +inline bool Server::is_running() const { return is_running_; } + +inline void Server::wait_until_ready() const { + while (!is_running() && !done_) { + std::this_thread::sleep_for(std::chrono::milliseconds{1}); + } +} + +inline void Server::stop() { + if (is_running_) { + assert(svr_sock_ != INVALID_SOCKET); + std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); + detail::shutdown_socket(sock); + detail::close_socket(sock); + } +} + +inline bool Server::parse_request_line(const char *s, Request &req) const { + auto len = strlen(s); + if (len < 2 || s[len - 2] != '\r' || s[len - 1] != '\n') { return false; } + len -= 2; + + { + size_t count = 0; + + detail::split(s, s + len, ' ', [&](const char *b, const char *e) { + switch (count) { + case 0: req.method = std::string(b, e); break; + case 1: req.target = std::string(b, e); break; + case 2: req.version = std::string(b, e); break; + default: break; + } + count++; + }); + + if (count != 3) { return false; } + } + + static const std::set methods{ + "GET", "HEAD", "POST", "PUT", "DELETE", + "CONNECT", "OPTIONS", "TRACE", "PATCH", "PRI"}; + + if (methods.find(req.method) == methods.end()) { return false; } + + if (req.version != "HTTP/1.1" && req.version != "HTTP/1.0") { return false; } + + { + // Skip URL fragment + for (size_t i = 0; i < req.target.size(); i++) { + if (req.target[i] == '#') { + req.target.erase(i); + break; + } + } + + size_t count = 0; + + detail::split(req.target.data(), req.target.data() + req.target.size(), '?', + 2, [&](const char *b, const char *e) { + switch (count) { + case 0: + req.path = detail::decode_url(std::string(b, e), false); + break; + case 1: { + if (e - b > 0) { + detail::parse_query_text(std::string(b, e), req.params); + } + break; + } + default: break; + } + count++; + }); + + if (count > 2) { return false; } + } + + return true; +} + +inline bool Server::write_response(Stream &strm, bool close_connection, + Request &req, Response &res) { + // NOTE: `req.ranges` should be empty, otherwise it will be applied + // incorrectly to the error content. + req.ranges.clear(); + return write_response_core(strm, close_connection, req, res, false); +} + +inline bool Server::write_response_with_content(Stream &strm, + bool close_connection, + const Request &req, + Response &res) { + return write_response_core(strm, close_connection, req, res, true); +} + +inline bool Server::write_response_core(Stream &strm, bool close_connection, + const Request &req, Response &res, + bool need_apply_ranges) { + assert(res.status != -1); + + if (400 <= res.status && error_handler_ && + error_handler_(req, res) == HandlerResponse::Handled) { + need_apply_ranges = true; + } + + std::string content_type; + std::string boundary; + if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); } + + // Prepare additional headers + if (close_connection || req.get_header_value("Connection") == "close") { + res.set_header("Connection", "close"); + } else { + std::stringstream ss; + ss << "timeout=" << keep_alive_timeout_sec_ + << ", max=" << keep_alive_max_count_; + res.set_header("Keep-Alive", ss.str()); + } + + if (!res.has_header("Content-Type") && + (!res.body.empty() || res.content_length_ > 0 || res.content_provider_)) { + res.set_header("Content-Type", "text/plain"); + } + + if (!res.has_header("Content-Length") && res.body.empty() && + !res.content_length_ && !res.content_provider_) { + res.set_header("Content-Length", "0"); + } + + if (!res.has_header("Accept-Ranges") && req.method == "HEAD") { + res.set_header("Accept-Ranges", "bytes"); + } + + if (post_routing_handler_) { post_routing_handler_(req, res); } + + // Response line and headers + { + detail::BufferStream bstrm; + + if (!bstrm.write_format("HTTP/1.1 %d %s\r\n", res.status, + status_message(res.status))) { + return false; + } + + if (!header_writer_(bstrm, res.headers)) { return false; } + + // Flush buffer + auto &data = bstrm.get_buffer(); + detail::write_data(strm, data.data(), data.size()); + } + + // Body + auto ret = true; + if (req.method != "HEAD") { + if (!res.body.empty()) { + if (!detail::write_data(strm, res.body.data(), res.body.size())) { + ret = false; + } + } else if (res.content_provider_) { + if (write_content_with_provider(strm, req, res, boundary, content_type)) { + res.content_provider_success_ = true; + } else { + ret = false; + } + } + } + + // Log + if (logger_) { logger_(req, res); } + + return ret; +} + +inline bool +Server::write_content_with_provider(Stream &strm, const Request &req, + Response &res, const std::string &boundary, + const std::string &content_type) { + auto is_shutting_down = [this]() { + return this->svr_sock_ == INVALID_SOCKET; + }; + + if (res.content_length_ > 0) { + if (req.ranges.empty()) { + return detail::write_content(strm, res.content_provider_, 0, + res.content_length_, is_shutting_down); + } else if (req.ranges.size() == 1) { + auto offset_and_length = detail::get_range_offset_and_length( + req.ranges[0], res.content_length_); + + return detail::write_content(strm, res.content_provider_, + offset_and_length.first, + offset_and_length.second, is_shutting_down); + } else { + return detail::write_multipart_ranges_data( + strm, req, res, boundary, content_type, res.content_length_, + is_shutting_down); + } + } else { + if (res.is_chunked_content_provider_) { + auto type = detail::encoding_type(req, res); + + std::unique_ptr compressor; + if (type == detail::EncodingType::Gzip) { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + compressor = detail::make_unique(); +#endif + } else if (type == detail::EncodingType::Brotli) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + compressor = detail::make_unique(); +#endif + } else { + compressor = detail::make_unique(); + } + assert(compressor != nullptr); + + return detail::write_content_chunked(strm, res.content_provider_, + is_shutting_down, *compressor); + } else { + return detail::write_content_without_length(strm, res.content_provider_, + is_shutting_down); + } + } +} + +inline bool Server::read_content(Stream &strm, Request &req, Response &res) { + MultipartFormDataMap::iterator cur; + auto file_count = 0; + if (read_content_core( + strm, req, res, + // Regular + [&](const char *buf, size_t n) { + if (req.body.size() + n > req.body.max_size()) { return false; } + req.body.append(buf, n); + return true; + }, + // Multipart + [&](const MultipartFormData &file) { + if (file_count++ == CPPHTTPLIB_MULTIPART_FORM_DATA_FILE_MAX_COUNT) { + return false; + } + cur = req.files.emplace(file.name, file); + return true; + }, + [&](const char *buf, size_t n) { + auto &content = cur->second.content; + if (content.size() + n > content.max_size()) { return false; } + content.append(buf, n); + return true; + })) { + const auto &content_type = req.get_header_value("Content-Type"); + if (!content_type.find("application/x-www-form-urlencoded")) { + if (req.body.size() > CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH) { + res.status = StatusCode::PayloadTooLarge_413; // NOTE: should be 414? + return false; + } + detail::parse_query_text(req.body, req.params); + } + return true; + } + return false; +} + +inline bool Server::read_content_with_content_receiver( + Stream &strm, Request &req, Response &res, ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver) { + return read_content_core(strm, req, res, std::move(receiver), + std::move(multipart_header), + std::move(multipart_receiver)); +} + +inline bool +Server::read_content_core(Stream &strm, Request &req, Response &res, + ContentReceiver receiver, + MultipartContentHeader multipart_header, + ContentReceiver multipart_receiver) const { + detail::MultipartFormDataParser multipart_form_data_parser; + ContentReceiverWithProgress out; + + if (req.is_multipart_form_data()) { + const auto &content_type = req.get_header_value("Content-Type"); + std::string boundary; + if (!detail::parse_multipart_boundary(content_type, boundary)) { + res.status = StatusCode::BadRequest_400; + return false; + } + + multipart_form_data_parser.set_boundary(std::move(boundary)); + out = [&](const char *buf, size_t n, uint64_t /*off*/, uint64_t /*len*/) { + /* For debug + size_t pos = 0; + while (pos < n) { + auto read_size = (std::min)(1, n - pos); + auto ret = multipart_form_data_parser.parse( + buf + pos, read_size, multipart_receiver, multipart_header); + if (!ret) { return false; } + pos += read_size; + } + return true; + */ + return multipart_form_data_parser.parse(buf, n, multipart_receiver, + multipart_header); + }; + } else { + out = [receiver](const char *buf, size_t n, uint64_t /*off*/, + uint64_t /*len*/) { return receiver(buf, n); }; + } + + if (req.method == "DELETE" && !req.has_header("Content-Length")) { + return true; + } + + if (!detail::read_content(strm, req, payload_max_length_, res.status, nullptr, + out, true)) { + return false; + } + + if (req.is_multipart_form_data()) { + if (!multipart_form_data_parser.is_valid()) { + res.status = StatusCode::BadRequest_400; + return false; + } + } + + return true; +} + +inline bool Server::handle_file_request(const Request &req, Response &res, + bool head) { + for (const auto &entry : base_dirs_) { + // Prefix match + if (!req.path.compare(0, entry.mount_point.size(), entry.mount_point)) { + std::string sub_path = "/" + req.path.substr(entry.mount_point.size()); + if (detail::is_valid_path(sub_path)) { + auto path = entry.base_dir + sub_path; + if (path.back() == '/') { path += "index.html"; } + + if (detail::is_file(path)) { + for (const auto &kv : entry.headers) { + res.set_header(kv.first, kv.second); + } + + auto mm = std::make_shared(path.c_str()); + if (!mm->is_open()) { return false; } + + res.set_content_provider( + mm->size(), + detail::find_content_type(path, file_extension_and_mimetype_map_, + default_file_mimetype_), + [mm](size_t offset, size_t length, DataSink &sink) -> bool { + sink.write(mm->data() + offset, length); + return true; + }); + + if (!head && file_request_handler_) { + file_request_handler_(req, res); + } + + return true; + } + } + } + } + return false; +} + +inline socket_t +Server::create_server_socket(const std::string &host, int port, + int socket_flags, + SocketOptions socket_options) const { + return detail::create_socket( + host, std::string(), port, address_family_, socket_flags, tcp_nodelay_, + std::move(socket_options), + [](socket_t sock, struct addrinfo &ai) -> bool { + if (::bind(sock, ai.ai_addr, static_cast(ai.ai_addrlen))) { + return false; + } + if (::listen(sock, CPPHTTPLIB_LISTEN_BACKLOG)) { return false; } + return true; + }); +} + +inline int Server::bind_internal(const std::string &host, int port, + int socket_flags) { + if (!is_valid()) { return -1; } + + svr_sock_ = create_server_socket(host, port, socket_flags, socket_options_); + if (svr_sock_ == INVALID_SOCKET) { return -1; } + + if (port == 0) { + struct sockaddr_storage addr; + socklen_t addr_len = sizeof(addr); + if (getsockname(svr_sock_, reinterpret_cast(&addr), + &addr_len) == -1) { + return -1; + } + if (addr.ss_family == AF_INET) { + return ntohs(reinterpret_cast(&addr)->sin_port); + } else if (addr.ss_family == AF_INET6) { + return ntohs(reinterpret_cast(&addr)->sin6_port); + } else { + return -1; + } + } else { + return port; + } +} + +inline bool Server::listen_internal() { + auto ret = true; + is_running_ = true; + auto se = detail::scope_exit([&]() { is_running_ = false; }); + + { + std::unique_ptr task_queue(new_task_queue()); + + while (svr_sock_ != INVALID_SOCKET) { +#ifndef _WIN32 + if (idle_interval_sec_ > 0 || idle_interval_usec_ > 0) { +#endif + auto val = detail::select_read(svr_sock_, idle_interval_sec_, + idle_interval_usec_); + if (val == 0) { // Timeout + task_queue->on_idle(); + continue; + } +#ifndef _WIN32 + } +#endif + socket_t sock = accept(svr_sock_, nullptr, nullptr); + + if (sock == INVALID_SOCKET) { + if (errno == EMFILE) { + // The per-process limit of open file descriptors has been reached. + // Try to accept new connections after a short sleep. + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + continue; + } else if (errno == EINTR || errno == EAGAIN) { + continue; + } + if (svr_sock_ != INVALID_SOCKET) { + detail::close_socket(svr_sock_); + ret = false; + } else { + ; // The server socket was closed by user. + } + break; + } + + { +#ifdef _WIN32 + auto timeout = static_cast(read_timeout_sec_ * 1000 + + read_timeout_usec_ / 1000); + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, + reinterpret_cast(&timeout), sizeof(timeout)); +#else + timeval tv; + tv.tv_sec = static_cast(read_timeout_sec_); + tv.tv_usec = static_cast(read_timeout_usec_); + setsockopt(sock, SOL_SOCKET, SO_RCVTIMEO, + reinterpret_cast(&tv), sizeof(tv)); +#endif + } + { + +#ifdef _WIN32 + auto timeout = static_cast(write_timeout_sec_ * 1000 + + write_timeout_usec_ / 1000); + setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, + reinterpret_cast(&timeout), sizeof(timeout)); +#else + timeval tv; + tv.tv_sec = static_cast(write_timeout_sec_); + tv.tv_usec = static_cast(write_timeout_usec_); + setsockopt(sock, SOL_SOCKET, SO_SNDTIMEO, + reinterpret_cast(&tv), sizeof(tv)); +#endif + } + + if (!task_queue->enqueue( + [this, sock]() { process_and_close_socket(sock); })) { + detail::shutdown_socket(sock); + detail::close_socket(sock); + } + } + + task_queue->shutdown(); + } + + return ret; +} + +inline bool Server::routing(Request &req, Response &res, Stream &strm) { + if (pre_routing_handler_ && + pre_routing_handler_(req, res) == HandlerResponse::Handled) { + return true; + } + + // File handler + auto is_head_request = req.method == "HEAD"; + if ((req.method == "GET" || is_head_request) && + handle_file_request(req, res, is_head_request)) { + return true; + } + + if (detail::expect_content(req)) { + // Content reader handler + { + ContentReader reader( + [&](ContentReceiver receiver) { + return read_content_with_content_receiver( + strm, req, res, std::move(receiver), nullptr, nullptr); + }, + [&](MultipartContentHeader header, ContentReceiver receiver) { + return read_content_with_content_receiver(strm, req, res, nullptr, + std::move(header), + std::move(receiver)); + }); + + if (req.method == "POST") { + if (dispatch_request_for_content_reader( + req, res, std::move(reader), + post_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PUT") { + if (dispatch_request_for_content_reader( + req, res, std::move(reader), + put_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "PATCH") { + if (dispatch_request_for_content_reader( + req, res, std::move(reader), + patch_handlers_for_content_reader_)) { + return true; + } + } else if (req.method == "DELETE") { + if (dispatch_request_for_content_reader( + req, res, std::move(reader), + delete_handlers_for_content_reader_)) { + return true; + } + } + } + + // Read content into `req.body` + if (!read_content(strm, req, res)) { return false; } + } + + // Regular handler + if (req.method == "GET" || req.method == "HEAD") { + return dispatch_request(req, res, get_handlers_); + } else if (req.method == "POST") { + return dispatch_request(req, res, post_handlers_); + } else if (req.method == "PUT") { + return dispatch_request(req, res, put_handlers_); + } else if (req.method == "DELETE") { + return dispatch_request(req, res, delete_handlers_); + } else if (req.method == "OPTIONS") { + return dispatch_request(req, res, options_handlers_); + } else if (req.method == "PATCH") { + return dispatch_request(req, res, patch_handlers_); + } + + res.status = StatusCode::BadRequest_400; + return false; +} + +inline bool Server::dispatch_request(Request &req, Response &res, + const Handlers &handlers) const { + for (const auto &x : handlers) { + const auto &matcher = x.first; + const auto &handler = x.second; + + if (matcher->match(req)) { + handler(req, res); + return true; + } + } + return false; +} + +inline void Server::apply_ranges(const Request &req, Response &res, + std::string &content_type, + std::string &boundary) const { + if (req.ranges.size() > 1) { + auto it = res.headers.find("Content-Type"); + if (it != res.headers.end()) { + content_type = it->second; + res.headers.erase(it); + } + + boundary = detail::make_multipart_data_boundary(); + + res.set_header("Content-Type", + "multipart/byteranges; boundary=" + boundary); + } + + auto type = detail::encoding_type(req, res); + + if (res.body.empty()) { + if (res.content_length_ > 0) { + size_t length = 0; + if (req.ranges.empty()) { + length = res.content_length_; + } else if (req.ranges.size() == 1) { + auto offset_and_length = detail::get_range_offset_and_length( + req.ranges[0], res.content_length_); + + length = offset_and_length.second; + + auto content_range = detail::make_content_range_header_field( + offset_and_length, res.content_length_); + res.set_header("Content-Range", content_range); + } else { + length = detail::get_multipart_ranges_data_length( + req, boundary, content_type, res.content_length_); + } + res.set_header("Content-Length", std::to_string(length)); + } else { + if (res.content_provider_) { + if (res.is_chunked_content_provider_) { + res.set_header("Transfer-Encoding", "chunked"); + if (type == detail::EncodingType::Gzip) { + res.set_header("Content-Encoding", "gzip"); + } else if (type == detail::EncodingType::Brotli) { + res.set_header("Content-Encoding", "br"); + } + } + } + } + } else { + if (req.ranges.empty()) { + ; + } else if (req.ranges.size() == 1) { + auto offset_and_length = + detail::get_range_offset_and_length(req.ranges[0], res.body.size()); + auto offset = offset_and_length.first; + auto length = offset_and_length.second; + + auto content_range = detail::make_content_range_header_field( + offset_and_length, res.body.size()); + res.set_header("Content-Range", content_range); + + assert(offset + length <= res.body.size()); + res.body = res.body.substr(offset, length); + } else { + std::string data; + detail::make_multipart_ranges_data(req, res, boundary, content_type, + res.body.size(), data); + res.body.swap(data); + } + + if (type != detail::EncodingType::None) { + std::unique_ptr compressor; + std::string content_encoding; + + if (type == detail::EncodingType::Gzip) { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + compressor = detail::make_unique(); + content_encoding = "gzip"; +#endif + } else if (type == detail::EncodingType::Brotli) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + compressor = detail::make_unique(); + content_encoding = "br"; +#endif + } + + if (compressor) { + std::string compressed; + if (compressor->compress(res.body.data(), res.body.size(), true, + [&](const char *data, size_t data_len) { + compressed.append(data, data_len); + return true; + })) { + res.body.swap(compressed); + res.set_header("Content-Encoding", content_encoding); + } + } + } + + auto length = std::to_string(res.body.size()); + res.set_header("Content-Length", length); + } +} + +inline bool Server::dispatch_request_for_content_reader( + Request &req, Response &res, ContentReader content_reader, + const HandlersForContentReader &handlers) const { + for (const auto &x : handlers) { + const auto &matcher = x.first; + const auto &handler = x.second; + + if (matcher->match(req)) { + handler(req, res, content_reader); + return true; + } + } + return false; +} + +inline bool +Server::process_request(Stream &strm, bool close_connection, + bool &connection_closed, + const std::function &setup_request) { + std::array buf{}; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + // Connection has been closed on client + if (!line_reader.getline()) { return false; } + + Request req; + + Response res; + res.version = "HTTP/1.1"; + res.headers = default_headers_; + +#ifdef _WIN32 + // TODO: Increase FD_SETSIZE statically (libzmq), dynamically (MySQL). +#else +#ifndef CPPHTTPLIB_USE_POLL + // Socket file descriptor exceeded FD_SETSIZE... + if (strm.socket() >= FD_SETSIZE) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = StatusCode::InternalServerError_500; + return write_response(strm, close_connection, req, res); + } +#endif +#endif + + // Check if the request URI doesn't exceed the limit + if (line_reader.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + Headers dummy; + detail::read_headers(strm, dummy); + res.status = StatusCode::UriTooLong_414; + return write_response(strm, close_connection, req, res); + } + + // Request line and headers + if (!parse_request_line(line_reader.ptr(), req) || + !detail::read_headers(strm, req.headers)) { + res.status = StatusCode::BadRequest_400; + return write_response(strm, close_connection, req, res); + } + + if (req.get_header_value("Connection") == "close") { + connection_closed = true; + } + + if (req.version == "HTTP/1.0" && + req.get_header_value("Connection") != "Keep-Alive") { + connection_closed = true; + } + + strm.get_remote_ip_and_port(req.remote_addr, req.remote_port); + req.set_header("REMOTE_ADDR", req.remote_addr); + req.set_header("REMOTE_PORT", std::to_string(req.remote_port)); + + strm.get_local_ip_and_port(req.local_addr, req.local_port); + req.set_header("LOCAL_ADDR", req.local_addr); + req.set_header("LOCAL_PORT", std::to_string(req.local_port)); + + if (req.has_header("Range")) { + const auto &range_header_value = req.get_header_value("Range"); + if (!detail::parse_range_header(range_header_value, req.ranges)) { + res.status = StatusCode::RangeNotSatisfiable_416; + return write_response(strm, close_connection, req, res); + } + } + + if (setup_request) { setup_request(req); } + + if (req.get_header_value("Expect") == "100-continue") { + int status = StatusCode::Continue_100; + if (expect_100_continue_handler_) { + status = expect_100_continue_handler_(req, res); + } + switch (status) { + case StatusCode::Continue_100: + case StatusCode::ExpectationFailed_417: + strm.write_format("HTTP/1.1 %d %s\r\n\r\n", status, + status_message(status)); + break; + default: return write_response(strm, close_connection, req, res); + } + } + + // Routing + auto routed = false; +#ifdef CPPHTTPLIB_NO_EXCEPTIONS + routed = routing(req, res, strm); +#else + try { + routed = routing(req, res, strm); + } catch (std::exception &e) { + if (exception_handler_) { + auto ep = std::current_exception(); + exception_handler_(req, res, ep); + routed = true; + } else { + res.status = StatusCode::InternalServerError_500; + std::string val; + auto s = e.what(); + for (size_t i = 0; s[i]; i++) { + switch (s[i]) { + case '\r': val += "\\r"; break; + case '\n': val += "\\n"; break; + default: val += s[i]; break; + } + } + res.set_header("EXCEPTION_WHAT", val); + } + } catch (...) { + if (exception_handler_) { + auto ep = std::current_exception(); + exception_handler_(req, res, ep); + routed = true; + } else { + res.status = StatusCode::InternalServerError_500; + res.set_header("EXCEPTION_WHAT", "UNKNOWN"); + } + } +#endif + if (routed) { + if (res.status == -1) { + res.status = req.ranges.empty() ? StatusCode::OK_200 + : StatusCode::PartialContent_206; + } + + if (detail::range_error(req, res)) { + res.body.clear(); + res.content_length_ = 0; + res.content_provider_ = nullptr; + res.status = StatusCode::RangeNotSatisfiable_416; + return write_response(strm, close_connection, req, res); + } + + return write_response_with_content(strm, close_connection, req, res); + } else { + if (res.status == -1) { res.status = StatusCode::NotFound_404; } + + return write_response(strm, close_connection, req, res); + } +} + +inline bool Server::is_valid() const { return true; } + +inline bool Server::process_and_close_socket(socket_t sock) { + auto ret = detail::process_server_socket( + svr_sock_, sock, keep_alive_max_count_, keep_alive_timeout_sec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, + [this](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, close_connection, connection_closed, + nullptr); + }); + + detail::shutdown_socket(sock); + detail::close_socket(sock); + return ret; +} + +// HTTP client implementation +inline ClientImpl::ClientImpl(const std::string &host) + : ClientImpl(host, 80, std::string(), std::string()) {} + +inline ClientImpl::ClientImpl(const std::string &host, int port) + : ClientImpl(host, port, std::string(), std::string()) {} + +inline ClientImpl::ClientImpl(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : host_(host), port_(port), + host_and_port_(adjust_host_string(host) + ":" + std::to_string(port)), + client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} + +inline ClientImpl::~ClientImpl() { + std::lock_guard guard(socket_mutex_); + shutdown_socket(socket_); + close_socket(socket_); +} + +inline bool ClientImpl::is_valid() const { return true; } + +inline void ClientImpl::copy_settings(const ClientImpl &rhs) { + client_cert_path_ = rhs.client_cert_path_; + client_key_path_ = rhs.client_key_path_; + connection_timeout_sec_ = rhs.connection_timeout_sec_; + read_timeout_sec_ = rhs.read_timeout_sec_; + read_timeout_usec_ = rhs.read_timeout_usec_; + write_timeout_sec_ = rhs.write_timeout_sec_; + write_timeout_usec_ = rhs.write_timeout_usec_; + basic_auth_username_ = rhs.basic_auth_username_; + basic_auth_password_ = rhs.basic_auth_password_; + bearer_token_auth_token_ = rhs.bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + digest_auth_username_ = rhs.digest_auth_username_; + digest_auth_password_ = rhs.digest_auth_password_; +#endif + keep_alive_ = rhs.keep_alive_; + follow_location_ = rhs.follow_location_; + url_encode_ = rhs.url_encode_; + address_family_ = rhs.address_family_; + tcp_nodelay_ = rhs.tcp_nodelay_; + socket_options_ = rhs.socket_options_; + compress_ = rhs.compress_; + decompress_ = rhs.decompress_; + interface_ = rhs.interface_; + proxy_host_ = rhs.proxy_host_; + proxy_port_ = rhs.proxy_port_; + proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; + proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; + proxy_bearer_token_auth_token_ = rhs.proxy_bearer_token_auth_token_; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; + proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; +#endif +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + ca_cert_file_path_ = rhs.ca_cert_file_path_; + ca_cert_dir_path_ = rhs.ca_cert_dir_path_; + ca_cert_store_ = rhs.ca_cert_store_; +#endif +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + server_certificate_verification_ = rhs.server_certificate_verification_; +#endif + logger_ = rhs.logger_; +} + +inline socket_t ClientImpl::create_client_socket(Error &error) const { + if (!proxy_host_.empty() && proxy_port_ != -1) { + return detail::create_client_socket( + proxy_host_, std::string(), proxy_port_, address_family_, tcp_nodelay_, + socket_options_, connection_timeout_sec_, connection_timeout_usec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, interface_, error); + } + + // Check is custom IP specified for host_ + std::string ip; + auto it = addr_map_.find(host_); + if (it != addr_map_.end()) { ip = it->second; } + + return detail::create_client_socket( + host_, ip, port_, address_family_, tcp_nodelay_, socket_options_, + connection_timeout_sec_, connection_timeout_usec_, read_timeout_sec_, + read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, interface_, + error); +} + +inline bool ClientImpl::create_and_connect_socket(Socket &socket, + Error &error) { + auto sock = create_client_socket(error); + if (sock == INVALID_SOCKET) { return false; } + socket.sock = sock; + return true; +} + +inline void ClientImpl::shutdown_ssl(Socket & /*socket*/, + bool /*shutdown_gracefully*/) { + // If there are any requests in flight from threads other than us, then it's + // a thread-unsafe race because individual ssl* objects are not thread-safe. + assert(socket_requests_in_flight_ == 0 || + socket_requests_are_from_thread_ == std::this_thread::get_id()); +} + +inline void ClientImpl::shutdown_socket(Socket &socket) const { + if (socket.sock == INVALID_SOCKET) { return; } + detail::shutdown_socket(socket.sock); +} + +inline void ClientImpl::close_socket(Socket &socket) { + // If there are requests in flight in another thread, usually closing + // the socket will be fine and they will simply receive an error when + // using the closed socket, but it is still a bug since rarely the OS + // may reassign the socket id to be used for a new socket, and then + // suddenly they will be operating on a live socket that is different + // than the one they intended! + assert(socket_requests_in_flight_ == 0 || + socket_requests_are_from_thread_ == std::this_thread::get_id()); + + // It is also a bug if this happens while SSL is still active +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + assert(socket.ssl == nullptr); +#endif + if (socket.sock == INVALID_SOCKET) { return; } + detail::close_socket(socket.sock); + socket.sock = INVALID_SOCKET; +} + +inline bool ClientImpl::read_response_line(Stream &strm, const Request &req, + Response &res) const { + std::array buf{}; + + detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); + + if (!line_reader.getline()) { return false; } + +#ifdef CPPHTTPLIB_ALLOW_LF_AS_LINE_TERMINATOR + const static std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r?\n"); +#else + const static std::regex re("(HTTP/1\\.[01]) (\\d{3})(?: (.*?))?\r\n"); +#endif + + std::cmatch m; + if (!std::regex_match(line_reader.ptr(), m, re)) { + return req.method == "CONNECT"; + } + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + res.reason = std::string(m[3]); + + // Ignore '100 Continue' + while (res.status == StatusCode::Continue_100) { + if (!line_reader.getline()) { return false; } // CRLF + if (!line_reader.getline()) { return false; } // next response line + + if (!std::regex_match(line_reader.ptr(), m, re)) { return false; } + res.version = std::string(m[1]); + res.status = std::stoi(std::string(m[2])); + res.reason = std::string(m[3]); + } + + return true; +} + +inline bool ClientImpl::send(Request &req, Response &res, Error &error) { + std::lock_guard request_mutex_guard(request_mutex_); + auto ret = send_(req, res, error); + if (error == Error::SSLPeerCouldBeClosed_) { + assert(!ret); + ret = send_(req, res, error); + } + return ret; +} + +inline bool ClientImpl::send_(Request &req, Response &res, Error &error) { + { + std::lock_guard guard(socket_mutex_); + + // Set this to false immediately - if it ever gets set to true by the end of + // the request, we know another thread instructed us to close the socket. + socket_should_be_closed_when_request_is_done_ = false; + + auto is_alive = false; + if (socket_.is_open()) { + is_alive = detail::is_socket_alive(socket_.sock); + if (!is_alive) { + // Attempt to avoid sigpipe by shutting down nongracefully if it seems + // like the other side has already closed the connection Also, there + // cannot be any requests in flight from other threads since we locked + // request_mutex_, so safe to close everything immediately + const bool shutdown_gracefully = false; + shutdown_ssl(socket_, shutdown_gracefully); + shutdown_socket(socket_); + close_socket(socket_); + } + } + + if (!is_alive) { + if (!create_and_connect_socket(socket_, error)) { return false; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // TODO: refactoring + if (is_ssl()) { + auto &scli = static_cast(*this); + if (!proxy_host_.empty() && proxy_port_ != -1) { + auto success = false; + if (!scli.connect_with_proxy(socket_, res, success, error)) { + return success; + } + } + + if (!scli.initialize_ssl(socket_, error)) { return false; } + } +#endif + } + + // Mark the current socket as being in use so that it cannot be closed by + // anyone else while this request is ongoing, even though we will be + // releasing the mutex. + if (socket_requests_in_flight_ > 1) { + assert(socket_requests_are_from_thread_ == std::this_thread::get_id()); + } + socket_requests_in_flight_ += 1; + socket_requests_are_from_thread_ = std::this_thread::get_id(); + } + + for (const auto &header : default_headers_) { + if (req.headers.find(header.first) == req.headers.end()) { + req.headers.insert(header); + } + } + + auto ret = false; + auto close_connection = !keep_alive_; + + auto se = detail::scope_exit([&]() { + // Briefly lock mutex in order to mark that a request is no longer ongoing + std::lock_guard guard(socket_mutex_); + socket_requests_in_flight_ -= 1; + if (socket_requests_in_flight_ <= 0) { + assert(socket_requests_in_flight_ == 0); + socket_requests_are_from_thread_ = std::thread::id(); + } + + if (socket_should_be_closed_when_request_is_done_ || close_connection || + !ret) { + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); + } + }); + + ret = process_socket(socket_, [&](Stream &strm) { + return handle_request(strm, req, res, close_connection, error); + }); + + if (!ret) { + if (error == Error::Success) { error = Error::Unknown; } + } + + return ret; +} + +inline Result ClientImpl::send(const Request &req) { + auto req2 = req; + return send_(std::move(req2)); +} + +inline Result ClientImpl::send_(Request &&req) { + auto res = detail::make_unique(); + auto error = Error::Success; + auto ret = send(req, *res, error); + return Result{ret ? std::move(res) : nullptr, error, std::move(req.headers)}; +} + +inline bool ClientImpl::handle_request(Stream &strm, Request &req, + Response &res, bool close_connection, + Error &error) { + if (req.path.empty()) { + error = Error::Connection; + return false; + } + + auto req_save = req; + + bool ret; + + if (!is_ssl() && !proxy_host_.empty() && proxy_port_ != -1) { + auto req2 = req; + req2.path = "http://" + host_and_port_ + req.path; + ret = process_request(strm, req2, res, close_connection, error); + req = req2; + req.path = req_save.path; + } else { + ret = process_request(strm, req, res, close_connection, error); + } + + if (!ret) { return false; } + + if (res.get_header_value("Connection") == "close" || + (res.version == "HTTP/1.0" && res.reason != "Connection established")) { + // TODO this requires a not-entirely-obvious chain of calls to be correct + // for this to be safe. + + // This is safe to call because handle_request is only called by send_ + // which locks the request mutex during the process. It would be a bug + // to call it from a different thread since it's a thread-safety issue + // to do these things to the socket if another thread is using the socket. + std::lock_guard guard(socket_mutex_); + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); + } + + if (300 < res.status && res.status < 400 && follow_location_) { + req = req_save; + ret = redirect(req, res, error); + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if ((res.status == StatusCode::Unauthorized_401 || + res.status == StatusCode::ProxyAuthenticationRequired_407) && + req.authorization_count_ < 5) { + auto is_proxy = res.status == StatusCode::ProxyAuthenticationRequired_407; + const auto &username = + is_proxy ? proxy_digest_auth_username_ : digest_auth_username_; + const auto &password = + is_proxy ? proxy_digest_auth_password_ : digest_auth_password_; + + if (!username.empty() && !password.empty()) { + std::map auth; + if (detail::parse_www_authenticate(res, auth, is_proxy)) { + Request new_req = req; + new_req.authorization_count_ += 1; + new_req.headers.erase(is_proxy ? "Proxy-Authorization" + : "Authorization"); + new_req.headers.insert(detail::make_digest_authentication_header( + req, auth, new_req.authorization_count_, detail::random_string(10), + username, password, is_proxy)); + + Response new_res; + + ret = send(new_req, new_res, error); + if (ret) { res = new_res; } + } + } + } +#endif + + return ret; +} + +inline bool ClientImpl::redirect(Request &req, Response &res, Error &error) { + if (req.redirect_count_ == 0) { + error = Error::ExceedRedirectCount; + return false; + } + + auto location = res.get_header_value("location"); + if (location.empty()) { return false; } + + const static std::regex re( + R"((?:(https?):)?(?://(?:\[([\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)"); + + std::smatch m; + if (!std::regex_match(location, m, re)) { return false; } + + auto scheme = is_ssl() ? "https" : "http"; + + auto next_scheme = m[1].str(); + auto next_host = m[2].str(); + if (next_host.empty()) { next_host = m[3].str(); } + auto port_str = m[4].str(); + auto next_path = m[5].str(); + auto next_query = m[6].str(); + + auto next_port = port_; + if (!port_str.empty()) { + next_port = std::stoi(port_str); + } else if (!next_scheme.empty()) { + next_port = next_scheme == "https" ? 443 : 80; + } + + if (next_scheme.empty()) { next_scheme = scheme; } + if (next_host.empty()) { next_host = host_; } + if (next_path.empty()) { next_path = "/"; } + + auto path = detail::decode_url(next_path, true) + next_query; + + if (next_scheme == scheme && next_host == host_ && next_port == port_) { + return detail::redirect(*this, req, res, path, location, error); + } else { + if (next_scheme == "https") { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSLClient cli(next_host, next_port); + cli.copy_settings(*this); + if (ca_cert_store_) { cli.set_ca_cert_store(ca_cert_store_); } + return detail::redirect(cli, req, res, path, location, error); +#else + return false; +#endif + } else { + ClientImpl cli(next_host, next_port); + cli.copy_settings(*this); + return detail::redirect(cli, req, res, path, location, error); + } + } +} + +inline bool ClientImpl::write_content_with_provider(Stream &strm, + const Request &req, + Error &error) const { + auto is_shutting_down = []() { return false; }; + + if (req.is_chunked_content_provider_) { + // TODO: Brotli support + std::unique_ptr compressor; +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_) { + compressor = detail::make_unique(); + } else +#endif + { + compressor = detail::make_unique(); + } + + return detail::write_content_chunked(strm, req.content_provider_, + is_shutting_down, *compressor, error); + } else { + return detail::write_content(strm, req.content_provider_, 0, + req.content_length_, is_shutting_down, error); + } +} + +inline bool ClientImpl::write_request(Stream &strm, Request &req, + bool close_connection, Error &error) { + // Prepare additional headers + if (close_connection) { + if (!req.has_header("Connection")) { + req.set_header("Connection", "close"); + } + } + + if (!req.has_header("Host")) { + if (is_ssl()) { + if (port_ == 443) { + req.set_header("Host", host_); + } else { + req.set_header("Host", host_and_port_); + } + } else { + if (port_ == 80) { + req.set_header("Host", host_); + } else { + req.set_header("Host", host_and_port_); + } + } + } + + if (!req.has_header("Accept")) { req.set_header("Accept", "*/*"); } + +#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT + if (!req.has_header("User-Agent")) { + auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION; + req.set_header("User-Agent", agent); + } +#endif + + if (req.body.empty()) { + if (req.content_provider_) { + if (!req.is_chunked_content_provider_) { + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.content_length_); + req.set_header("Content-Length", length); + } + } + } else { + if (req.method == "POST" || req.method == "PUT" || + req.method == "PATCH") { + req.set_header("Content-Length", "0"); + } + } + } else { + if (!req.has_header("Content-Type")) { + req.set_header("Content-Type", "text/plain"); + } + + if (!req.has_header("Content-Length")) { + auto length = std::to_string(req.body.size()); + req.set_header("Content-Length", length); + } + } + + if (!basic_auth_password_.empty() || !basic_auth_username_.empty()) { + if (!req.has_header("Authorization")) { + req.headers.insert(make_basic_authentication_header( + basic_auth_username_, basic_auth_password_, false)); + } + } + + if (!proxy_basic_auth_username_.empty() && + !proxy_basic_auth_password_.empty()) { + if (!req.has_header("Proxy-Authorization")) { + req.headers.insert(make_basic_authentication_header( + proxy_basic_auth_username_, proxy_basic_auth_password_, true)); + } + } + + if (!bearer_token_auth_token_.empty()) { + if (!req.has_header("Authorization")) { + req.headers.insert(make_bearer_token_authentication_header( + bearer_token_auth_token_, false)); + } + } + + if (!proxy_bearer_token_auth_token_.empty()) { + if (!req.has_header("Proxy-Authorization")) { + req.headers.insert(make_bearer_token_authentication_header( + proxy_bearer_token_auth_token_, true)); + } + } + + // Request line and headers + { + detail::BufferStream bstrm; + + const auto &path = url_encode_ ? detail::encode_url(req.path) : req.path; + bstrm.write_format("%s %s HTTP/1.1\r\n", req.method.c_str(), path.c_str()); + + header_writer_(bstrm, req.headers); + + // Flush buffer + auto &data = bstrm.get_buffer(); + if (!detail::write_data(strm, data.data(), data.size())) { + error = Error::Write; + return false; + } + } + + // Body + if (req.body.empty()) { + return write_content_with_provider(strm, req, error); + } + + if (!detail::write_data(strm, req.body.data(), req.body.size())) { + error = Error::Write; + return false; + } + + return true; +} + +inline std::unique_ptr ClientImpl::send_with_content_provider( + Request &req, const char *body, size_t content_length, + ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type, Error &error) { + if (!content_type.empty()) { req.set_header("Content-Type", content_type); } + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_) { req.set_header("Content-Encoding", "gzip"); } +#endif + +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (compress_ && !content_provider_without_length) { + // TODO: Brotli support + detail::gzip_compressor compressor; + + if (content_provider) { + auto ok = true; + size_t offset = 0; + DataSink data_sink; + + data_sink.write = [&](const char *data, size_t data_len) -> bool { + if (ok) { + auto last = offset + data_len == content_length; + + auto ret = compressor.compress( + data, data_len, last, + [&](const char *compressed_data, size_t compressed_data_len) { + req.body.append(compressed_data, compressed_data_len); + return true; + }); + + if (ret) { + offset += data_len; + } else { + ok = false; + } + } + return ok; + }; + + while (ok && offset < content_length) { + if (!content_provider(offset, content_length - offset, data_sink)) { + error = Error::Canceled; + return nullptr; + } + } + } else { + if (!compressor.compress(body, content_length, true, + [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + return true; + })) { + error = Error::Compression; + return nullptr; + } + } + } else +#endif + { + if (content_provider) { + req.content_length_ = content_length; + req.content_provider_ = std::move(content_provider); + req.is_chunked_content_provider_ = false; + } else if (content_provider_without_length) { + req.content_length_ = 0; + req.content_provider_ = detail::ContentProviderAdapter( + std::move(content_provider_without_length)); + req.is_chunked_content_provider_ = true; + req.set_header("Transfer-Encoding", "chunked"); + } else { + req.body.assign(body, content_length); + } + } + + auto res = detail::make_unique(); + return send(req, *res, error) ? std::move(res) : nullptr; +} + +inline Result ClientImpl::send_with_content_provider( + const std::string &method, const std::string &path, const Headers &headers, + const char *body, size_t content_length, ContentProvider content_provider, + ContentProviderWithoutLength content_provider_without_length, + const std::string &content_type) { + Request req; + req.method = method; + req.headers = headers; + req.path = path; + + auto error = Error::Success; + + auto res = send_with_content_provider( + req, body, content_length, std::move(content_provider), + std::move(content_provider_without_length), content_type, error); + + return Result{std::move(res), error, std::move(req.headers)}; +} + +inline std::string +ClientImpl::adjust_host_string(const std::string &host) const { + if (host.find(':') != std::string::npos) { return "[" + host + "]"; } + return host; +} + +inline bool ClientImpl::process_request(Stream &strm, Request &req, + Response &res, bool close_connection, + Error &error) { + // Send request + if (!write_request(strm, req, close_connection, error)) { return false; } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl()) { + auto is_proxy_enabled = !proxy_host_.empty() && proxy_port_ != -1; + if (!is_proxy_enabled) { + char buf[1]; + if (SSL_peek(socket_.ssl, buf, 1) == 0 && + SSL_get_error(socket_.ssl, 0) == SSL_ERROR_ZERO_RETURN) { + error = Error::SSLPeerCouldBeClosed_; + return false; + } + } + } +#endif + + // Receive response and headers + if (!read_response_line(strm, req, res) || + !detail::read_headers(strm, res.headers)) { + error = Error::Read; + return false; + } + + // Body + if ((res.status != StatusCode::NoContent_204) && req.method != "HEAD" && + req.method != "CONNECT") { + auto redirect = 300 < res.status && res.status < 400 && follow_location_; + + if (req.response_handler && !redirect) { + if (!req.response_handler(res)) { + error = Error::Canceled; + return false; + } + } + + auto out = + req.content_receiver + ? static_cast( + [&](const char *buf, size_t n, uint64_t off, uint64_t len) { + if (redirect) { return true; } + auto ret = req.content_receiver(buf, n, off, len); + if (!ret) { error = Error::Canceled; } + return ret; + }) + : static_cast( + [&](const char *buf, size_t n, uint64_t /*off*/, + uint64_t /*len*/) { + if (res.body.size() + n > res.body.max_size()) { + return false; + } + res.body.append(buf, n); + return true; + }); + + auto progress = [&](uint64_t current, uint64_t total) { + if (!req.progress || redirect) { return true; } + auto ret = req.progress(current, total); + if (!ret) { error = Error::Canceled; } + return ret; + }; + + int dummy_status; + if (!detail::read_content(strm, res, (std::numeric_limits::max)(), + dummy_status, std::move(progress), std::move(out), + decompress_)) { + if (error != Error::Canceled) { error = Error::Read; } + return false; + } + } + + // Log + if (logger_) { logger_(req, res); } + + return true; +} + +inline ContentProviderWithoutLength ClientImpl::get_multipart_content_provider( + const std::string &boundary, const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) const { + size_t cur_item = 0; + size_t cur_start = 0; + // cur_item and cur_start are copied to within the std::function and maintain + // state between successive calls + return [&, cur_item, cur_start](size_t offset, + DataSink &sink) mutable -> bool { + if (!offset && !items.empty()) { + sink.os << detail::serialize_multipart_formdata(items, boundary, false); + return true; + } else if (cur_item < provider_items.size()) { + if (!cur_start) { + const auto &begin = detail::serialize_multipart_formdata_item_begin( + provider_items[cur_item], boundary); + offset += begin.size(); + cur_start = offset; + sink.os << begin; + } + + DataSink cur_sink; + auto has_data = true; + cur_sink.write = sink.write; + cur_sink.done = [&]() { has_data = false; }; + + if (!provider_items[cur_item].provider(offset - cur_start, cur_sink)) { + return false; + } + + if (!has_data) { + sink.os << detail::serialize_multipart_formdata_item_end(); + cur_item++; + cur_start = 0; + } + return true; + } else { + sink.os << detail::serialize_multipart_formdata_finish(boundary); + sink.done(); + return true; + } + }; +} + +inline bool +ClientImpl::process_socket(const Socket &socket, + std::function callback) { + return detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, std::move(callback)); +} + +inline bool ClientImpl::is_ssl() const { return false; } + +inline Result ClientImpl::Get(const std::string &path) { + return Get(path, Headers(), Progress()); +} + +inline Result ClientImpl::Get(const std::string &path, Progress progress) { + return Get(path, Headers(), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers) { + return Get(path, headers, Progress()); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, + Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.progress = std::move(progress); + + return send_(std::move(req)); +} + +inline Result ClientImpl::Get(const std::string &path, + ContentReceiver content_receiver) { + return Get(path, Headers(), nullptr, std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, Headers(), nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver) { + return Get(path, headers, nullptr, std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, headers, nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + return Get(path, Headers(), std::move(response_handler), + std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + return Get(path, headers, std::move(response_handler), + std::move(content_receiver), nullptr); +} + +inline Result ClientImpl::Get(const std::string &path, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, Headers(), std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + Request req; + req.method = "GET"; + req.path = path; + req.headers = headers; + req.response_handler = std::move(response_handler); + req.content_receiver = + [content_receiver](const char *data, size_t data_length, + uint64_t /*offset*/, uint64_t /*total_length*/) { + return content_receiver(data, data_length); + }; + req.progress = std::move(progress); + + return send_(std::move(req)); +} + +inline Result ClientImpl::Get(const std::string &path, const Params ¶ms, + const Headers &headers, Progress progress) { + if (params.empty()) { return Get(path, headers); } + + std::string path_with_query = append_query_params(path, params); + return Get(path_with_query, headers, std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + ContentReceiver content_receiver, + Progress progress) { + return Get(path, params, headers, nullptr, std::move(content_receiver), + std::move(progress)); +} + +inline Result ClientImpl::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, + Progress progress) { + if (params.empty()) { + return Get(path, headers, std::move(response_handler), + std::move(content_receiver), std::move(progress)); + } + + std::string path_with_query = append_query_params(path, params); + return Get(path_with_query, headers, std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} + +inline Result ClientImpl::Head(const std::string &path) { + return Head(path, Headers()); +} + +inline Result ClientImpl::Head(const std::string &path, + const Headers &headers) { + Request req; + req.method = "HEAD"; + req.headers = headers; + req.path = path; + + return send_(std::move(req)); +} + +inline Result ClientImpl::Post(const std::string &path) { + return Post(path, std::string(), std::string()); +} + +inline Result ClientImpl::Post(const std::string &path, + const Headers &headers) { + return Post(path, headers, nullptr, 0, std::string()); +} + +inline Result ClientImpl::Post(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return Post(path, Headers(), body, content_length, content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type) { + return send_with_content_provider("POST", path, headers, body, content_length, + nullptr, nullptr, content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const std::string &body, + const std::string &content_type) { + return Post(path, Headers(), body, content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type) { + return send_with_content_provider("POST", path, headers, body.data(), + body.size(), nullptr, nullptr, + content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const Params ¶ms) { + return Post(path, Headers(), params); +} + +inline Result ClientImpl::Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return Post(path, Headers(), content_length, std::move(content_provider), + content_type); +} + +inline Result ClientImpl::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return Post(path, Headers(), std::move(content_provider), content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return send_with_content_provider("POST", path, headers, nullptr, + content_length, std::move(content_provider), + nullptr, content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr, + std::move(content_provider), content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Post(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline Result ClientImpl::Post(const std::string &path, + const MultipartFormDataItems &items) { + return Post(path, Headers(), items); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Post(path, headers, body, content_type); +} + +inline Result ClientImpl::Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary) { + if (!detail::is_multipart_boundary_chars_valid(boundary)) { + return Result{nullptr, Error::UnsupportedMultipartBoundaryChars}; + } + + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Post(path, headers, body, content_type); +} + +inline Result +ClientImpl::Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + return send_with_content_provider( + "POST", path, headers, nullptr, 0, nullptr, + get_multipart_content_provider(boundary, items, provider_items), + content_type); +} + +inline Result ClientImpl::Put(const std::string &path) { + return Put(path, std::string(), std::string()); +} + +inline Result ClientImpl::Put(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return Put(path, Headers(), body, content_length, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, body, content_length, + nullptr, nullptr, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const std::string &body, + const std::string &content_type) { + return Put(path, Headers(), body, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, body.data(), + body.size(), nullptr, nullptr, + content_type); +} + +inline Result ClientImpl::Put(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return Put(path, Headers(), content_length, std::move(content_provider), + content_type); +} + +inline Result ClientImpl::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return Put(path, Headers(), std::move(content_provider), content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, nullptr, + content_length, std::move(content_provider), + nullptr, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr, + std::move(content_provider), content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const Params ¶ms) { + return Put(path, Headers(), params); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const Params ¶ms) { + auto query = detail::params_to_query_str(params); + return Put(path, headers, query, "application/x-www-form-urlencoded"); +} + +inline Result ClientImpl::Put(const std::string &path, + const MultipartFormDataItems &items) { + return Put(path, Headers(), items); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Put(path, headers, body, content_type); +} + +inline Result ClientImpl::Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary) { + if (!detail::is_multipart_boundary_chars_valid(boundary)) { + return Result{nullptr, Error::UnsupportedMultipartBoundaryChars}; + } + + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + const auto &body = detail::serialize_multipart_formdata(items, boundary); + return Put(path, headers, body, content_type); +} + +inline Result +ClientImpl::Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + const auto &boundary = detail::make_multipart_data_boundary(); + const auto &content_type = + detail::serialize_multipart_formdata_get_content_type(boundary); + return send_with_content_provider( + "PUT", path, headers, nullptr, 0, nullptr, + get_multipart_content_provider(boundary, items, provider_items), + content_type); +} +inline Result ClientImpl::Patch(const std::string &path) { + return Patch(path, std::string(), std::string()); +} + +inline Result ClientImpl::Patch(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return Patch(path, Headers(), body, content_length, content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type) { + return send_with_content_provider("PATCH", path, headers, body, + content_length, nullptr, nullptr, + content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, + const std::string &body, + const std::string &content_type) { + return Patch(path, Headers(), body, content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type) { + return send_with_content_provider("PATCH", path, headers, body.data(), + body.size(), nullptr, nullptr, + content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return Patch(path, Headers(), content_length, std::move(content_provider), + content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return Patch(path, Headers(), std::move(content_provider), content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return send_with_content_provider("PATCH", path, headers, nullptr, + content_length, std::move(content_provider), + nullptr, content_type); +} + +inline Result ClientImpl::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return send_with_content_provider("PATCH", path, headers, nullptr, 0, nullptr, + std::move(content_provider), content_type); +} + +inline Result ClientImpl::Delete(const std::string &path) { + return Delete(path, Headers(), std::string(), std::string()); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers) { + return Delete(path, headers, std::string(), std::string()); +} + +inline Result ClientImpl::Delete(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return Delete(path, Headers(), body, content_length, content_type); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, const char *body, + size_t content_length, + const std::string &content_type) { + Request req; + req.method = "DELETE"; + req.headers = headers; + req.path = path; + + if (!content_type.empty()) { req.set_header("Content-Type", content_type); } + req.body.assign(body, content_length); + + return send_(std::move(req)); +} + +inline Result ClientImpl::Delete(const std::string &path, + const std::string &body, + const std::string &content_type) { + return Delete(path, Headers(), body.data(), body.size(), content_type); +} + +inline Result ClientImpl::Delete(const std::string &path, + const Headers &headers, + const std::string &body, + const std::string &content_type) { + return Delete(path, headers, body.data(), body.size(), content_type); +} + +inline Result ClientImpl::Options(const std::string &path) { + return Options(path, Headers()); +} + +inline Result ClientImpl::Options(const std::string &path, + const Headers &headers) { + Request req; + req.method = "OPTIONS"; + req.headers = headers; + req.path = path; + + return send_(std::move(req)); +} + +inline void ClientImpl::stop() { + std::lock_guard guard(socket_mutex_); + + // If there is anything ongoing right now, the ONLY thread-safe thing we can + // do is to shutdown_socket, so that threads using this socket suddenly + // discover they can't read/write any more and error out. Everything else + // (closing the socket, shutting ssl down) is unsafe because these actions are + // not thread-safe. + if (socket_requests_in_flight_ > 0) { + shutdown_socket(socket_); + + // Aside from that, we set a flag for the socket to be closed when we're + // done. + socket_should_be_closed_when_request_is_done_ = true; + return; + } + + // Otherwise, still holding the mutex, we can shut everything down ourselves + shutdown_ssl(socket_, true); + shutdown_socket(socket_); + close_socket(socket_); +} + +inline std::string ClientImpl::host() const { return host_; } + +inline int ClientImpl::port() const { return port_; } + +inline size_t ClientImpl::is_socket_open() const { + std::lock_guard guard(socket_mutex_); + return socket_.is_open(); +} + +inline socket_t ClientImpl::socket() const { return socket_.sock; } + +inline void ClientImpl::set_connection_timeout(time_t sec, time_t usec) { + connection_timeout_sec_ = sec; + connection_timeout_usec_ = usec; +} + +inline void ClientImpl::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + +inline void ClientImpl::set_write_timeout(time_t sec, time_t usec) { + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; +} + +inline void ClientImpl::set_basic_auth(const std::string &username, + const std::string &password) { + basic_auth_username_ = username; + basic_auth_password_ = password; +} + +inline void ClientImpl::set_bearer_token_auth(const std::string &token) { + bearer_token_auth_token_ = token; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void ClientImpl::set_digest_auth(const std::string &username, + const std::string &password) { + digest_auth_username_ = username; + digest_auth_password_ = password; +} +#endif + +inline void ClientImpl::set_keep_alive(bool on) { keep_alive_ = on; } + +inline void ClientImpl::set_follow_location(bool on) { follow_location_ = on; } + +inline void ClientImpl::set_url_encode(bool on) { url_encode_ = on; } + +inline void +ClientImpl::set_hostname_addr_map(std::map addr_map) { + addr_map_ = std::move(addr_map); +} + +inline void ClientImpl::set_default_headers(Headers headers) { + default_headers_ = std::move(headers); +} + +inline void ClientImpl::set_header_writer( + std::function const &writer) { + header_writer_ = writer; +} + +inline void ClientImpl::set_address_family(int family) { + address_family_ = family; +} + +inline void ClientImpl::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } + +inline void ClientImpl::set_socket_options(SocketOptions socket_options) { + socket_options_ = std::move(socket_options); +} + +inline void ClientImpl::set_compress(bool on) { compress_ = on; } + +inline void ClientImpl::set_decompress(bool on) { decompress_ = on; } + +inline void ClientImpl::set_interface(const std::string &intf) { + interface_ = intf; +} + +inline void ClientImpl::set_proxy(const std::string &host, int port) { + proxy_host_ = host; + proxy_port_ = port; +} + +inline void ClientImpl::set_proxy_basic_auth(const std::string &username, + const std::string &password) { + proxy_basic_auth_username_ = username; + proxy_basic_auth_password_ = password; +} + +inline void ClientImpl::set_proxy_bearer_token_auth(const std::string &token) { + proxy_bearer_token_auth_token_ = token; +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void ClientImpl::set_proxy_digest_auth(const std::string &username, + const std::string &password) { + proxy_digest_auth_username_ = username; + proxy_digest_auth_password_ = password; +} + +inline void ClientImpl::set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path) { + ca_cert_file_path_ = ca_cert_file_path; + ca_cert_dir_path_ = ca_cert_dir_path; +} + +inline void ClientImpl::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (ca_cert_store && ca_cert_store != ca_cert_store_) { + ca_cert_store_ = ca_cert_store; + } +} + +inline X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert, + std::size_t size) const { + auto mem = BIO_new_mem_buf(ca_cert, static_cast(size)); + if (!mem) { return nullptr; } + + auto inf = PEM_X509_INFO_read_bio(mem, nullptr, nullptr, nullptr); + if (!inf) { + BIO_free_all(mem); + return nullptr; + } + + auto cts = X509_STORE_new(); + if (cts) { + for (auto i = 0; i < static_cast(sk_X509_INFO_num(inf)); i++) { + auto itmp = sk_X509_INFO_value(inf, i); + if (!itmp) { continue; } + + if (itmp->x509) { X509_STORE_add_cert(cts, itmp->x509); } + if (itmp->crl) { X509_STORE_add_crl(cts, itmp->crl); } + } + } + + sk_X509_INFO_pop_free(inf, X509_INFO_free); + BIO_free_all(mem); + return cts; +} + +inline void ClientImpl::enable_server_certificate_verification(bool enabled) { + server_certificate_verification_ = enabled; +} +#endif + +inline void ClientImpl::set_logger(Logger logger) { + logger_ = std::move(logger); +} + +/* + * SSL Implementation + */ +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +namespace detail { + +template +inline SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex, + U SSL_connect_or_accept, V setup) { + SSL *ssl = nullptr; + { + std::lock_guard guard(ctx_mutex); + ssl = SSL_new(ctx); + } + + if (ssl) { + set_nonblocking(sock, true); + auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); + BIO_set_nbio(bio, 1); + SSL_set_bio(ssl, bio, bio); + + if (!setup(ssl) || SSL_connect_or_accept(ssl) != 1) { + SSL_shutdown(ssl); + { + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); + } + set_nonblocking(sock, false); + return nullptr; + } + BIO_set_nbio(bio, 0); + set_nonblocking(sock, false); + } + + return ssl; +} + +inline void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, + bool shutdown_gracefully) { + // sometimes we may want to skip this to try to avoid SIGPIPE if we know + // the remote has closed the network connection + // Note that it is not always possible to avoid SIGPIPE, this is merely a + // best-efforts. + if (shutdown_gracefully) { SSL_shutdown(ssl); } + + std::lock_guard guard(ctx_mutex); + SSL_free(ssl); +} + +template +bool ssl_connect_or_accept_nonblocking(socket_t sock, SSL *ssl, + U ssl_connect_or_accept, + time_t timeout_sec, + time_t timeout_usec) { + auto res = 0; + while ((res = ssl_connect_or_accept(ssl)) != 1) { + auto err = SSL_get_error(ssl, res); + switch (err) { + case SSL_ERROR_WANT_READ: + if (select_read(sock, timeout_sec, timeout_usec) > 0) { continue; } + break; + case SSL_ERROR_WANT_WRITE: + if (select_write(sock, timeout_sec, timeout_usec) > 0) { continue; } + break; + default: break; + } + return false; + } + return true; +} + +template +inline bool process_server_socket_ssl( + const std::atomic &svr_sock, SSL *ssl, socket_t sock, + size_t keep_alive_max_count, time_t keep_alive_timeout_sec, + time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { + return process_server_socket_core( + svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); +} + +template +inline bool +process_client_socket_ssl(SSL *ssl, socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { + SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm); +} + +class SSLInit { +public: + SSLInit() { + OPENSSL_init_ssl( + OPENSSL_INIT_LOAD_SSL_STRINGS | OPENSSL_INIT_LOAD_CRYPTO_STRINGS, NULL); + } +}; + +// SSL socket stream implementation +inline SSLSocketStream::SSLSocketStream(socket_t sock, SSL *ssl, + time_t read_timeout_sec, + time_t read_timeout_usec, + time_t write_timeout_sec, + time_t write_timeout_usec) + : sock_(sock), ssl_(ssl), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec), + write_timeout_sec_(write_timeout_sec), + write_timeout_usec_(write_timeout_usec) { + SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY); +} + +inline SSLSocketStream::~SSLSocketStream() = default; + +inline bool SSLSocketStream::is_readable() const { + return detail::select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; +} + +inline bool SSLSocketStream::is_writable() const { + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && + is_socket_alive(sock_); +} + +inline ssize_t SSLSocketStream::read(char *ptr, size_t size) { + if (SSL_pending(ssl_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } else if (is_readable()) { + auto ret = SSL_read(ssl_, ptr, static_cast(size)); + if (ret < 0) { + auto err = SSL_get_error(ssl_, ret); + auto n = 1000; +#ifdef _WIN32 + while (--n >= 0 && (err == SSL_ERROR_WANT_READ || + (err == SSL_ERROR_SYSCALL && + WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err == SSL_ERROR_WANT_READ) { +#endif + if (SSL_pending(ssl_) > 0) { + return SSL_read(ssl_, ptr, static_cast(size)); + } else if (is_readable()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + ret = SSL_read(ssl_, ptr, static_cast(size)); + if (ret >= 0) { return ret; } + err = SSL_get_error(ssl_, ret); + } else { + return -1; + } + } + } + return ret; + } + return -1; +} + +inline ssize_t SSLSocketStream::write(const char *ptr, size_t size) { + if (is_writable()) { + auto handle_size = static_cast( + std::min(size, (std::numeric_limits::max)())); + + auto ret = SSL_write(ssl_, ptr, static_cast(handle_size)); + if (ret < 0) { + auto err = SSL_get_error(ssl_, ret); + auto n = 1000; +#ifdef _WIN32 + while (--n >= 0 && (err == SSL_ERROR_WANT_WRITE || + (err == SSL_ERROR_SYSCALL && + WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { +#endif + if (is_writable()) { + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + ret = SSL_write(ssl_, ptr, static_cast(handle_size)); + if (ret >= 0) { return ret; } + err = SSL_get_error(ssl_, ret); + } else { + return -1; + } + } + } + return ret; + } + return -1; +} + +inline void SSLSocketStream::get_remote_ip_and_port(std::string &ip, + int &port) const { + detail::get_remote_ip_and_port(sock_, ip, port); +} + +inline void SSLSocketStream::get_local_ip_and_port(std::string &ip, + int &port) const { + detail::get_local_ip_and_port(sock_, ip, port); +} + +inline socket_t SSLSocketStream::socket() const { return sock_; } + +static SSLInit sslinit_; + +} // namespace detail + +// SSL HTTP server implementation +inline SSLServer::SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path, + const char *client_ca_cert_dir_path, + const char *private_key_password) { + ctx_ = SSL_CTX_new(TLS_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_1_VERSION); + + if (private_key_password != nullptr && (private_key_password[0] != '\0')) { + SSL_CTX_set_default_passwd_cb_userdata( + ctx_, + reinterpret_cast(const_cast(private_key_password))); + } + + if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != + 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { + SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, + client_ca_cert_dir_path); + + SSL_CTX_set_verify( + ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); + } + } +} + +inline SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store) { + ctx_ = SSL_CTX_new(TLS_server_method()); + + if (ctx_) { + SSL_CTX_set_options(ctx_, + SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + + SSL_CTX_set_min_proto_version(ctx_, TLS1_1_VERSION); + + if (SSL_CTX_use_certificate(ctx_, cert) != 1 || + SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } else if (client_ca_cert_store) { + SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); + + SSL_CTX_set_verify( + ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); + } + } +} + +inline SSLServer::SSLServer( + const std::function &setup_ssl_ctx_callback) { + ctx_ = SSL_CTX_new(TLS_method()); + if (ctx_) { + if (!setup_ssl_ctx_callback(*ctx_)) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLServer::~SSLServer() { + if (ctx_) { SSL_CTX_free(ctx_); } +} + +inline bool SSLServer::is_valid() const { return ctx_; } + +inline SSL_CTX *SSLServer::ssl_context() const { return ctx_; } + +inline bool SSLServer::process_and_close_socket(socket_t sock) { + auto ssl = detail::ssl_new( + sock, ctx_, ctx_mutex_, + [&](SSL *ssl2) { + return detail::ssl_connect_or_accept_nonblocking( + sock, ssl2, SSL_accept, read_timeout_sec_, read_timeout_usec_); + }, + [](SSL * /*ssl2*/) { return true; }); + + auto ret = false; + if (ssl) { + ret = detail::process_server_socket_ssl( + svr_sock_, ssl, sock, keep_alive_max_count_, keep_alive_timeout_sec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, + [this, ssl](Stream &strm, bool close_connection, + bool &connection_closed) { + return process_request(strm, close_connection, connection_closed, + [&](Request &req) { req.ssl = ssl; }); + }); + + // Shutdown gracefully if the result seemed successful, non-gracefully if + // the connection appeared to be closed. + const bool shutdown_gracefully = ret; + detail::ssl_delete(ctx_mutex_, ssl, shutdown_gracefully); + } + + detail::shutdown_socket(sock); + detail::close_socket(sock); + return ret; +} + +// SSL HTTP client implementation +inline SSLClient::SSLClient(const std::string &host) + : SSLClient(host, 443, std::string(), std::string()) {} + +inline SSLClient::SSLClient(const std::string &host, int port) + : SSLClient(host, port, std::string(), std::string()) {} + +inline SSLClient::SSLClient(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path, + const std::string &private_key_password) + : ClientImpl(host, port, client_cert_path, client_key_path) { + ctx_ = SSL_CTX_new(TLS_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(b, e); + }); + + if (!client_cert_path.empty() && !client_key_path.empty()) { + if (!private_key_password.empty()) { + SSL_CTX_set_default_passwd_cb_userdata( + ctx_, reinterpret_cast( + const_cast(private_key_password.c_str()))); + } + + if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), + SSL_FILETYPE_PEM) != 1 || + SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), + SSL_FILETYPE_PEM) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLClient::SSLClient(const std::string &host, int port, + X509 *client_cert, EVP_PKEY *client_key, + const std::string &private_key_password) + : ClientImpl(host, port) { + ctx_ = SSL_CTX_new(TLS_client_method()); + + detail::split(&host_[0], &host_[host_.size()], '.', + [&](const char *b, const char *e) { + host_components_.emplace_back(b, e); + }); + + if (client_cert != nullptr && client_key != nullptr) { + if (!private_key_password.empty()) { + SSL_CTX_set_default_passwd_cb_userdata( + ctx_, reinterpret_cast( + const_cast(private_key_password.c_str()))); + } + + if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || + SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { + SSL_CTX_free(ctx_); + ctx_ = nullptr; + } + } +} + +inline SSLClient::~SSLClient() { + if (ctx_) { SSL_CTX_free(ctx_); } + // Make sure to shut down SSL since shutdown_ssl will resolve to the + // base function rather than the derived function once we get to the + // base class destructor, and won't free the SSL (causing a leak). + shutdown_ssl_impl(socket_, true); +} + +inline bool SSLClient::is_valid() const { return ctx_; } + +inline void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (ca_cert_store) { + if (ctx_) { + if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store) { + // Free memory allocated for old cert and use new store `ca_cert_store` + SSL_CTX_set_cert_store(ctx_, ca_cert_store); + } + } else { + X509_STORE_free(ca_cert_store); + } + } +} + +inline void SSLClient::load_ca_cert_store(const char *ca_cert, + std::size_t size) { + set_ca_cert_store(ClientImpl::create_ca_cert_store(ca_cert, size)); +} + +inline long SSLClient::get_openssl_verify_result() const { + return verify_result_; +} + +inline SSL_CTX *SSLClient::ssl_context() const { return ctx_; } + +inline bool SSLClient::create_and_connect_socket(Socket &socket, Error &error) { + return is_valid() && ClientImpl::create_and_connect_socket(socket, error); +} + +// Assumes that socket_mutex_ is locked and that there are no requests in flight +inline bool SSLClient::connect_with_proxy(Socket &socket, Response &res, + bool &success, Error &error) { + success = true; + Response proxy_res; + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { + Request req2; + req2.method = "CONNECT"; + req2.path = host_and_port_; + return process_request(strm, req2, proxy_res, false, error); + })) { + // Thread-safe to close everything because we are assuming there are no + // requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + + if (proxy_res.status == StatusCode::ProxyAuthenticationRequired_407) { + if (!proxy_digest_auth_username_.empty() && + !proxy_digest_auth_password_.empty()) { + std::map auth; + if (detail::parse_www_authenticate(proxy_res, auth, true)) { + proxy_res = Response(); + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, [&](Stream &strm) { + Request req3; + req3.method = "CONNECT"; + req3.path = host_and_port_; + req3.headers.insert(detail::make_digest_authentication_header( + req3, auth, 1, detail::random_string(10), + proxy_digest_auth_username_, proxy_digest_auth_password_, + true)); + return process_request(strm, req3, proxy_res, false, error); + })) { + // Thread-safe to close everything because we are assuming there are + // no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + } + } + } + + // If status code is not 200, proxy request is failed. + // Set error to ProxyConnection and return proxy response + // as the response of the request + if (proxy_res.status != StatusCode::OK_200) { + error = Error::ProxyConnection; + res = std::move(proxy_res); + // Thread-safe to close everything because we are assuming there are + // no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + return false; + } + + return true; +} + +inline bool SSLClient::load_certs() { + auto ret = true; + + std::call_once(initialize_cert_, [&]() { + std::lock_guard guard(ctx_mutex_); + if (!ca_cert_file_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(), + nullptr)) { + ret = false; + } + } else if (!ca_cert_dir_path_.empty()) { + if (!SSL_CTX_load_verify_locations(ctx_, nullptr, + ca_cert_dir_path_.c_str())) { + ret = false; + } + } else { + auto loaded = false; +#ifdef _WIN32 + loaded = + detail::load_system_certs_on_windows(SSL_CTX_get_cert_store(ctx_)); +#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && defined(__APPLE__) +#if TARGET_OS_OSX + loaded = detail::load_system_certs_on_macos(SSL_CTX_get_cert_store(ctx_)); +#endif // TARGET_OS_OSX +#endif // _WIN32 + if (!loaded) { SSL_CTX_set_default_verify_paths(ctx_); } + } + }); + + return ret; +} + +inline bool SSLClient::initialize_ssl(Socket &socket, Error &error) { + auto ssl = detail::ssl_new( + socket.sock, ctx_, ctx_mutex_, + [&](SSL *ssl2) { + if (server_certificate_verification_) { + if (!load_certs()) { + error = Error::SSLLoadingCerts; + return false; + } + SSL_set_verify(ssl2, SSL_VERIFY_NONE, nullptr); + } + + if (!detail::ssl_connect_or_accept_nonblocking( + socket.sock, ssl2, SSL_connect, connection_timeout_sec_, + connection_timeout_usec_)) { + error = Error::SSLConnection; + return false; + } + + if (server_certificate_verification_) { + verify_result_ = SSL_get_verify_result(ssl2); + + if (verify_result_ != X509_V_OK) { + error = Error::SSLServerVerification; + return false; + } + + auto server_cert = SSL_get1_peer_certificate(ssl2); + + if (server_cert == nullptr) { + error = Error::SSLServerVerification; + return false; + } + + if (!verify_host(server_cert)) { + X509_free(server_cert); + error = Error::SSLServerVerification; + return false; + } + X509_free(server_cert); + } + + return true; + }, + [&](SSL *ssl2) { + // NOTE: Direct call instead of using the OpenSSL macro to suppress + // -Wold-style-cast warning + // SSL_set_tlsext_host_name(ssl2, host_.c_str()); + SSL_ctrl(ssl2, SSL_CTRL_SET_TLSEXT_HOSTNAME, TLSEXT_NAMETYPE_host_name, + static_cast(const_cast(host_.c_str()))); + return true; + }); + + if (ssl) { + socket.ssl = ssl; + return true; + } + + shutdown_socket(socket); + close_socket(socket); + return false; +} + +inline void SSLClient::shutdown_ssl(Socket &socket, bool shutdown_gracefully) { + shutdown_ssl_impl(socket, shutdown_gracefully); +} + +inline void SSLClient::shutdown_ssl_impl(Socket &socket, + bool shutdown_gracefully) { + if (socket.sock == INVALID_SOCKET) { + assert(socket.ssl == nullptr); + return; + } + if (socket.ssl) { + detail::ssl_delete(ctx_mutex_, socket.ssl, shutdown_gracefully); + socket.ssl = nullptr; + } + assert(socket.ssl == nullptr); +} + +inline bool +SSLClient::process_socket(const Socket &socket, + std::function callback) { + assert(socket.ssl); + return detail::process_client_socket_ssl( + socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, std::move(callback)); +} + +inline bool SSLClient::is_ssl() const { return true; } + +inline bool SSLClient::verify_host(X509 *server_cert) const { + /* Quote from RFC2818 section 3.1 "Server Identity" + + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. + + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. + + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. + + */ + return verify_host_with_subject_alt_name(server_cert) || + verify_host_with_common_name(server_cert); +} + +inline bool +SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { + auto ret = false; + + auto type = GEN_DNS; + + struct in6_addr addr6 {}; + struct in_addr addr {}; + size_t addr_len = 0; + +#ifndef __MINGW32__ + if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } +#endif + + auto alt_names = static_cast( + X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); + + if (alt_names) { + auto dsn_matched = false; + auto ip_matched = false; + + auto count = sk_GENERAL_NAME_num(alt_names); + + for (decltype(count) i = 0; i < count && !dsn_matched; i++) { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if (val->type == type) { + auto name = + reinterpret_cast(ASN1_STRING_get0_data(val->d.ia5)); + auto name_len = static_cast(ASN1_STRING_length(val->d.ia5)); + + switch (type) { + case GEN_DNS: dsn_matched = check_host_name(name, name_len); break; + + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || + !memcmp(&addr, name, addr_len)) { + ip_matched = true; + } + break; + } + } + } + + if (dsn_matched || ip_matched) { ret = true; } + } + + GENERAL_NAMES_free(const_cast( + reinterpret_cast(alt_names))); + return ret; +} + +inline bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { + const auto subject_name = X509_get_subject_name(server_cert); + + if (subject_name != nullptr) { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, + name, sizeof(name)); + + if (name_len != -1) { + return check_host_name(name, static_cast(name_len)); + } + } + + return false; +} + +inline bool SSLClient::check_host_name(const char *pattern, + size_t pattern_len) const { + if (host_.size() == pattern_len && host_ == pattern) { return true; } + + // Wildcard match + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + std::vector pattern_components; + detail::split(&pattern[0], &pattern[pattern_len], '.', + [&](const char *b, const char *e) { + pattern_components.emplace_back(b, e); + }); + + if (host_components_.size() != pattern_components.size()) { return false; } + + auto itr = pattern_components.begin(); + for (const auto &h : host_components_) { + auto &p = *itr; + if (p != h && p != "*") { + auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && + !p.compare(0, p.size() - 1, h)); + if (!partial_match) { return false; } + } + ++itr; + } + + return true; +} +#endif + +// Universal client implementation +inline Client::Client(const std::string &scheme_host_port) + : Client(scheme_host_port, std::string(), std::string()) {} + +inline Client::Client(const std::string &scheme_host_port, + const std::string &client_cert_path, + const std::string &client_key_path) { + const static std::regex re( + R"((?:([a-z]+):\/\/)?(?:\[([\d:]+)\]|([^:/?#]+))(?::(\d+))?)"); + + std::smatch m; + if (std::regex_match(scheme_host_port, m, re)) { + auto scheme = m[1].str(); + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (!scheme.empty() && (scheme != "http" && scheme != "https")) { +#else + if (!scheme.empty() && scheme != "http") { +#endif +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + std::string msg = "'" + scheme + "' scheme is not supported."; + throw std::invalid_argument(msg); +#endif + return; + } + + auto is_ssl = scheme == "https"; + + auto host = m[2].str(); + if (host.empty()) { host = m[3].str(); } + + auto port_str = m[4].str(); + auto port = !port_str.empty() ? std::stoi(port_str) : (is_ssl ? 443 : 80); + + if (is_ssl) { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + cli_ = detail::make_unique(host, port, client_cert_path, + client_key_path); + is_ssl_ = is_ssl; +#endif + } else { + cli_ = detail::make_unique(host, port, client_cert_path, + client_key_path); + } + } else { + cli_ = detail::make_unique(scheme_host_port, 80, + client_cert_path, client_key_path); + } +} + +inline Client::Client(const std::string &host, int port) + : cli_(detail::make_unique(host, port)) {} + +inline Client::Client(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path) + : cli_(detail::make_unique(host, port, client_cert_path, + client_key_path)) {} + +inline Client::~Client() = default; + +inline bool Client::is_valid() const { + return cli_ != nullptr && cli_->is_valid(); +} + +inline Result Client::Get(const std::string &path) { return cli_->Get(path); } +inline Result Client::Get(const std::string &path, const Headers &headers) { + return cli_->Get(path, headers); +} +inline Result Client::Get(const std::string &path, Progress progress) { + return cli_->Get(path, std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, + Progress progress) { + return cli_->Get(path, headers, std::move(progress)); +} +inline Result Client::Get(const std::string &path, + ContentReceiver content_receiver) { + return cli_->Get(path, std::move(content_receiver)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver) { + return cli_->Get(path, headers, std::move(content_receiver)); +} +inline Result Client::Get(const std::string &path, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, headers, std::move(content_receiver), + std::move(progress)); +} +inline Result Client::Get(const std::string &path, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + return cli_->Get(path, std::move(response_handler), + std::move(content_receiver)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver) { + return cli_->Get(path, headers, std::move(response_handler), + std::move(content_receiver)); +} +inline Result Client::Get(const std::string &path, + ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, headers, std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Params ¶ms, + const Headers &headers, Progress progress) { + return cli_->Get(path, params, headers, std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, params, headers, std::move(content_receiver), + std::move(progress)); +} +inline Result Client::Get(const std::string &path, const Params ¶ms, + const Headers &headers, + ResponseHandler response_handler, + ContentReceiver content_receiver, Progress progress) { + return cli_->Get(path, params, headers, std::move(response_handler), + std::move(content_receiver), std::move(progress)); +} + +inline Result Client::Head(const std::string &path) { return cli_->Head(path); } +inline Result Client::Head(const std::string &path, const Headers &headers) { + return cli_->Head(path, headers); +} + +inline Result Client::Post(const std::string &path) { return cli_->Post(path); } +inline Result Client::Post(const std::string &path, const Headers &headers) { + return cli_->Post(path, headers); +} +inline Result Client::Post(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return cli_->Post(path, body, content_length, content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Post(path, headers, body, content_length, content_type); +} +inline Result Client::Post(const std::string &path, const std::string &body, + const std::string &content_type) { + return cli_->Post(path, body, content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type) { + return cli_->Post(path, headers, body, content_type); +} +inline Result Client::Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Post(path, content_length, std::move(content_provider), + content_type); +} +inline Result Client::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Post(path, std::move(content_provider), content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Post(path, headers, content_length, std::move(content_provider), + content_type); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Post(path, headers, std::move(content_provider), content_type); +} +inline Result Client::Post(const std::string &path, const Params ¶ms) { + return cli_->Post(path, params); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const Params ¶ms) { + return cli_->Post(path, headers, params); +} +inline Result Client::Post(const std::string &path, + const MultipartFormDataItems &items) { + return cli_->Post(path, items); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items) { + return cli_->Post(path, headers, items); +} +inline Result Client::Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary) { + return cli_->Post(path, headers, items, boundary); +} +inline Result +Client::Post(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + return cli_->Post(path, headers, items, provider_items); +} +inline Result Client::Put(const std::string &path) { return cli_->Put(path); } +inline Result Client::Put(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return cli_->Put(path, body, content_length, content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Put(path, headers, body, content_length, content_type); +} +inline Result Client::Put(const std::string &path, const std::string &body, + const std::string &content_type) { + return cli_->Put(path, body, content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type) { + return cli_->Put(path, headers, body, content_type); +} +inline Result Client::Put(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Put(path, content_length, std::move(content_provider), + content_type); +} +inline Result Client::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Put(path, std::move(content_provider), content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Put(path, headers, content_length, std::move(content_provider), + content_type); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Put(path, headers, std::move(content_provider), content_type); +} +inline Result Client::Put(const std::string &path, const Params ¶ms) { + return cli_->Put(path, params); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const Params ¶ms) { + return cli_->Put(path, headers, params); +} +inline Result Client::Put(const std::string &path, + const MultipartFormDataItems &items) { + return cli_->Put(path, items); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items) { + return cli_->Put(path, headers, items); +} +inline Result Client::Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const std::string &boundary) { + return cli_->Put(path, headers, items, boundary); +} +inline Result +Client::Put(const std::string &path, const Headers &headers, + const MultipartFormDataItems &items, + const MultipartFormDataProviderItems &provider_items) { + return cli_->Put(path, headers, items, provider_items); +} +inline Result Client::Patch(const std::string &path) { + return cli_->Patch(path); +} +inline Result Client::Patch(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return cli_->Patch(path, body, content_length, content_type); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Patch(path, headers, body, content_length, content_type); +} +inline Result Client::Patch(const std::string &path, const std::string &body, + const std::string &content_type) { + return cli_->Patch(path, body, content_type); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type) { + return cli_->Patch(path, headers, body, content_type); +} +inline Result Client::Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Patch(path, content_length, std::move(content_provider), + content_type); +} +inline Result Client::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Patch(path, std::move(content_provider), content_type); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type) { + return cli_->Patch(path, headers, content_length, std::move(content_provider), + content_type); +} +inline Result Client::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type) { + return cli_->Patch(path, headers, std::move(content_provider), content_type); +} +inline Result Client::Delete(const std::string &path) { + return cli_->Delete(path); +} +inline Result Client::Delete(const std::string &path, const Headers &headers) { + return cli_->Delete(path, headers); +} +inline Result Client::Delete(const std::string &path, const char *body, + size_t content_length, + const std::string &content_type) { + return cli_->Delete(path, body, content_length, content_type); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, + const char *body, size_t content_length, + const std::string &content_type) { + return cli_->Delete(path, headers, body, content_length, content_type); +} +inline Result Client::Delete(const std::string &path, const std::string &body, + const std::string &content_type) { + return cli_->Delete(path, body, content_type); +} +inline Result Client::Delete(const std::string &path, const Headers &headers, + const std::string &body, + const std::string &content_type) { + return cli_->Delete(path, headers, body, content_type); +} +inline Result Client::Options(const std::string &path) { + return cli_->Options(path); +} +inline Result Client::Options(const std::string &path, const Headers &headers) { + return cli_->Options(path, headers); +} + +inline bool Client::send(Request &req, Response &res, Error &error) { + return cli_->send(req, res, error); +} + +inline Result Client::send(const Request &req) { return cli_->send(req); } + +inline void Client::stop() { cli_->stop(); } + +inline std::string Client::host() const { return cli_->host(); } + +inline int Client::port() const { return cli_->port(); } + +inline size_t Client::is_socket_open() const { return cli_->is_socket_open(); } + +inline socket_t Client::socket() const { return cli_->socket(); } + +inline void +Client::set_hostname_addr_map(std::map addr_map) { + cli_->set_hostname_addr_map(std::move(addr_map)); +} + +inline void Client::set_default_headers(Headers headers) { + cli_->set_default_headers(std::move(headers)); +} + +inline void Client::set_header_writer( + std::function const &writer) { + cli_->set_header_writer(writer); +} + +inline void Client::set_address_family(int family) { + cli_->set_address_family(family); +} + +inline void Client::set_tcp_nodelay(bool on) { cli_->set_tcp_nodelay(on); } + +inline void Client::set_socket_options(SocketOptions socket_options) { + cli_->set_socket_options(std::move(socket_options)); +} + +inline void Client::set_connection_timeout(time_t sec, time_t usec) { + cli_->set_connection_timeout(sec, usec); +} + +inline void Client::set_read_timeout(time_t sec, time_t usec) { + cli_->set_read_timeout(sec, usec); +} + +inline void Client::set_write_timeout(time_t sec, time_t usec) { + cli_->set_write_timeout(sec, usec); +} + +inline void Client::set_basic_auth(const std::string &username, + const std::string &password) { + cli_->set_basic_auth(username, password); +} +inline void Client::set_bearer_token_auth(const std::string &token) { + cli_->set_bearer_token_auth(token); +} +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_digest_auth(const std::string &username, + const std::string &password) { + cli_->set_digest_auth(username, password); +} +#endif + +inline void Client::set_keep_alive(bool on) { cli_->set_keep_alive(on); } +inline void Client::set_follow_location(bool on) { + cli_->set_follow_location(on); +} + +inline void Client::set_url_encode(bool on) { cli_->set_url_encode(on); } + +inline void Client::set_compress(bool on) { cli_->set_compress(on); } + +inline void Client::set_decompress(bool on) { cli_->set_decompress(on); } + +inline void Client::set_interface(const std::string &intf) { + cli_->set_interface(intf); +} + +inline void Client::set_proxy(const std::string &host, int port) { + cli_->set_proxy(host, port); +} +inline void Client::set_proxy_basic_auth(const std::string &username, + const std::string &password) { + cli_->set_proxy_basic_auth(username, password); +} +inline void Client::set_proxy_bearer_token_auth(const std::string &token) { + cli_->set_proxy_bearer_token_auth(token); +} +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_proxy_digest_auth(const std::string &username, + const std::string &password) { + cli_->set_proxy_digest_auth(username, password); +} +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::enable_server_certificate_verification(bool enabled) { + cli_->enable_server_certificate_verification(enabled); +} +#endif + +inline void Client::set_logger(Logger logger) { + cli_->set_logger(std::move(logger)); +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +inline void Client::set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path) { + cli_->set_ca_cert_path(ca_cert_file_path, ca_cert_dir_path); +} + +inline void Client::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (is_ssl_) { + static_cast(*cli_).set_ca_cert_store(ca_cert_store); + } else { + cli_->set_ca_cert_store(ca_cert_store); + } +} + +inline void Client::load_ca_cert_store(const char *ca_cert, std::size_t size) { + set_ca_cert_store(cli_->create_ca_cert_store(ca_cert, size)); +} + +inline long Client::get_openssl_verify_result() const { + if (is_ssl_) { + return static_cast(*cli_).get_openssl_verify_result(); + } + return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? +} + +inline SSL_CTX *Client::ssl_context() const { + if (is_ssl_) { return static_cast(*cli_).ssl_context(); } + return nullptr; +} +#endif + +// ---------------------------------------------------------------------------- + +} // namespace httplib + +#if defined(_WIN32) && defined(CPPHTTPLIB_USE_POLL) +#undef poll +#endif + +#endif // CPPHTTPLIB_HTTPLIB_H diff --git a/thirdparty/stb_image_write.h b/thirdparty/stb_image_write.h index 5589a7ec2..55118853e 100644 --- a/thirdparty/stb_image_write.h +++ b/thirdparty/stb_image_write.h @@ -177,7 +177,7 @@ STBIWDEF int stbi_write_png(char const *filename, int w, int h, int comp, const STBIWDEF int stbi_write_bmp(char const *filename, int w, int h, int comp, const void *data); STBIWDEF int stbi_write_tga(char const *filename, int w, int h, int comp, const void *data); STBIWDEF int stbi_write_hdr(char const *filename, int w, int h, int comp, const float *data); -STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality); +STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality, const char* parameters = NULL); #ifdef STBIW_WINDOWS_UTF8 STBIWDEF int stbiw_convert_wchar_to_utf8(char *buffer, size_t bufferlen, const wchar_t* input); @@ -1412,7 +1412,7 @@ static int stbiw__jpg_processDU(stbi__write_context *s, int *bitBuf, int *bitCnt return DU[0]; } -static int stbi_write_jpg_core(stbi__write_context *s, int width, int height, int comp, const void* data, int quality) { +static int stbi_write_jpg_core(stbi__write_context *s, int width, int height, int comp, const void* data, int quality, const char* parameters) { // Constants that don't pollute global namespace static const unsigned char std_dc_luminance_nrcodes[] = {0,0,1,5,1,1,1,1,1,1,0,0,0,0,0,0,0}; static const unsigned char std_dc_luminance_values[] = {0,1,2,3,4,5,6,7,8,9,10,11}; @@ -1521,6 +1521,20 @@ static int stbi_write_jpg_core(stbi__write_context *s, int width, int height, in s->func(s->context, (void*)YTable, sizeof(YTable)); stbiw__putc(s, 1); s->func(s->context, UVTable, sizeof(UVTable)); + + // comment block with parameters of generation + if(parameters != NULL) { + stbiw__putc(s, 0xFF /* comnent */ ); + stbiw__putc(s, 0xFE /* marker */ ); + size_t param_length = std::min(2 + strlen("parameters") + 1 + strlen(parameters) + 1, (size_t) 0xFFFF); + stbiw__putc(s, param_length >> 8); // no need to mask, length < 65536 + stbiw__putc(s, param_length & 0xFF); + s->func(s->context, (void*)"parameters", strlen("parameters") + 1); // std::string is zero-terminated + s->func(s->context, (void*)parameters, std::min(param_length, (size_t) 65534) - 2 - strlen("parameters") - 1); + if(param_length > 65534) stbiw__putc(s, 0); // always zero-terminate for safety + if(param_length & 1) stbiw__putc(s, 0xFF); // pad to even length + } + s->func(s->context, (void*)head1, sizeof(head1)); s->func(s->context, (void*)(std_dc_luminance_nrcodes+1), sizeof(std_dc_luminance_nrcodes)-1); s->func(s->context, (void*)std_dc_luminance_values, sizeof(std_dc_luminance_values)); @@ -1625,16 +1639,16 @@ STBIWDEF int stbi_write_jpg_to_func(stbi_write_func *func, void *context, int x, { stbi__write_context s = { 0 }; stbi__start_write_callbacks(&s, func, context); - return stbi_write_jpg_core(&s, x, y, comp, (void *) data, quality); + return stbi_write_jpg_core(&s, x, y, comp, (void *) data, quality, NULL); } #ifndef STBI_WRITE_NO_STDIO -STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality) +STBIWDEF int stbi_write_jpg(char const *filename, int x, int y, int comp, const void *data, int quality, const char* parameters) { stbi__write_context s = { 0 }; if (stbi__start_write_file(&s,filename)) { - int r = stbi_write_jpg_core(&s, x, y, comp, data, quality); + int r = stbi_write_jpg_core(&s, x, y, comp, data, quality, parameters); stbi__end_write_file(&s); return r; } else diff --git a/unet.hpp b/unet.hpp index 94a8ba46a..9193dcd67 100644 --- a/unet.hpp +++ b/unet.hpp @@ -166,6 +166,7 @@ class SpatialVideoTransformer : public SpatialTransformer { // ldm.modules.diffusionmodules.openaimodel.UNetModel class UnetModelBlock : public GGMLBlock { protected: + static std::map empty_tensor_types; SDVersion version = VERSION_SD1; // network hparams int in_channels = 4; @@ -183,13 +184,13 @@ class UnetModelBlock : public GGMLBlock { int model_channels = 320; int adm_in_channels = 2816; // only for VERSION_SDXL/SVD - UnetModelBlock(SDVersion version = VERSION_SD1) + UnetModelBlock(SDVersion version = VERSION_SD1, std::map& tensor_types = empty_tensor_types, bool flash_attn = false) : version(version) { - if (version == VERSION_SD2) { + if (sd_version_is_sd2(version)) { context_dim = 1024; num_head_channels = 64; num_heads = -1; - } else if (version == VERSION_SDXL) { + } else if (sd_version_is_sdxl(version)) { context_dim = 2048; attention_resolutions = {4, 2}; channel_mult = {1, 2, 4}; @@ -204,6 +205,12 @@ class UnetModelBlock : public GGMLBlock { num_head_channels = 64; num_heads = -1; } + if (sd_version_is_inpaint(version)) { + in_channels = 9; + } else if (sd_version_is_unet_edit(version)) { + in_channels = 8; + } + // dims is always 2 // use_temporal_attention is always True for SVD @@ -211,7 +218,7 @@ class UnetModelBlock : public GGMLBlock { // time_embed_1 is nn.SiLU() blocks["time_embed.2"] = std::shared_ptr(new Linear(time_embed_dim, time_embed_dim)); - if (version == VERSION_SDXL || version == VERSION_SVD) { + if (sd_version_is_sdxl(version) || version == VERSION_SVD) { blocks["label_emb.0.0"] = std::shared_ptr(new Linear(adm_in_channels, time_embed_dim)); // label_emb_1 is nn.SiLU() blocks["label_emb.0.2"] = std::shared_ptr(new Linear(time_embed_dim, time_embed_dim)); @@ -242,7 +249,7 @@ class UnetModelBlock : public GGMLBlock { if (version == VERSION_SVD) { return new SpatialVideoTransformer(in_channels, n_head, d_head, depth, context_dim); } else { - return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim); + return new SpatialTransformer(in_channels, n_head, d_head, depth, context_dim, flash_attn); } }; @@ -532,10 +539,12 @@ struct UNetModelRunner : public GGMLRunner { UnetModelBlock unet; UNetModelRunner(ggml_backend_t backend, - ggml_type wtype, - SDVersion version = VERSION_SD1) - : GGMLRunner(backend, wtype), unet(version) { - unet.init(params_ctx, wtype); + std::map& tensor_types, + const std::string prefix, + SDVersion version = VERSION_SD1, + bool flash_attn = false) + : GGMLRunner(backend), unet(version, tensor_types, flash_attn) { + unet.init(params_ctx, tensor_types, prefix); } std::string get_desc() { @@ -564,6 +573,7 @@ struct UNetModelRunner : public GGMLRunner { context = to_backend(context); y = to_backend(y); timesteps = to_backend(timesteps); + c_concat = to_backend(c_concat); for (int i = 0; i < controls.size(); i++) { controls[i] = to_backend(controls[i]); @@ -649,4 +659,4 @@ struct UNetModelRunner : public GGMLRunner { } }; -#endif // __UNET_HPP__ \ No newline at end of file +#endif // __UNET_HPP__ diff --git a/upscaler.cpp b/upscaler.cpp index 096352993..137213496 100644 --- a/upscaler.cpp +++ b/upscaler.cpp @@ -15,30 +15,38 @@ struct UpscalerGGML { } bool load_from_file(const std::string& esrgan_path) { -#ifdef SD_USE_CUBLAS +#ifdef SD_USE_CUDA LOG_DEBUG("Using CUDA backend"); backend = ggml_backend_cuda_init(0); #endif #ifdef SD_USE_METAL LOG_DEBUG("Using Metal backend"); - ggml_backend_metal_log_set_callback(ggml_log_callback_default, nullptr); + ggml_log_set(ggml_log_callback_default, nullptr); backend = ggml_backend_metal_init(); #endif #ifdef SD_USE_VULKAN LOG_DEBUG("Using Vulkan backend"); backend = ggml_backend_vk_init(0); #endif +#ifdef SD_USE_OPENCL + LOG_DEBUG("Using OpenCL backend"); + backend = ggml_backend_opencl_init(); +#endif #ifdef SD_USE_SYCL LOG_DEBUG("Using SYCL backend"); backend = ggml_backend_sycl_init(0); #endif - + ModelLoader model_loader; + if (!model_loader.init_from_file(esrgan_path)) { + LOG_ERROR("init model loader from file failed: '%s'", esrgan_path.c_str()); + } + model_loader.set_wtype_override(model_data_type); if (!backend) { LOG_DEBUG("Using CPU backend"); backend = ggml_backend_cpu_init(); } LOG_INFO("Upscaler weight type: %s", ggml_type_name(model_data_type)); - esrgan_upscaler = std::make_shared(backend, model_data_type); + esrgan_upscaler = std::make_shared(backend, model_loader.tensor_storages_types); if (!esrgan_upscaler->load_from_file(esrgan_path)) { return false; } @@ -96,8 +104,7 @@ struct upscaler_ctx_t { }; upscaler_ctx_t* new_upscaler_ctx(const char* esrgan_path_c_str, - int n_threads, - enum sd_type_t wtype) { + int n_threads) { upscaler_ctx_t* upscaler_ctx = (upscaler_ctx_t*)malloc(sizeof(upscaler_ctx_t)); if (upscaler_ctx == NULL) { return NULL; diff --git a/util.cpp b/util.cpp index 5de5ce26e..92bc9ef50 100644 --- a/util.cpp +++ b/util.cpp @@ -22,6 +22,7 @@ #include #endif +#include "ggml-cpu.h" #include "ggml.h" #include "stable-diffusion.h" @@ -111,19 +112,32 @@ std::vector get_files_from_dir(const std::string& dir) { sprintf(directoryPath, "%s\\%s\\*", currentDirectory, dir.c_str()); // Find the first file in the directory - hFind = FindFirstFile(directoryPath, &findFileData); - + hFind = FindFirstFile(directoryPath, &findFileData); + bool isAbsolutePath = false; // Check if the directory was found if (hFind == INVALID_HANDLE_VALUE) { - printf("Unable to find directory.\n"); - return files; + printf("Unable to find directory. Try with original path \n"); + + char directoryPathAbsolute[MAX_PATH]; + sprintf(directoryPathAbsolute, "%s*", dir.c_str()); + + hFind = FindFirstFile(directoryPathAbsolute, &findFileData); + isAbsolutePath = true; + if (hFind == INVALID_HANDLE_VALUE) { + printf("Absolute path was also wrong.\n"); + return files; + } } // Loop through all files in the directory do { // Check if the found file is a regular file (not a directory) if (!(findFileData.dwFileAttributes & FILE_ATTRIBUTE_DIRECTORY)) { - files.push_back(std::string(currentDirectory) + "\\" + dir + "\\" + std::string(findFileData.cFileName)); + if (isAbsolutePath) { + files.push_back(dir + "\\" + std::string(findFileData.cFileName)); + } else { + files.push_back(std::string(currentDirectory) + "\\" + dir + "\\" + std::string(findFileData.cFileName)); + } } } while (FindNextFile(hFind, &findFileData) != 0); @@ -276,6 +290,23 @@ std::string path_join(const std::string& p1, const std::string& p2) { return p1 + "/" + p2; } +std::vector splitString(const std::string& str, char delimiter) { + std::vector result; + size_t start = 0; + size_t end = str.find(delimiter); + + while (end != std::string::npos) { + result.push_back(str.substr(start, end - start)); + start = end + 1; + end = str.find(delimiter, start); + } + + // Add the last segment after the last delimiter + result.push_back(str.substr(start)); + + return result; +} + sd_image_t* preprocess_id_image(sd_image_t* img) { int shortest_edge = 224; int size = shortest_edge; @@ -330,7 +361,7 @@ void pretty_progress(int step, int steps, float time) { } } progress += "|"; - printf(time > 1.0f ? "\r%s %i/%i - %.2fs/it" : "\r%s %i/%i - %.2fit/s", + printf(time > 1.0f ? "\r%s %i/%i - %.2fs/it" : "\r%s %i/%i - %.2fit/s\033[K", progress.c_str(), step, steps, time > 1.0f || time == 0 ? time : (1.0f / time)); fflush(stdout); // for linux @@ -393,7 +424,6 @@ const char* sd_get_system_info() { static char buffer[1024]; std::stringstream ss; ss << "System Info: \n"; - ss << " BLAS = " << ggml_cpu_has_blas() << std::endl; ss << " SSE3 = " << ggml_cpu_has_sse3() << std::endl; ss << " AVX = " << ggml_cpu_has_avx() << std::endl; ss << " AVX2 = " << ggml_cpu_has_avx2() << std::endl; @@ -411,10 +441,6 @@ const char* sd_get_system_info() { return buffer; } -const char* sd_type_name(enum sd_type_t type) { - return ggml_type_name((ggml_type)type); -} - sd_image_f32_t sd_image_t_to_sd_image_f32_t(sd_image_t image) { sd_image_f32_t converted_image; converted_image.width = image.width; diff --git a/util.h b/util.h index 9b1e6734f..d98c9a280 100644 --- a/util.h +++ b/util.h @@ -7,6 +7,9 @@ #include "stable-diffusion.h" +#define SAFE_STR(s) ((s) ? (s) : "") +#define BOOL_STR(b) ((b) ? "true" : "false") + bool ends_with(const std::string& str, const std::string& ending); bool starts_with(const std::string& str, const std::string& start); bool contains(const std::string& str, const std::string& substr); @@ -45,7 +48,7 @@ sd_image_f32_t resize_sd_image_f32_t(sd_image_f32_t image, int target_width, int sd_image_f32_t clip_preprocess(sd_image_f32_t image, int size); std::string path_join(const std::string& p1, const std::string& p2); - +std::vector splitString(const std::string& str, char delimiter); void pretty_progress(int step, int steps, float time); void log_printf(sd_log_level_t level, const char* file, int line, const char* format, ...); diff --git a/vae.hpp b/vae.hpp index 42b694cd5..4add881f6 100644 --- a/vae.hpp +++ b/vae.hpp @@ -163,8 +163,9 @@ class AE3DConv : public Conv2d { class VideoResnetBlock : public ResnetBlock { protected: - void init_params(struct ggml_context* ctx, ggml_type wtype) { - params["mix_factor"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1); + void init_params(struct ggml_context* ctx, std::map& tensor_types, const std::string prefix = "") { + enum ggml_type wtype = (tensor_types.find(prefix + "mix_factor") != tensor_types.end()) ? tensor_types[prefix + "mix_factor"] : GGML_TYPE_F32; + params["mix_factor"] = ggml_new_tensor_1d(ctx, wtype, 1); } float get_alpha() { @@ -457,7 +458,7 @@ class AutoencodingEngine : public GGMLBlock { bool use_video_decoder = false, SDVersion version = VERSION_SD1) : decode_only(decode_only), use_video_decoder(use_video_decoder) { - if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + if (sd_version_is_dit(version)) { dd_config.z_channels = 16; use_quant = false; } @@ -524,12 +525,13 @@ struct AutoEncoderKL : public GGMLRunner { AutoencodingEngine ae; AutoEncoderKL(ggml_backend_t backend, - ggml_type wtype, + std::map& tensor_types, + const std::string prefix, bool decode_only = false, bool use_video_decoder = false, SDVersion version = VERSION_SD1) - : decode_only(decode_only), ae(decode_only, use_video_decoder, version), GGMLRunner(backend, wtype) { - ae.init(params_ctx, wtype); + : decode_only(decode_only), ae(decode_only, use_video_decoder, version), GGMLRunner(backend) { + ae.init(params_ctx, tensor_types, prefix); } std::string get_desc() { @@ -612,4 +614,4 @@ struct AutoEncoderKL : public GGMLRunner { }; }; -#endif \ No newline at end of file +#endif