Skip to content

Commit e19e449

Browse files
authored
Implement general tensor_pow via Base.power_by_squaring (#54)
1 parent 3773c03 commit e19e449

File tree

4 files changed

+36
-8
lines changed

4 files changed

+36
-8
lines changed

src/QuantumInterface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ function tensor end
4545
const = tensor
4646
tensor() = throw(ArgumentError("Tensor function needs at least one argument."))
4747

48-
function tensor_pow end # TODO should Base.^ be the same as tensor_pow?
48+
function tensor_pow end
4949

5050
function traceout! end
5151

src/bases.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,7 @@ function tensor(b1::Basis, b2::CompositeBasis)
9393
end
9494
tensor(bases::Basis...) = reduce(tensor, bases)
9595

96-
function Base.:^(b::Basis, N::Integer)
97-
if N < 1
98-
throw(ArgumentError("Power of a basis is only defined for positive integers."))
99-
end
100-
tensor([b for i=1:N]...)
101-
end
96+
Base.:^(b::Basis, N::Integer) = tensor_pow(b, N)
10297

10398
"""
10499
equal_shape(a, b)

src/tensor.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,36 @@ tensor(op::AbstractOperator) = op
88
tensor(operators::AbstractOperator...) = reduce(tensor, operators)
99
tensor(state::StateVector) = state
1010
tensor(states::Vector{T}) where T<:StateVector = reduce(tensor, states)
11+
12+
"""
13+
tensor_pow(a, N)
14+
15+
Gives the tensor product of `a` `N` times.
16+
"""
17+
tensor_pow(a, N) = tensor_pow_by_squaring(a, N)
18+
19+
# Copied from Base.power_by_squaring as `mul` keyword dosn't work without implementing `*`
20+
function tensor_pow_by_squaring(x, p::Integer)
21+
if p == 1
22+
return x
23+
elseif p == 2
24+
return tensor(x, x)
25+
elseif p < 1
26+
throw(DomainError("Cannot take tensor_pow to power less than one"))
27+
end
28+
t = trailing_zeros(p) + 1
29+
p >>= t
30+
while (t -= 1) > 0
31+
x = tensor(x, x)
32+
end
33+
y = x
34+
while p > 0
35+
t = trailing_zeros(p) + 1
36+
p >>= t
37+
while (t -= 1) >= 0
38+
x = tensor(x, x)
39+
end
40+
y = tensor(y, x)
41+
end
42+
return y
43+
end

test/test_bases.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ comp_b2 = tensor(b1, b1, b2)
2828

2929
@test b1^3 == CompositeBasis(b1, b1, b1)
3030
@test (b1b2)^2 == CompositeBasis(b1, b2, b1, b2)
31-
@test_throws ArgumentError b1^(0)
31+
@test_throws DomainError b1^(0)
3232

3333
comp_b1_b2 = tensor(comp_b1, comp_b2)
3434
@test comp_b1_b2.shape == [prod(shape1), prod(shape2), prod(shape1), prod(shape1), prod(shape2)]

0 commit comments

Comments
 (0)