Skip to content

Commit d2c8304

Browse files
committed
Support n > 1 in Enumerable#sample
1 parent 5e5f009 commit d2c8304

File tree

2 files changed

+52
-14
lines changed

2 files changed

+52
-14
lines changed

ext/enumerable/statistics/extension/statistics.c

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1453,16 +1453,17 @@ random_usize_limited(VALUE rnd, size_t max)
14531453
}
14541454
#endif
14551455

1456-
struct sample_single_memo {
1456+
struct enum_sample_memo {
14571457
size_t k;
1458+
long n;
14581459
VALUE sample;
14591460
VALUE random;
14601461
};
14611462

14621463
static VALUE
14631464
enum_sample_single_i(RB_BLOCK_CALL_FUNC_ARGLIST(e, data))
14641465
{
1465-
struct sample_single_memo *memo = (struct sample_single_memo *)data;
1466+
struct enum_sample_memo *memo = (struct enum_sample_memo *)data;
14661467
ENUM_WANT_SVALUE();
14671468

14681469
if (++memo->k <= 1) {
@@ -1481,9 +1482,10 @@ enum_sample_single_i(RB_BLOCK_CALL_FUNC_ARGLIST(e, data))
14811482
static VALUE
14821483
enum_sample_single(VALUE obj, VALUE random)
14831484
{
1484-
struct sample_single_memo memo;
1485+
struct enum_sample_memo memo;
14851486

14861487
memo.k = 0;
1488+
memo.n = 1;
14871489
memo.sample = Qundef;
14881490
memo.random = random;
14891491

@@ -1493,13 +1495,46 @@ enum_sample_single(VALUE obj, VALUE random)
14931495
}
14941496

14951497
static VALUE
1496-
enum_sample_multiple_unweighted(VALUE obj, long size, int replace_p)
1498+
enum_sample_multiple_without_replace_unweighted_i(RB_BLOCK_CALL_FUNC_ARGLIST(e, data))
14971499
{
1498-
assert(size > 1);
1500+
struct enum_sample_memo *memo = (struct enum_sample_memo *)data;
1501+
ENUM_WANT_SVALUE();
1502+
1503+
if (++memo->k <= memo->n) {
1504+
rb_ary_push(memo->sample, e);
1505+
}
1506+
else {
1507+
size_t j = random_usize_limited(memo->random, memo->k - 1);
1508+
if (j <= memo->n) {
1509+
rb_ary_store(memo->sample, (long)(j - 1), e);
1510+
}
1511+
}
14991512

15001513
return Qnil;
15011514
}
15021515

1516+
static VALUE
1517+
enum_sample_multiple_unweighted(VALUE obj, long size, VALUE random, int replace_p)
1518+
{
1519+
struct enum_sample_memo memo;
1520+
1521+
assert(size > 1);
1522+
1523+
memo.k = 0;
1524+
memo.n = size;
1525+
memo.sample = rb_ary_new_capa(size);
1526+
memo.random = random;
1527+
1528+
if (replace_p) {
1529+
return Qnil;
1530+
}
1531+
else {
1532+
rb_block_call(obj, id_each, 0, 0, enum_sample_multiple_without_replace_unweighted_i, (VALUE)&memo);
1533+
}
1534+
1535+
return memo.sample;
1536+
}
1537+
15031538
/* call-seq:
15041539
* enum.sample(n=1, random: Random, replace: false)
15051540
*/
@@ -1536,6 +1571,7 @@ enum_sample(int argc, VALUE *argv, VALUE obj)
15361571
replace_v = kwargs[1];
15371572
/* weights_v = kwargs[2]; */
15381573
}
1574+
15391575
if (random_v == Qundef) {
15401576
random_v = rb_cRandom;
15411577
}
@@ -1545,9 +1581,9 @@ enum_sample(int argc, VALUE *argv, VALUE obj)
15451581
return enum_sample_single(obj, random_v);
15461582
}
15471583

1548-
replace_p = (replace_v == Qundef) ? 1 : RTEST(replace_v);
1584+
replace_p = (replace_v == Qundef) ? 0 : RTEST(replace_v);
15491585

1550-
return enum_sample_unweighted(obj, NUM2LONG(size), replace_p);
1586+
return enum_sample_multiple_unweighted(obj, size, random_v, replace_p);
15511587
}
15521588

15531589

spec/enum/sample_spec.rb

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
result = enum.sample
1616
expect(result).to be_an(Integer)
1717
other_results = Array.new(100) { enum.sample }
18-
expect(other_results).not_to be_all(eq result)
18+
expect(other_results).not_to be_all {|i| i == result }
1919
end
2020
end
2121
end
@@ -26,7 +26,7 @@
2626
result = enum.sample(random: random)
2727
expect(result).to be_an(Integer)
2828
other_results = Array.new(100) { enum.sample(random: save_random.dup) }
29-
expect(other_results).to be_all(eq result)
29+
expect(other_results).to be_all {|i| i == result }
3030
end
3131
end
3232
end
@@ -38,7 +38,7 @@
3838
result = enum.sample(1)
3939
expect(result).to be_an(Integer)
4040
other_results = Array.new(100) { enum.sample(1) }
41-
expect(other_results).not_to be_all(eq result)
41+
expect(other_results).not_to be_all {|i| i == result }
4242
end
4343
end
4444
end
@@ -49,7 +49,7 @@
4949
result = enum.sample(1, random: random)
5050
expect(result).to be_an(Integer)
5151
other_results = Array.new(100) { enum.sample(1, random: save_random.dup) }
52-
expect(other_results).to be_all(eq result)
52+
expect(other_results).to be_all {|i| i == result }
5353
end
5454
end
5555
end
@@ -63,8 +63,9 @@
6363
result = enum.sample(n)
6464
expect(result).to be_an(Array)
6565
expect(result.length).to eq(n)
66+
expect(result.uniq.length).to eq(n)
6667
other_results = Array.new(100) { enum.sample(n) }
67-
expect(other_results).not_to be_all(eq result)
68+
expect(other_results).not_to be_all {|i| i == result }
6869
end
6970
end
7071

@@ -76,8 +77,9 @@
7677
result = enum.sample(n, random: random)
7778
expect(result).to be_an(Array)
7879
expect(result.length).to eq(n)
79-
other_results = Array.new(100) { enum.sample(n, random: random) }
80-
expect(other_results).to be_all(eq result)
80+
expect(result.uniq.length).to eq(n)
81+
other_results = Array.new(100) { enum.sample(n, random: save_random.dup) }
82+
expect(other_results).to be_all {|i| i == result }
8183
end
8284
end
8385
end

0 commit comments

Comments
 (0)