Skip to content

Commit a6c731b

Browse files
committed
Refactors, adds doc and tests for modulation.py and utilities.py
1 parent 26c4369 commit a6c731b

File tree

4 files changed

+220
-43
lines changed

4 files changed

+220
-43
lines changed

commpy/modulation.py

Lines changed: 121 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,46 @@
3030
from numpy.fft import fft, ifft
3131
from numpy.linalg import qr, norm
3232

33-
from commpy.utilities import bitarray2dec, dec2bitarray
33+
from commpy.utilities import bitarray2dec, dec2bitarray, signal_power
3434

3535
__all__ = ['PSKModem', 'QAMModem', 'ofdm_tx', 'ofdm_rx', 'mimo_ml', 'kbest', 'best_first_detector',
3636
'bit_lvl_repr', 'max_log_approx']
3737

3838

3939
class Modem:
40+
41+
""" Creates a custom Modem object.
42+
43+
Parameters
44+
----------
45+
constellation : array-like with a length which is a power of 2
46+
Constellation of the custom modem
47+
48+
Attributes
49+
----------
50+
constellation : 1D-ndarray of complex
51+
Modem constellation. If changed, the length of the new constellation must be a power of 2.
52+
53+
Es : float
54+
Average energy per symbols.
55+
56+
m : integer
57+
Constellation length.
58+
59+
num_bits_symb : integer
60+
Number of bits per symbol.
61+
62+
Raises
63+
------
64+
ValueError
65+
If the constellation is changed to an array-like with length that is not a power of 2.
66+
"""
67+
68+
def __init__(self, constellation):
69+
""" Creates a custom Modem object. """
70+
71+
self.constellation = constellation
72+
4073
def modulate(self, input_bits):
4174
""" Modulate (map) an array of bits to constellation symbols.
4275
@@ -52,7 +85,7 @@ def modulate(self, input_bits):
5285
5386
"""
5487
mapfunc = vectorize(lambda i:
55-
self.constellation[bitarray2dec(input_bits[i:i + self.num_bits_symbol])])
88+
self._constellation[bitarray2dec(input_bits[i:i + self.num_bits_symbol])])
5689

5790
baseband_symbols = mapfunc(arange(0, len(input_bits), self.num_bits_symbol))
5891

