diff --git a/rumale-naive_bayes/lib/rumale/naive_bayes/multinomial_nb.rb b/rumale-naive_bayes/lib/rumale/naive_bayes/multinomial_nb.rb index 420e372a..c060cd40 100644 --- a/rumale-naive_bayes/lib/rumale/naive_bayes/multinomial_nb.rb +++ b/rumale-naive_bayes/lib/rumale/naive_bayes/multinomial_nb.rb @@ -65,9 +65,9 @@ def decision_function(x) x = ::Rumale::Validation.check_convert_sample_array(x) n_classes = @classes.size - bin_x = x.gt(0) + bin_x = Numo::DFloat.cast(x.gt(0)) log_likelihoods = Array.new(n_classes) do |l| - Math.log(@class_priors[l]) + (Numo::DFloat[*bin_x] * Numo::NMath.log(@feature_probs[l, true])).sum(axis: 1) + Math.log(@class_priors[l]) + (bin_x * Numo::NMath.log(@feature_probs[l, true])).sum(axis: 1) end Numo::DFloat[*log_likelihoods].transpose.dup end diff --git a/rumale-naive_bayes/spec/rumale/naive_bayes/multinomial_nb_spec.rb b/rumale-naive_bayes/spec/rumale/naive_bayes/multinomial_nb_spec.rb index a37f8217..4eae0e9f 100644 --- a/rumale-naive_bayes/spec/rumale/naive_bayes/multinomial_nb_spec.rb +++ b/rumale-naive_bayes/spec/rumale/naive_bayes/multinomial_nb_spec.rb @@ -51,4 +51,16 @@ expect(probs.shape[1]).to eq(n_classes) expect(predicted_by_probs).to eq(y) end + + context 'with large sample sizes' do + let(:x) { Numo::DFloat.new(1_000_000, 4).rand } + let(:y) { Numo::Int32.new(1_000_000).rand(-1, 1) } + + it 'does not raise SystemStackError', :aggregate_failures do + expect do + estimator + score + end.not_to raise_error + end + end end