Skip to content

Commit 4564f15

Browse files
committed
feat: mdspan
1 parent 8d9f529 commit 4564f15

File tree

6 files changed

+219
-59
lines changed

6 files changed

+219
-59
lines changed

include/misc/needle_matrix.hpp

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
// SPDX-FileCopyrightText: 2006-2025 Knut Reinert & Freie Universität Berlin
2+
// SPDX-FileCopyrightText: 2016-2025 Knut Reinert & MPI für molekulare Genetik
3+
// SPDX-License-Identifier: BSD-3-Clause
4+
5+
#pragma once
6+
7+
#include <array>
8+
#include <cstddef>
9+
#include <mdspan>
10+
#include <span>
11+
#include <vector>
12+
13+
template <typename value_t>
14+
class needle_matrix
15+
{
16+
std::vector<value_t> data_;
17+
size_t levels_{};
18+
size_t experiments_{};
19+
20+
public:
21+
needle_matrix() = default;
22+
needle_matrix(needle_matrix const &) = default;
23+
needle_matrix(needle_matrix &&) = default;
24+
needle_matrix & operator=(needle_matrix const &) = default;
25+
needle_matrix & operator=(needle_matrix &&) = default;
26+
~needle_matrix() = default;
27+
28+
needle_matrix(size_t levels, size_t experiments) :
29+
data_(levels * experiments),
30+
levels_(levels),
31+
experiments_(experiments)
32+
{}
33+
34+
template <typename self_t>
35+
[[nodiscard]] constexpr auto * data(this self_t && self) noexcept
36+
{
37+
return self.data_.data();
38+
}
39+
40+
// Access via [level, experiment]
41+
template <typename self_t>
42+
[[nodiscard]] constexpr auto && operator[](this self_t && self, size_t lvl, size_t exp) noexcept
43+
{
44+
return self.view()[lvl, exp];
45+
}
46+
47+
// Flat 1D access (for bin-wise operations)
48+
template <typename self_t>
49+
[[nodiscard]] constexpr auto && operator[](this self_t && self, size_t bin) noexcept
50+
{
51+
return self.data()[bin];
52+
}
53+
54+
template <typename self_t>
55+
[[nodiscard]] constexpr auto view(this self_t && self) noexcept
56+
{
57+
return std::mdspan(self.data(), self.levels(), self.experiments());
58+
}
59+
60+
// Returns a contiguous span for a single level (row)
61+
template <typename self_t>
62+
[[nodiscard]] constexpr auto level(this self_t && self, size_t lvl) noexcept
63+
{
64+
return std::span(self.data() + (lvl * self.experiments()), self.experiments());
65+
}
66+
67+
// Returns a strided mdspan for a single experiment (column)
68+
template <typename self_t>
69+
[[nodiscard]] constexpr auto experiment(this self_t && self, size_t exp) noexcept
70+
{
71+
return std::mdspan(self.data() + exp,
72+
std::layout_stride::mapping{std::extents<size_t, std::dynamic_extent>{self.levels()},
73+
std::array<size_t, 1>{self.experiments()}});
74+
}
75+
76+
[[nodiscard]] constexpr size_t levels() const noexcept
77+
{
78+
return levels_;
79+
}
80+
[[nodiscard]] constexpr size_t experiments() const noexcept
81+
{
82+
return experiments_;
83+
}
84+
[[nodiscard]] constexpr size_t size() const noexcept
85+
{
86+
return data_.size();
87+
}
88+
89+
// Invalidates all views (.view(), .experiment(), .level())!
90+
void add_level()
91+
{
92+
++levels_;
93+
data_.resize(levels() * experiments());
94+
}
95+
};

include/misc/read_levels.hpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@
77
#include <filesystem>
88
#include <vector>
99

10+
template <typename value_t>
11+
class needle_matrix;
12+
1013
// Reads the level file ibf creates
1114
template <typename float_or_int>
1215
void read_levels(std::vector<std::vector<float_or_int>> & expressions, std::filesystem::path const & filename);
16+
17+
// Overload for needle_matrix
18+
template <typename float_or_int>
19+
void read_levels(needle_matrix<float_or_int> & expressions, std::filesystem::path const & filename);