@@ -80,9 +113,8 @@ def demodulate(self, input_symbols, demod_type, noise_var=0):
80113
81114
"""
82115
if demod_type == 'hard':
83-
index_list = map(lambda i: argmin(abs(input_symbols[i] - self.constellation)),
84-
range(0, len(input_symbols)))
85-
demod_bits = array([dec2bitarray(i, self.num_bits_symbol) for i in index_list]).reshape(-1)
116+
index_list = abs(input_symbols - self._constellation[:, None]).argmin(0)
117+
demod_bits = dec2bitarray(index_list, self.num_bits_symbol)
86118

87119
elif demod_type == 'soft':
88120
demod_bits = zeros(len(input_symbols) * self.num_bits_symbol)
@@ -94,10 +126,10 @@ def demodulate(self, input_symbols, demod_type, noise_var=0):
94126
for const_index in self.symbol_mapping:
95127
if (const_index >> bit_index) & 1:
96128
llr_num = llr_num + exp(
97-
(-abs(current_symbol - self.constellation[const_index]) ** 2) / noise_var)
129+
(-abs(current_symbol - self._constellation[const_index]) ** 2) / noise_var)
98130
else:
99131
llr_den = llr_den + exp(
100-
(-abs(current_symbol - self.constellation[const_index]) ** 2) / noise_var)
132+
(-abs(current_symbol - self._constellation[const_index]) ** 2) / noise_var)
101133
demod_bits[i * self.num_bits_symbol + self.num_bits_symbol - 1 - bit_index] = log(llr_num / llr_den)
102134
else:
103135
raise ValueError('demod_type must be "hard" or "soft"')
@@ -142,36 +174,92 @@ def plot_constellation(self):
142174
plt.grid()
143175
plt.show()
144176

177+
@property
178+
def constellation(self):
179+
""" Constellation of the modem. """
180+
return self._constellation
145181

146-
class PSKModem(Modem):
147-
""" Creates a Phase Shift Keying (PSK) Modem object. """
182+
@constellation.setter
183+
def constellation(self, value):
184+
# Check value input
185+
num_bits_symbol = log2(len(value))
186+
if num_bits_symbol != int(num_bits_symbol):
187+
raise ValueError('Constellation length must be a power of 2.')
148188

149-
Es = 1
189+
# Set constellation as an array
190+
self._constellation = array(value)
150191

151-
def _constellation_symbol(self, i):
152-
return cos(2 * pi * (i - 1) / self.m) + sin(2 * pi * (i - 1) / self.m) * (0 + 1j)
192+
# Update other attributes
193+
self.Es = signal_power(self.constellation)
194+
self.m = self._constellation.size
195+
self.num_bits_symbol = int(num_bits_symbol)
196+
self.constellation_mapping = arange(self.m)
153197

154-
def __init__(self, m):
155-
""" Creates a Phase Shift Keying (PSK) Modem object.
198+
199+
class PSKModem(Modem):
200+
""" Creates a Phase Shift Keying (PSK) Modem object.
156201
157202
Parameters
158203
----------
159204
m : int
160205
Size of the PSK constellation.
161206
207+
Attributes
208+
----------
209+
constellation : 1D-ndarray of complex
210+
Modem constellation. If changed, the length of the new constellation must be a power of 2.
211+
212+
Es : float
213+
Average energy per symbols.
214+
215+
m : integer
216+
Constellation length.
217+
218+
num_bits_symb : integer
219+
Number of bits per symbol.
220+
221+
Raises
222+
------
223+
ValueError
224+
If the constellation is changed to an array-like with length that is not a power of 2.
162225
"""
163-
self.m = m
164-
self.num_bits_symbol = int(log2(self.m))
165-
self.symbol_mapping = arange(self.m)
166-
self.constellation = list(map(self._constellation_symbol,
167-
self.symbol_mapping))
226+
227+
def __init__(self, m):
228+
""" Creates a Phase Shift Keying (PSK) Modem object. """
229+
230+
def _constellation_symbol(i):
231+
return cos(2 * pi * (i - 1) / m) + sin(2 * pi * (i - 1) / m) * (0 + 1j)
232+
233+
self.constellation = list(map(_constellation_symbol, arange(m)))
168234

169235

170236
class QAMModem(Modem):
171-
""" Creates a Quadrature Amplitude Modulation (QAM) Modem object."""
237+
""" Creates a Quadrature Amplitude Modulation (QAM) Modem object.
172238
173-
def _constellation_symbol(self, i):
174-
return (2 * i[0] - 1) + (2 * i[1] - 1) * (1j)
239+
Parameters
240+
----------
241+
m : int
242+
Size of the PSK constellation.
243+
244+
Attributes
245+
----------
246+
constellation : 1D-ndarray of complex
247+
Modem constellation. If changed, the length of the new constellation must be a power of 2.
248+
249+
Es : float
250+
Average energy per symbols.
251+
252+
m : integer
253+
Constellation length.
254+
255+
num_bits_symb : integer
256+
Number of bits per symbol.
257+
258+
Raises
259+
------
260+
ValueError
261+
If the constellation is changed to an array-like with length that is not a power of 2.
262+
"""
175263

176264
def __init__(self, m):
177265
""" Creates a Quadrature Amplitude Modulation (QAM) Modem object.
@@ -183,13 +271,12 @@ def __init__(self, m):
183271
184272
"""
185273

