Skip to content

Commit 5e5f009

Browse files
committed
Add Enumerable#sample(n=1, ...)
- Only support the case of n == 1
1 parent 24a577b commit 5e5f009

File tree

2 files changed

+213
-0
lines changed

2 files changed

+213
-0
lines changed

ext/enumerable/statistics/extension/statistics.c

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1433,6 +1433,124 @@ enum_stdev(int argc, VALUE* argv, VALUE obj)
14331433
return stdev;
14341434
}
14351435

1436+
#if SIZEOF_SIZE_T == SIZEOF_LONG
1437+
static inline size_t
1438+
random_usize_limited(VALUE rnd, size_t max)
1439+
{
1440+
return (size_t)rb_random_ulong_limited(rnd, max);
1441+
}
1442+
#else
1443+
static inline size_t
1444+
random_usize_limited(VALUE rnd, size_t max)
1445+
{
1446+
if (max <= ULONG_MAX) {
1447+
return (size_t)rb_random_ulong_limited(rnd, (unsigned long)max);
1448+
}
1449+
else {
1450+
VALUE num = rb_random_int(rnd, SIZET2NUM(max));
1451+
return NUM2SIZET(num);
1452+
}
1453+
}
1454+
#endif
1455+
1456+
struct sample_single_memo {
1457+
size_t k;
1458+
VALUE sample;
1459+
VALUE random;
1460+
};
1461+
1462+
static VALUE
1463+
enum_sample_single_i(RB_BLOCK_CALL_FUNC_ARGLIST(e, data))
1464+
{
1465+
struct sample_single_memo *memo = (struct sample_single_memo *)data;
1466+
ENUM_WANT_SVALUE();
1467+
1468+
if (++memo->k <= 1) {
1469+
memo->sample = e;
1470+
}
1471+
else {
1472+
size_t j = random_usize_limited(memo->random, memo->k - 1);
1473+
if (j == 1) {
1474+
memo->sample = e;
1475+
}
1476+
}
1477+
1478+
return Qnil;
1479+
}
1480+
1481+
static VALUE
1482+
enum_sample_single(VALUE obj, VALUE random)
1483+
{
1484+
struct sample_single_memo memo;
1485+
1486+
memo.k = 0;
1487+
memo.sample = Qundef;
1488+
memo.random = random;
1489+
1490+
rb_block_call(obj, id_each, 0, 0, enum_sample_single_i, (VALUE)&memo);
1491+
1492+
return memo.sample;
1493+
}
1494+
1495+
static VALUE
1496+
enum_sample_multiple_unweighted(VALUE obj, long size, int replace_p)
1497+
{
1498+
assert(size > 1);
1499+
1500+
return Qnil;
1501+
}
1502+
1503+
/* call-seq:
1504+
* enum.sample(n=1, random: Random, replace: false)
1505+
*/
1506+
static VALUE
1507+
enum_sample(int argc, VALUE *argv, VALUE obj)
1508+
{
1509+
VALUE size_v, random_v, replace_v, weights_v, opts;
1510+
long size;
1511+
int replace_p;
1512+
1513+
random_v = rb_cRandom;
1514+
replace_v = Qundef;
1515+
weights_v = Qundef;
1516+
1517+
if (argc == 0) goto single;
1518+
1519+
rb_scan_args(argc, argv, "01:", &size_v, &opts);
1520+
size = NIL_P(size_v) ? 1 : NUM2LONG(size_v);
1521+
1522+
if (size == 1 && NIL_P(opts)) {
1523+
goto single;
1524+
}
1525+
1526+
if (!NIL_P(opts)) {
1527+
static ID keywords[3];
1528+
VALUE kwargs[3];
1529+
if (!keywords[0]) {
1530+
keywords[0] = rb_intern("random");
1531+
keywords[1] = rb_intern("replace");
1532+
/* keywords[2] = rb_intern("weights"); */
1533+
}
1534+
rb_get_kwargs(opts, keywords, 0, 2, kwargs);
1535+
random_v = kwargs[0];
1536+
replace_v = kwargs[1];
1537+
/* weights_v = kwargs[2]; */
1538+
}
1539+
if (random_v == Qundef) {
1540+
random_v = rb_cRandom;
1541+
}
1542+
1543+
if (size == 1) {
1544+
single:
1545+
return enum_sample_single(obj, random_v);
1546+
}
1547+
1548+
replace_p = (replace_v == Qundef) ? 1 : RTEST(replace_v);
1549+
1550+
return enum_sample_unweighted(obj, NUM2LONG(size), replace_p);
1551+
}
1552+
1553+
14361554
/* call-seq:
14371555
* ary.mean_stdev(population: false)
14381556
*
@@ -1499,6 +1617,7 @@ Init_extension(void)
14991617
rb_define_method(rb_mEnumerable, "variance", enum_variance, -1);
15001618
rb_define_method(rb_mEnumerable, "mean_stdev", enum_mean_stdev, -1);
15011619
rb_define_method(rb_mEnumerable, "stdev", enum_stdev, -1);
1620+
rb_define_method(rb_mEnumerable, "sample", enum_sample, -1);
15021621

15031622
#ifndef HAVE_ARRAY_SUM
15041623
rb_define_method(rb_cArray, "sum", ary_sum, -1);

spec/enum/sample_spec.rb

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
require 'spec_helper'
2+
require 'enumerable/statistics'
3+
4+
RSpec.describe Enumerable, '#sample' do
5+
let(:random) { Random.new }
6+
let(:n) { 20 }
7+
8+
context 'without weight' do
9+
let(:enum) { 1.upto(100000) }
10+
11+
context 'without size' do
12+
context 'without rng' do
13+
context 'without weight' do
14+
specify do
15+
result = enum.sample
16+
expect(result).to be_an(Integer)
17+
other_results = Array.new(100) { enum.sample }
18+
expect(other_results).not_to be_all(eq result)
19+
end
20+
end
21+
end
22+
23+
context 'with rng' do
24+
specify do
25+
save_random = random.dup
26+
result = enum.sample(random: random)
27+
expect(result).to be_an(Integer)
28+
other_results = Array.new(100) { enum.sample(random: save_random.dup) }
29+
expect(other_results).to be_all(eq result)
30+
end
31+
end
32+
end
33+
34+
context 'with size (== 1)' do
35+
context 'without rng' do
36+
context 'without weight' do
37+
specify do
38+
result = enum.sample(1)
39+
expect(result).to be_an(Integer)
40+
other_results = Array.new(100) { enum.sample(1) }
41+
expect(other_results).not_to be_all(eq result)
42+
end
43+
end
44+
end
45+
46+
context 'with rng' do
47+
specify do
48+
save_random = random.dup
49+
result = enum.sample(1, random: random)
50+
expect(result).to be_an(Integer)
51+
other_results = Array.new(100) { enum.sample(1, random: save_random.dup) }
52+
expect(other_results).to be_all(eq result)
53+
end
54+
end
55+
end
56+
57+
context 'with size (> 1)' do
58+
context 'without replacement' do
59+
context 'without rng' do
60+
subject(:result) { enum.sample(n) }
61+
62+
specify do
63+
result = enum.sample(n)
64+
expect(result).to be_an(Array)
65+
expect(result.length).to eq(n)
66+
other_results = Array.new(100) { enum.sample(n) }
67+
expect(other_results).not_to be_all(eq result)
68+
end
69+
end
70+
71+
context 'with rng' do
72+
subject(:result) { enum.sample(n, random: random) }
73+
74+
specify do
75+
save_random = random.dup
76+
result = enum.sample(n, random: random)
77+
expect(result).to be_an(Array)
78+
expect(result.length).to eq(n)
79+
other_results = Array.new(100) { enum.sample(n, random: random) }
80+
expect(other_results).to be_all(eq result)
81+
end
82+
end
83+
end
84+
85+
context 'with replacement' do
86+
pending
87+
end
88+
end
89+
end
90+
91+
context 'with weight' do
92+
pending
93+
end
94+
end

0 commit comments

Comments
 (0)