src/estimate.cpp

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "misc/debug.hpp"
1414
#include "misc/filenames.hpp"
15+
#include "misc/needle_matrix.hpp"
1516
#include "misc/read_levels.hpp"
1617

1718
inline std::vector<uint64_t> get_minimiser(seqan3::dna4_vector const & seq, minimiser_arguments const & args)
@@ -27,7 +28,7 @@ inline std::vector<uint64_t> get_minimiser(seqan3::dna4_vector const & seq, mini
2728
template <typename ibf_t, bool normalization, typename exp_t>
2829
void check_ibf(estimate_ibf_arguments const & args,
2930
ibf_t const & ibf,
30-
std::vector<uint16_t> & estimations,
31+
std::span<uint16_t> estimations,
3132
seqan3::dna4_vector const & seq,
3233
exp_t const & expressions,
3334
std::vector<std::vector<double>> const & fprs,
@@ -36,7 +37,7 @@ void check_ibf(estimate_ibf_arguments const & args,
3637
size_t const num_experiments)
3738
{
3839
// Check, if one expression threshold for all or individual thresholds
39-
static constexpr bool multiple_expressions = std::same_as<exp_t, std::vector<std::vector<uint16_t>>>;
40+
static constexpr bool multiple_expressions = std::same_as<exp_t, needle_matrix<uint16_t>>;
4041

4142
std::vector<uint64_t> const minimiser = get_minimiser(seq, args);
4243
uint64_t const minimiser_count = minimiser.size();
@@ -86,7 +87,7 @@ void check_ibf(estimate_ibf_arguments const & args,
8687
if (level == num_levels - 1) // This is the last (lowest) level
8788
{
8889
if constexpr (multiple_expressions)
89-
estimations[experiment] = expressions[level][experiment];
90+
estimations[experiment] = expressions[level, experiment];
9091
else
9192
estimations[experiment] = args.expression_thresholds[level];
9293
}
@@ -102,11 +103,11 @@ void check_ibf(estimate_ibf_arguments const & args,
102103
// Actually calculate estimation, in the else case level stands for the prev_expression
103104
if constexpr (multiple_expressions)
104105
{
105-
size_t const prev_level_expression = expressions[level + 1][experiment];
106-
size_t const expression_difference = prev_level_expression - expressions[level][experiment];
106+
size_t const prev_level_expression = expressions[level + 1, experiment];
107+
size_t const expression_difference = prev_level_expression - expressions[level, experiment];
107108
size_t const estimate =
108109
prev_level_expression - (normalized_minimiser_pos * expression_difference);
109-
estimations[experiment] = std::max<size_t>(expressions[level][experiment], estimate);
110+
estimations[experiment] = std::max<size_t>(expressions[level, experiment], estimate);
110111
}
111112
else
112113
{
@@ -121,7 +122,7 @@ void check_ibf(estimate_ibf_arguments const & args,
121122
// Apply normalization if requested
122123
// TODO: Is this meant to be expressions[0]?
123124
if constexpr (normalization && multiple_expressions)
124-
estimations[experiment] /= expressions[1][experiment]; // Normalize by first level
125+
estimations[experiment] /= expressions[1, experiment]; // Normalize by first level
125126

126127
break; // Found the estimate for this experiment
127128
}
@@ -149,9 +150,9 @@ void estimate(estimate_ibf_arguments & args,
149150
// ========================================================================
150151
// const data
151152
// ========================================================================
152-
std::vector<std::vector<uint16_t>> const expressions = [&]()
153+
needle_matrix<uint16_t> const expressions = [&]()
153154
{
154-
std::vector<std::vector<uint16_t>> result;
155+
needle_matrix<uint16_t> result;
155156
if constexpr (samplewise)
156157
read_levels<uint16_t>(result, filenames::levels(estimate_args.path_in));
157158
return result;
@@ -186,8 +187,8 @@ void estimate(estimate_ibf_arguments & args,
186187
// ========================================================================
187188
std::vector<std::string> ids;
188189
std::vector<seqan3::dna4_vector> seqs;
189-
std::vector<std::vector<float>> prev_counts;
190-
std::vector<std::vector<uint16_t>> estimations;
190+
needle_matrix<float> prev_counts;
191+
needle_matrix<uint16_t> estimations;
191192
bool counters_initialised = false;
192193

193194
// ========================================================================
@@ -221,29 +222,17 @@ void estimate(estimate_ibf_arguments & args,
221222
// ========================================================================
222223
auto init_counter = [&](size_t const size)
223224
{
224-
static_assert(std::same_as<std::ranges::range_value_t<decltype(prev_counts)>, std::vector<float>>);
225-
prev_counts.resize(size, std::vector<float>(num_experiments));
226-
227-
static_assert(std::same_as<std::ranges::range_value_t<decltype(estimations)>, std::vector<uint16_t>>);
228-
estimations.resize(size, std::vector<uint16_t>(num_experiments));
229-
225+
prev_counts = needle_matrix<float>{size, num_experiments};
226+
estimations = needle_matrix<uint16_t>{size, num_experiments};
230227
return true;
231228
};
232229

233230
auto clear_data = [&]()
234231
{
235232
ids.clear();
236233
seqs.clear();
237-
std::ranges::for_each(prev_counts,
238-
[](auto & v)
239-
{
240-
std::ranges::fill(v, float{});
241-
});
242-
std::ranges::for_each(estimations,
243-
[](auto & v)
244-
{
245-
std::ranges::fill(v, uint16_t{});
246-
});
234+
std::ranges::fill_n(prev_counts.data(), prev_counts.size(), float{});
235+
std::ranges::fill_n(estimations.data(), estimations.size(), uint16_t{});
247236
};
248237

249238
auto process_ibf = [&](size_t const i)
@@ -252,7 +241,7 @@ void estimate(estimate_ibf_arguments & args,
252241
{
253242
check_ibf<ibf_t, normalization_method>(args,
254243
ibf,
255-
estimations[i],
244+
estimations.level(i),
256245
seqs[i],
257246
expressions,
258247
fprs,
@@ -264,7 +253,7 @@ void estimate(estimate_ibf_arguments & args,
264253
{
265254
check_ibf<ibf_t, false>(args,
266255
ibf,
267-
estimations[i],
256+
estimations.level(i),
268257
seqs[i],
269258
args.expression_thresholds,
270259
fprs,
@@ -303,7 +292,7 @@ void estimate(estimate_ibf_arguments & args,
303292
{
304293
outfile << ids[i] << '\t';
305294
for (size_t j = 0; j < num_experiments; ++j)
306-
outfile << estimations[i][j] << '\t';
295+
outfile << estimations[i, j] << '\t';
307296
outfile << '\n';
308297
}
309298
}

src/ibf.cpp

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include "misc/fill_hash_table.hpp"
2020
#include "misc/get_expression_thresholds.hpp"
2121
#include "misc/get_include_set_table.hpp"
22+
#include "misc/needle_matrix.hpp"
2223
#include "misc/stream.hpp"
2324

2425
// Check number of expression levels, sort expression levels
@@ -142,7 +143,7 @@ void store_fpr_information(seqan::hibf::hierarchical_interleaved_bloom_filter co
142143
};
143144

144145
size_t const num_levels = ibf_args.number_expression_thresholds;
145-
std::vector<double> actual_fprs(num_files * num_levels);
146+
needle_matrix<double> fprs{num_levels, num_files};
146147

147148
auto traverse_hibf = [&](this auto self, size_t const ibf_idx) -> void
148149
{
@@ -171,7 +172,7 @@ void store_fpr_information(seqan::hibf::hierarchical_interleaved_bloom_filter co
171172
default:
172173
if (tbin + 1u == ibf.bin_count() || user_bin_id != ub_ids[tbin + 1]) // last bin || end of split bin
173174
{
174-
actual_fprs[user_bin_id] = compute_combined_fpr(sum);
175+
fprs[user_bin_id] = compute_combined_fpr(sum);
175176
sum = 0u;
176177
}
177178
}
@@ -181,12 +182,12 @@ void store_fpr_information(seqan::hibf::hierarchical_interleaved_bloom_filter co
181182
traverse_hibf(0);
182183

183184
std::ofstream outfile{filenames::fprs(ibf_args.path_out)};
184-
for (unsigned level = 0; level < ibf_args.number_expression_thresholds; level++)
185+
for (size_t level = 0; level < fprs.levels(); ++level)
185186
{
186-
for (size_t file = 0; file < num_files; file++)
187+
auto current_level = fprs.level(level);
188+
for (auto && fpr : current_level)
187189
{
188-
size_t const bin = level * num_files + file;
189-
outfile << actual_fprs[bin] << " ";
190+
outfile << fpr << " ";
190191
}
191192
outfile << "\n";
192193
}
@@ -211,17 +212,15 @@ void ibf_helper(std::vector<std::filesystem::path> const & minimiser_files,
211212
return minimiser_args.samples.size();
212213
}();
213214

214-
std::vector<std::vector<uint16_t>> expressions = [&]()
215+
needle_matrix<uint16_t> expressions = [&]()
215216
{
216-
std::vector<std::vector<uint16_t>> result;
217217
if constexpr (samplewise)
218-
result.resize(num_files, std::vector<uint16_t>(ibf_args.number_expression_thresholds));
219-
return result;
218+
return needle_matrix<uint16_t>{ibf_args.number_expression_thresholds, num_files};
219+
else
220+
return needle_matrix<uint16_t>{};
220221
}();
221222

222223
std::vector<std::vector<uint64_t>> sizes(num_files);
223-
std::vector<std::vector<uint64_t>> counts_per_level(num_files,
224-
std::vector<uint64_t>(ibf_args.number_expression_thresholds));
225224

226225
bool const calculate_cutoffs = cutoffs.empty();
227226

@@ -383,7 +382,9 @@ void ibf_helper(std::vector<std::filesystem::path> const & minimiser_files,
383382
genome,
384383
cutoffs[i],
385384
expression_by_genome);
386-
expressions[i] = expression_thresholds;
385+
auto experiment = expressions.experiment(i);
386+
for (size_t j = 0; j < experiment.extent(0); ++j)
387+
experiment[j] = expression_thresholds[j];
387388
}
388389

389390
// Collect insertions for this file
@@ -397,7 +398,7 @@ void ibf_helper(std::vector<std::filesystem::path> const & minimiser_files,
397398
uint16_t const threshold = [&]()
398399
{
399400
if constexpr (samplewise)
400-
return expressions[i][j];
401+
return expressions[j, i];
401402
else
402403
return ibf_args.expression_thresholds[j];
403404
}();
@@ -406,7 +407,6 @@ void ibf_helper(std::vector<std::filesystem::path> const & minimiser_files,
406407
{
407408
size_t bin_index = j * num_files + i;
408409
target_bins.push_back(bin_index);
409-
counts_per_level[i][j]++;
410410
break;
411411
}
412412
}
@@ -473,10 +473,11 @@ void ibf_helper(std::vector<std::filesystem::path> const & minimiser_files,
473473
if constexpr (samplewise)
474474
{
475475
std::ofstream outfile{filenames::levels(ibf_args.path_out)};
476-
for (unsigned j = 0; j < ibf_args.number_expression_thresholds; j++)
476+
for (unsigned j = 0; j < expressions.levels(); j++)
477477
{
478-
for (size_t i = 0; i < num_files; i++)
479-
outfile << expressions[i][j] << " ";
478+
auto current_level = expressions.level(j);
479+
for (auto && expr : current_level)
480+
outfile << expr << " ";
480481
outfile << "\n";
481482
}
482483
outfile << "/\n";

0 commit comments

Comments
 (0)