Skip to content

Commit b08924e

Browse files
committed
std.math.big.int: changed llshr and llshl implementation
1 parent b635b37 commit b08924e

File tree

1 file changed

+99
-54
lines changed

1 file changed

+99
-54
lines changed

lib/std/math/big/int.zig

Lines changed: 99 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ const Endian = std.builtin.Endian;
1717
const Signedness = std.builtin.Signedness;
1818
const native_endian = builtin.cpu.arch.endian();
1919

20-
2120
/// Returns the number of limbs needed to store `scalar`, which must be a
2221
/// primitive integer value.
2322
/// Note: A comptime-known upper bound of this value that may be used
@@ -210,7 +209,7 @@ pub const Mutable = struct {
210209
for (self.limbs[0..self.len]) |limb| {
211210
std.debug.print("{x} ", .{limb});
212211
}
213-
std.debug.print("capacity={} positive={}\n", .{ self.limbs.len, self.positive });
212+
std.debug.print("len={} capacity={} positive={}\n", .{ self.len, self.limbs.len, self.positive });
214213
}
215214

216215
/// Clones an Mutable and returns a new Mutable with the same value. The new Mutable is a deep copy and
@@ -1104,8 +1103,8 @@ pub const Mutable = struct {
11041103
/// Asserts there is enough memory to fit the result. The upper bound Limb count is
11051104
/// `a.limbs.len + (shift / (@sizeOf(Limb) * 8))`.
11061105
pub fn shiftLeft(r: *Mutable, a: Const, shift: usize) void {
1107-
llshl(r.limbs, a.limbs, shift);
1108-
r.normalize(a.limbs.len + (shift / limb_bits) + 1);
1106+
const new_len = llshl(r.limbs, a.limbs, shift);
1107+
r.normalize(new_len);
11091108
r.positive = a.positive;
11101109
}
11111110

@@ -1173,16 +1172,16 @@ pub const Mutable = struct {
11731172

11741173
// This shift should not be able to overflow, so invoke llshl and normalize manually
11751174
// to avoid the extra required limb.
1176-
llshl(r.limbs, a.limbs, shift);
1177-
r.normalize(a.limbs.len + (shift / limb_bits));
1175+
const new_len = llshl(r.limbs, a.limbs, shift);
1176+
r.normalize(new_len);
11781177
r.positive = a.positive;
11791178
}
11801179

11811180
/// r = a >> shift
11821181
/// r and a may alias.
11831182
///
11841183
/// Asserts there is enough memory to fit the result. The upper bound Limb count is
1185-
/// `a.limbs.len - (shift / (@sizeOf(Limb) * 8))`.
1184+
/// `a.limbs.len - (shift / (@bitSizeOf(Limb)))`.
11861185
pub fn shiftRight(r: *Mutable, a: Const, shift: usize) void {
11871186
const full_limbs_shifted_out = shift / limb_bits;
11881187
const remaining_bits_shifted_out = shift % limb_bits;
@@ -1210,9 +1209,9 @@ pub const Mutable = struct {
12101209
break :nonzero a.limbs[full_limbs_shifted_out] << not_covered != 0;
12111210
};
12121211

1213-
llshr(r.limbs, a.limbs, shift);
1212+
const new_len = llshr(r.limbs, a.limbs, shift);
12141213

1215-
r.len = a.limbs.len - full_limbs_shifted_out;
1214+
r.len = new_len;
12161215
r.positive = a.positive;
12171216
if (nonzero_negative_shiftout) r.addScalar(r.toConst(), -1);
12181217
r.normalize(r.len);
@@ -1971,7 +1970,7 @@ pub const Const = struct {
19711970
for (self.limbs[0..self.limbs.len]) |limb| {
19721971
std.debug.print("{x} ", .{limb});
19731972
}
1974-
std.debug.print("positive={}\n", .{self.positive});
1973+
std.debug.print("len={} positive={}\n", .{ self.len, self.positive });
19751974
}
19761975

19771976
pub fn abs(self: Const) Const {
@@ -2673,7 +2672,7 @@ pub const Managed = struct {
26732672
for (self.limbs[0..self.len()]) |limb| {
26742673
std.debug.print("{x} ", .{limb});
26752674
}
2676-
std.debug.print("capacity={} positive={}\n", .{ self.limbs.len, self.isPositive() });
2675+
std.debug.print("len={} capacity={} positive={}\n", .{ self.len(), self.limbs.len, self.isPositive() });
26772676
}
26782677

26792678
/// Negate the sign.
@@ -3711,68 +3710,114 @@ fn lldiv0p5(quo: []Limb, rem: *Limb, a: []const Limb, b: HalfLimb) void {
37113710
}
37123711
}
37133712

3714-
fn llshl(r: []Limb, a: []const Limb, shift: usize) void {
3715-
@setRuntimeSafety(debug_safety);
3716-
assert(a.len >= 1);
3713+
/// Performs r = a << shift and returns the amount of limbs affected
3714+
///
3715+
/// if a and r overlaps, then r.ptr >= a.ptr is asserted
3716+
/// r must have the capacity to store a << shift
3717+
fn llshl(r: []Limb, a: []const Limb, shift: usize) usize {
3718+
std.debug.assert(a.len >= 1);
3719+
if (slicesOverlap(a, r))
3720+
std.debug.assert(@intFromPtr(r.ptr) >= @intFromPtr(a.ptr));
3721+
3722+
if (shift == 0) {
3723+
if (a.ptr != r.ptr)
3724+
std.mem.copyBackwards(Limb, r[0..a.len], a);
3725+
return a.len;
3726+
}
3727+
if (shift >= limb_bits) {
3728+
const limb_shift = shift / limb_bits;
3729+
3730+
const affected = llshl(r[limb_shift..], a, shift % limb_bits);
3731+
@memset(r[0..limb_shift], 0);
3732+
3733+
return limb_shift + affected;
3734+
}
37173735

3718-
const interior_limb_shift = @as(Log2Limb, @truncate(shift));
3736+
// shift is guaranteed to be < limb_bits
3737+
const bit_shift: Log2Limb = @truncate(shift);
3738+
const opposite_bit_shift: Log2Limb = @truncate(limb_bits - bit_shift);
37193739

37203740
// We only need the extra limb if the shift of the last element overflows.
37213741
// This is useful for the implementation of `shiftLeftSat`.
3722-
if (a[a.len - 1] << interior_limb_shift >> interior_limb_shift != a[a.len - 1]) {
3723-
assert(r.len >= a.len + (shift / limb_bits) + 1);
3742+
const overflows = a[a.len - 1] >> opposite_bit_shift != 0;
3743+
if (overflows) {
3744+
std.debug.assert(r.len >= a.len + 1);
37243745
} else {
3725-
assert(r.len >= a.len + (shift / limb_bits));
3746+
std.debug.assert(r.len >= a.len);
37263747
}
37273748

3728-
const limb_shift = shift / limb_bits + 1;
3749+
var i: usize = a.len;
3750+
if (overflows) {
3751+
// r is asserted to be large enough above
3752+
r[a.len] = a[a.len - 1] >> opposite_bit_shift;
3753+
}
3754+
while (i > 1) {
3755+
i -= 1;
3756+
r[i] = (a[i - 1] >> opposite_bit_shift) | (a[i] << bit_shift);
3757+
}
3758+
r[0] = a[0] << bit_shift;
37293759

3730-
var carry: Limb = 0;
3731-
var i: usize = 0;
3732-
while (i < a.len) : (i += 1) {
3733-
const src_i = a.len - i - 1;
3734-
const dst_i = src_i + limb_shift;
3760+
return a.len + @intFromBool(overflows);
3761+
}
37353762

3736-
const src_digit = a[src_i];
3737-
r[dst_i] = carry | @call(.always_inline, math.shr, .{
3738-
Limb,
3739-
src_digit,
3740-
limb_bits - @as(Limb, @intCast(interior_limb_shift)),
3741-
});
3742-
carry = (src_digit << interior_limb_shift);
3763+
/// Performs r = a >> shift and returns the amount of limbs affected
3764+
///
3765+
/// if a and r overlaps, then r.ptr <= a.ptr is asserted
3766+
/// r must have the capacity to store a >> shift
3767+
///
3768+
/// See tests below for examples of behaviour
3769+
fn llshr(r: []Limb, a: []const Limb, shift: usize) usize {
3770+
if (slicesOverlap(a, r))
3771+
std.debug.assert(@intFromPtr(r.ptr) <= @intFromPtr(a.ptr));
3772+
3773+
if (a.len == 0) return 0;
3774+
3775+
if (shift == 0) {
3776+
std.debug.assert(r.len >= a.len);
3777+
3778+
if (a.ptr != r.ptr)
3779+
std.mem.copyForwards(Limb, r[0..a.len], a);
3780+
return a.len;
3781+
}
3782+
if (shift >= limb_bits) {
3783+
if (shift / limb_bits >= a.len) {
3784+
r[0] = 0;
3785+
return 1;
3786+
}
3787+
return llshr(r, a[shift / limb_bits ..], shift % limb_bits);
37433788
}
37443789

3745-
r[limb_shift - 1] = carry;
3746-
@memset(r[0 .. limb_shift - 1], 0);
3747-
}
3790+
// shift is guaranteed to be < limb_bits
3791+
const bit_shift: Log2Limb = @truncate(shift);
3792+
const opposite_bit_shift: Log2Limb = @truncate(limb_bits - bit_shift);
37483793

3749-
fn llshr(r: []Limb, a: []const Limb, shift: usize) void {
3750-
@setRuntimeSafety(debug_safety);
3751-
assert(a.len >= 1);
3752-
assert(r.len >= a.len - (shift / limb_bits));
3794+
// special case, where there is a risk to set r to 0
3795+
if (a.len == 1) {
3796+
r[0] = a[0] >> bit_shift;
3797+
return 1;
3798+
}
3799+
if (a.len == 0) {
3800+
r[0] = 0;
3801+
return 1;
3802+
}
37533803

3754-
const limb_shift = shift / limb_bits;
3755-
const interior_limb_shift = @as(Log2Limb, @truncate(shift));
3804+
// if the most significant limb becomes 0 after the shift
3805+
const shrink = a[a.len - 1] >> bit_shift == 0;
3806+
std.debug.assert(r.len >= a.len - @intFromBool(!shrink));
37563807

37573808
var i: usize = 0;
3758-
while (i < a.len - limb_shift) : (i += 1) {
3759-
const dst_i = i;
3760-
const src_i = dst_i + limb_shift;
3761-
3762-
const src_digit = a[src_i];
3763-
const src_digit_next = if (src_i + 1 < a.len) a[src_i + 1] else 0;
3764-
const carry = @call(.always_inline, math.shl, .{
3765-
Limb,
3766-
src_digit_next,
3767-
limb_bits - @as(Limb, @intCast(interior_limb_shift)),
3768-
});
3769-
r[dst_i] = carry | (src_digit >> interior_limb_shift);
3809+
while (i < a.len - 1) : (i += 1) {
3810+
r[i] = (a[i] >> bit_shift) | (a[i + 1] << opposite_bit_shift);
37703811
}
3812+
3813+
if (!shrink)
3814+
r[i] = a[i] >> bit_shift;
3815+
3816+
return a.len - @intFromBool(shrink);
37713817
}
37723818

37733819
// r = ~r
37743820
fn llnot(r: []Limb) void {
3775-
37763821
for (r) |*elem| {
37773822
elem.* = ~elem.*;
37783823
}
@@ -4107,7 +4152,7 @@ fn llsquareBasecase(r: []Limb, x: []const Limb) void {
41074152
}
41084153

41094154
// Each product appears twice, multiply by 2
4110-
llshl(r, r[0 .. 2 * x_norm.len], 1);
4155+
_ = llshl(r, r[0 .. 2 * x_norm.len], 1);
41114156

41124157
for (x_norm, 0..) |v, i| {
41134158
// Compute and add the squares

0 commit comments

Comments
 (0)