Skip to content

Commit 85ce9b5

Browse files
authored
feat: add sumcheck support for grumpkin curve (PROOF-913) (#249)
* sumcheck grumpkin support * rework grumpkin field * add field * work on sumcheck grumpkin support * work on grumpkin support * grumpkin support * test grumpkin * support grumpkin
1 parent adb4992 commit 85ce9b5

File tree

24 files changed

+288
-22
lines changed

24 files changed

+288
-22
lines changed

cbindings/blitzar_api.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ extern "C" {
3131
#define SXT_CURVE_GRUMPKIN 3
3232

3333
#define SXT_FIELD_SCALAR255 0
34+
#define SXT_FIELD_GRUMPKIN 1
3435

3536
/** config struct to hold the chosen backend */
3637
struct sxt_config {

sxt/cbindings/backend/cpu_backend.cc

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,28 +76,27 @@ void cpu_backend::prove_sumcheck(void* polynomials, void* evaluation_point, unsi
7676
auto num_variables = static_cast<size_t>(std::max(basn::ceil_log2(descriptor.n), 1));
7777
cbnb::switch_field_type(
7878
static_cast<cbnb::field_id_t>(field_id), [&]<class T>(std::type_identity<T>) noexcept {
79-
static_assert(std::same_as<T, s25t::element>, "only support curve-255 right now");
8079
// transcript
8180
callback_sumcheck_transcript<T> transcript{
8281
reinterpret_cast<callback_sumcheck_transcript<T>::callback_t>(
8382
const_cast<void*>(transcript_callback)),
8483
transcript_context};
8584

8685
// prove
87-
basct::span<s25t::element> polynomials_span{
88-
static_cast<s25t::element*>(polynomials),
86+
basct::span<T> polynomials_span{
87+
static_cast<T*>(polynomials),
8988
(descriptor.round_degree + 1u) * num_variables,
9089
};
91-
basct::span<s25t::element> evaluation_point_span{
92-
static_cast<s25t::element*>(evaluation_point),
90+
basct::span<T> evaluation_point_span{
91+
static_cast<T*>(evaluation_point),
9392
num_variables,
9493
};
95-
basct::cspan<s25t::element> mles_span{
96-
static_cast<const s25t::element*>(descriptor.mles),
94+
basct::cspan<T> mles_span{
95+
static_cast<const T*>(descriptor.mles),
9796
descriptor.n * descriptor.num_mles,
9897
};
99-
basct::cspan<std::pair<s25t::element, unsigned>> product_table_span{
100-
static_cast<const std::pair<s25t::element, unsigned>*>(descriptor.product_table),
98+
basct::cspan<std::pair<T, unsigned>> product_table_span{
99+
static_cast<const std::pair<T, unsigned>*>(descriptor.product_table),
101100
descriptor.num_products,
102101
};
103102
basct::cspan<unsigned> product_terms_span{

sxt/cbindings/backend/gpu_backend.cc

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -109,28 +109,27 @@ void gpu_backend::prove_sumcheck(void* polynomials, void* evaluation_point, unsi
109109
auto num_variables = static_cast<size_t>(std::max(basn::ceil_log2(descriptor.n), 1));
110110
cbnb::switch_field_type(
111111
static_cast<cbnb::field_id_t>(field_id), [&]<class T>(std::type_identity<T>) noexcept {
112-
static_assert(std::same_as<T, s25t::element>, "only support curve-255 right now");
113112
// transcript
114113
callback_sumcheck_transcript<T> transcript{
115114
reinterpret_cast<callback_sumcheck_transcript<T>::callback_t>(
116115
const_cast<void*>(transcript_callback)),
117116
transcript_context};
118117

119118
// prove
120-
basct::span<s25t::element> polynomials_span{
121-
static_cast<s25t::element*>(polynomials),
119+
basct::span<T> polynomials_span{
120+
static_cast<T*>(polynomials),
122121
(descriptor.round_degree + 1u) * num_variables,
123122
};
124-
basct::span<s25t::element> evaluation_point_span{
125-
static_cast<s25t::element*>(evaluation_point),
123+
basct::span<T> evaluation_point_span{
124+
static_cast<T*>(evaluation_point),
126125
num_variables,
127126
};
128-
basct::cspan<s25t::element> mles_span{
129-
static_cast<const s25t::element*>(descriptor.mles),
127+
basct::cspan<T> mles_span{
128+
static_cast<const T*>(descriptor.mles),
130129
descriptor.n * descriptor.num_mles,
131130
};
132-
basct::cspan<std::pair<s25t::element, unsigned>> product_table_span{
133-
static_cast<const std::pair<s25t::element, unsigned>*>(descriptor.product_table),
131+
basct::cspan<std::pair<T, unsigned>> product_table_span{
132+
static_cast<const std::pair<T, unsigned>*>(descriptor.product_table),
134133
descriptor.num_products,
135134
};
136135
basct::cspan<unsigned> product_terms_span{

sxt/cbindings/base/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ sxt_cc_component(
5757
deps = [
5858
":field_id",
5959
"//sxt/base/error:panic",
60+
"//sxt/fieldgk/realization:field",
6061
"//sxt/scalar25/realization:field",
6162
],
6263
)

sxt/cbindings/base/field_id.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,6 @@ namespace sxt::cbnb {
2727
*/
2828
enum class field_id_t : unsigned {
2929
scalar25519 = 0,
30+
grumpkin = 1,
3031
};
3132
} // namespace sxt::cbnb

sxt/cbindings/base/field_id_utility.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
#include "sxt/base/error/panic.h"
2222
#include "sxt/cbindings/base/field_id.h"
23+
#include "sxt/fieldgk/realization/field.h"
2324
#include "sxt/scalar25/realization/field.h"
2425

2526
namespace sxt::cbnb {
@@ -31,6 +32,9 @@ template <class F> void switch_field_type(field_id_t id, F f) {
3132
case field_id_t::scalar25519:
3233
f(std::type_identity<s25t::element>{});
3334
break;
35+
case field_id_t::grumpkin:
36+
f(std::type_identity<fgkt::element>{});
37+
break;
3438
default:
3539
baser::panic("unsupported field id {}", static_cast<unsigned>(id));
3640
}

sxt/fieldgk/operation/BUILD

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,16 @@ sxt_cc_component(
8383
],
8484
)
8585

86+
sxt_cc_component(
87+
name = "muladd",
88+
with_test = False,
89+
deps = [
90+
":add",
91+
":mul",
92+
"//sxt/base/macro:cuda_callable",
93+
],
94+
)
95+
8696
sxt_cc_component(
8797
name = "neg",
8898
test_deps = [

sxt/fieldgk/operation/muladd.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
2+
*
3+
* Copyright 2025-present Space and Time Labs, Inc.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
#include "sxt/fieldgk/operation/muladd.h"

sxt/fieldgk/operation/muladd.h

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/** Proofs GPU - Space and Time's cryptographic proof algorithms on the CPU and GPU.
2+
*
3+
* Copyright 2025-present Space and Time Labs, Inc.
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
#pragma once
18+
19+
#include "sxt/base/macro/cuda_callable.h"
20+
#include "sxt/fieldgk/operation/add.h"
21+
#include "sxt/fieldgk/operation/mul.h"
22+
23+
namespace sxt::fgko {
24+
//--------------------------------------------------------------------------------------------------
25+
// muladd
26+
//--------------------------------------------------------------------------------------------------
27+
inline CUDA_CALLABLE void muladd(fgkt::element& s, const fgkt::element& a, const fgkt::element& b,
28+
const fgkt::element& c) noexcept {
29+
auto cp = c;
30+
mul(s, a, b);
31+
add(s, s, cp);
32+
}
33+
} // namespace sxt::fgko

sxt/fieldgk/realization/BUILD

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
load(
2+
"//bazel:sxt_build_system.bzl",
3+
"sxt_cc_component",
4+
)
5+
6+
sxt_cc_component(
7+
name = "field",
8+
with_test = False,
9+
deps = [
10+
"//sxt/base/field:element",
11+
"//sxt/fieldgk/operation:add",
12+
"//sxt/fieldgk/operation:mul",
13+
"//sxt/fieldgk/operation:muladd",
14+
"//sxt/fieldgk/operation:neg",
15+
"//sxt/fieldgk/operation:sub",
16+
"//sxt/fieldgk/type:element",
17+
],
18+
)

0 commit comments

Comments
 (0)