186-
self.m = m
187-
self.num_bits_symbol = int(log2(self.m))
188-
self.symbol_mapping = arange(self.m)
189-
mapping_array = arange(1, sqrt(self.m) + 1) - (sqrt(self.m) / 2)
190-
self.constellation = list(map(self._constellation_symbol,
274+
def _constellation_symbol(i):
275+
return (2 * i[0] - 1) + (2 * i[1] - 1) * (1j)
276+
277+
mapping_array = arange(1, sqrt(m) + 1) - (sqrt(m) / 2)
278+
self.constellation = list(map(_constellation_symbol,
191279
list(product(mapping_array, repeat=2))))
192-
self.Es = 2 * (self.m - 1) / 3
193280

194281

195282
def ofdm_tx(x, nfft, nsc, cp_length):
@@ -510,6 +597,11 @@ def bit_lvl_repr(H, w):
510597
------
511598
A : 2D nbarray (shape : nb_rx, nb_tx*beta)
512599
Channel matrix adapted to the bit-level representation.
600+
601+
raises
602+
------
603+
ValueError
604+
If beta (the length of w) is not even)
513605
"""
514606
beta = len(w)
515607
if beta % 2 == 0:
@@ -518,7 +610,7 @@ def bit_lvl_repr(H, w):
518610
kr = kron(In, w)
519611
return dot(H, kr)
520612
else:
521-
raise ValueError('Beta must be even.')
613+
raise ValueError('Beta (length of w) must be even.')
522614

523615

524616
def max_log_approx(y, h, noise_var, pts_list, demode):

commpy/tests/test_modulation.py

Lines changed: 71 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
# Authors: CommPy contributors
22
# License: BSD 3-Clause
33

4+
from itertools import product
5+
46
from numpy import zeros, identity, arange, concatenate, log2, array, inf
57
from numpy.random import seed
6-
from numpy.testing import run_module_suite, assert_allclose, dec
8+
from numpy.testing import run_module_suite, assert_allclose, dec, assert_raises, assert_array_equal
79

810
from commpy.channels import MIMOFlatChannel
911
from commpy.links import *
10-
from commpy.modulation import QAMModem, mimo_ml, bit_lvl_repr, max_log_approx
12+
from commpy.modulation import QAMModem, mimo_ml, bit_lvl_repr, max_log_approx, PSKModem, Modem
13+
from commpy.utilities import signal_power
1114

1215

1316
@dec.slow
@@ -49,6 +52,10 @@ def receiver_without_blr(y, H, cons, noise_var):
4952
assert_allclose(ber_without_blr, ber_with_blr, rtol=0.5,
5053
err_msg='bit_lvl_repr changes the performance')
5154

55+
# Test error raising
56+
with assert_raises(ValueError):
57+
bit_lvl_repr(RayleighChannel.channel_gains[0], array((2, 4, 6)))
58+
5259

5360
def test_max_log_approx():
5461
x = array((-1, -1, 1))
@@ -71,8 +78,68 @@ def decode(pt):
7178
err_msg='Wrong LLRs without noise')
7279

7380

74-
def test_kbest():
75-
pass # Tested in test_links
81+
class ModemTestcase:
82+
qam_modems = [QAMModem(4), QAMModem(16), QAMModem(64)]
83+
psk_modems = [PSKModem(4), PSKModem(16), PSKModem(64)]
84+
modems = qam_modems + psk_modems
85+
86+
def __init__(self):
87+
# Create a custom Modem
88+
custom_constellation = [re + im * 1j for re, im in product((-3.5, -0.5, 0.5, 3.5), repeat=2)]
89+
self.custom_modems = [Modem(custom_constellation)]
90+
91+
# Add to custom modems a QAM modem with modified constellation
92+
QAM_custom = QAMModem(16)
93+
QAM_custom.constellation = custom_constellation
94+
self.custom_modems.append(QAM_custom)
95+
self.modems += self.custom_modems
96+
97+
# Assert that error is raised when the contellation length is not a power of 2
98+
with assert_raises(ValueError):
99+
QAM_custom.constellation = (0, 0, 0)
100+
101+
def test(self):
102+
for modem in self.modems:
103+
self.do(modem)
104+
for modem in self.qam_modems:
105+
self.do_qam(modem)
106+
for modem in self.psk_modems:
107+
self.do_psk(modem)
108+
for modem in self.custom_modems:
109+
self.do_custom(modem)
110+
111+
# Default methods for TestClasses that not implement a specific test
112+
def do(self, modem):
113+
pass
114+
115+
def do_qam(self, modem):
116+
pass
117+
118+
def do_psk(self, modem):
119+
pass
120+
121+
def do_custom(self, modem):
122+
pass
123+
124+
125+
class TestModulateHardDemodulate(ModemTestcase):
126+
127+
def do(self, modem):
128+
for bits in product(*((0, 1),) * modem.num_bits_symbol):
129+
assert_array_equal(bits, modem.demodulate(modem.modulate(bits), 'hard'),
130+
err_msg='Bits are not equal after modulation and hard demodulation')
131+
132+
133+
class TestEs(ModemTestcase):
134+
135+
def do_qam(self, modem):
136+
assert_allclose(signal_power(modem.constellation), 2 * (modem.m - 1) / 3)
137+
138+
def do_psk(self, modem):
139+
assert_allclose(modem.Es, 1)
140+
141+
def do_custom(self, modem):
142+
assert_allclose(modem.Es, 12.5)
76143

77144

78145
if __name__ == "__main__":

commpy/tests/test_utilities.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Authors: CommPy contributors
2+
# License: BSD 3-Clause
3+
4+
from numpy import array
5+
from numpy.random import seed
6+
from numpy.testing import run_module_suite, assert_array_equal
7+
8+
from commpy.utilities import dec2bitarray
9+
10+
11+
def test_dec2bitarray():
12+
# Assert result
13+
assert_array_equal(dec2bitarray(17, 8), array((0, 0, 0, 1, 0, 0, 0, 1)))
14+
assert_array_equal(dec2bitarray((17, 12), 5), array((1, 0, 0, 0, 1, 0, 1, 1, 0, 0)))
15+
16+
17+
if __name__ == "__main__":
18+
seed(17121996)
19+
run_module_suite()

commpy/utilities.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,25 @@
1616
upsample -- Upsample by an integral factor (zero insertion).
1717
signal_power -- Compute the power of a discrete time signal.
1818
"""
19-
from __future__ import division # Python 2 compatibility
19+
20+
import itertools as it
2021

2122
import numpy as np
2223

2324
__all__ = ['dec2bitarray', 'bitarray2dec', 'hamming_dist', 'euclid_dist', 'upsample',
2425
'signal_power']
2526

27+
vectorized_binary_repr = np.vectorize(np.binary_repr)
28+
29+
2630
def dec2bitarray(in_number, bit_width):
2731
"""
2832
Converts a positive integer to NumPy array of the specified size containing
2933
bits (0 and 1).
3034
3135
Parameters
3236
----------
33-
in_number : int
37+
in_number : int or array-like of int
3438
Positive integer to be converted to a bit array.
3539
3640
bit_width : int
@@ -39,17 +43,12 @@ def dec2bitarray(in_number, bit_width):
3943
Returns
4044
-------
4145
bitarray : 1D ndarray of ints
42-
Array containing the binary representation of the input decimal.
46+
Array containing the binary representation of all the input decimal(s).
4347
4448
"""
4549

46-
binary_string = bin(in_number)
47-
length = len(binary_string)
48-
bitarray = np.zeros(bit_width, 'int')
49-
for i in range(length - 2):
50-
bitarray[bit_width - i - 1] = int(binary_string[length - i - 1])
51-
52-
return bitarray
50+
binary_words = vectorized_binary_repr(np.array(in_number, ndmin=1), bit_width)
51+
return np.fromiter(it.chain.from_iterable(binary_words), dtype=np.int8)
5352

5453

5554
def bitarray2dec(in_bitarray):

0 commit comments

Comments
 (0)