@@ -1213,6 +1213,78 @@ struct test_get_rows_back : public test_case {
12131213 }
12141214};
12151215
1216+ // GGML_OP_SET_ROWS
1217+ struct test_set_rows : public test_case {
1218+ const ggml_type type;
1219+ const int n; // cols
1220+ const int m; // rows
1221+ const int r; // rows to set
1222+ const int b0; // batch size
1223+ const int b1; // batch size
1224+ const int bs; // batch size src (for testing broadcast)
1225+ const bool v; // view (non-contiguous src1)
1226+
1227+ std::string vars () override {
1228+ return VARS_TO_STR7 (type, n, m, r, b0, bs, v);
1229+ }
1230+
1231+ test_set_rows (ggml_type type = GGML_TYPE_F32, int n = 10 , int m = 5 , int r = 3 , int b = 1 , int bs = 1 , bool v = false )
1232+ : type(type), n(n), m(m), r(r), b0(b), b1(3 ), bs(bs), v(v) {
1233+ GGML_ASSERT (b0 % bs == 0 && " b0 must be a multiple of bs" );
1234+ GGML_ASSERT (r <= m && " r must be less than or equal to m" );
1235+ }
1236+
1237+ ggml_tensor * build_graph (ggml_context * ctx) override {
1238+ ggml_tensor * dst = ggml_new_tensor_4d (ctx, type, n, m, b0, b1);
1239+ ggml_set_name (dst, " dst" );
1240+
1241+ ggml_tensor * src = ggml_new_tensor_4d (ctx, GGML_TYPE_F32, n, r, b0, b1);
1242+ ggml_set_name (src, " src" );
1243+
1244+ ggml_tensor * row_idxs = ggml_new_tensor_3d (ctx, GGML_TYPE_I64, r, bs, b1);
1245+ ggml_set_name (row_idxs, " row_idxs" );
1246+
1247+ if (v) {
1248+ src = ggml_view_4d (ctx, src, n, r/2 , b0, b1, src->nb [1 ], src->nb [2 ], src->nb [3 ], 0 );
1249+ row_idxs = ggml_view_3d (ctx, row_idxs, r/2 , bs, b1, row_idxs->nb [1 ], row_idxs->nb [2 ], 0 );
1250+ ggml_set_name (row_idxs, " view_of_rows" );
1251+ }
1252+
1253+ ggml_tensor * out = ggml_set_rows (ctx, dst, src, row_idxs);
1254+ ggml_set_name (out, " out" );
1255+
1256+ return out;
1257+ }
1258+
1259+ void initialize_tensors (ggml_context * ctx) override {
1260+ std::random_device rd;
1261+ std::default_random_engine rng (rd ());
1262+ for (ggml_tensor * t = ggml_get_first_tensor (ctx); t != NULL ; t = ggml_get_next_tensor (ctx, t)) {
1263+ if (t->type == GGML_TYPE_I64) {
1264+ if (ggml_is_view_op (t->op )) {
1265+ continue ;
1266+ }
1267+
1268+ for (int i2 = 0 ; i2 < t->ne [2 ]; i2++) {
1269+ for (int i1 = 0 ; i1 < t->ne [1 ]; i1++) {
1270+ std::vector<int64_t > data (m);
1271+ for (int i = 0 ; i < m; i++) {
1272+ data[i] = i;
1273+ }
1274+ std::shuffle (data.begin (), data.end (), rng);
1275+ data.resize (t->ne [0 ]);
1276+
1277+ const size_t offs = i1*t->nb [1 ] + i2*t->nb [2 ];
1278+ ggml_backend_tensor_set (t, data.data (), offs, t->ne [0 ]*sizeof (int64_t ));
1279+ }
1280+ }
1281+ } else {
1282+ init_tensor_uniform (t);
1283+ }
1284+ }
1285+ }
1286+ };
1287+
12161288// GGML_OP_ARGMAX
12171289struct test_argmax : public test_case {
12181290 const ggml_type type;
@@ -3984,6 +4056,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
39844056 test_cases.emplace_back (new test_get_rows_back (GGML_TYPE_I32, 256 , 5 , 4 , 1 , v));
39854057 }
39864058
4059+ test_cases.emplace_back (new test_set_rows (GGML_TYPE_F32, 1 , 8 , 2 , 1 , 1 , false ));
4060+ for (ggml_type type : all_types) {
4061+ for (int b : {1 , 7 }) {
4062+ for (bool v : {false , true }) {
4063+ test_cases.emplace_back (new test_set_rows (type, 256 , 5 , 4 , b, 1 , v));
4064+ }
4065+ }
4066+ }
4067+
39874068 for (ggml_type type_input : {GGML_TYPE_F32}) {
39884069 for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
39894070 for (int k0 : {1 , 3 }) {
0 commit comments