add __bool__ override to selectable_int
[soc.git] / src / soc / decoder / selectable_int.py
1 import unittest
2 from copy import copy
3
4
5 class SelectableInt:
6 def __init__(self, value, bits):
7 mask = (1 << bits) - 1
8 self.value = value & mask
9 self.bits = bits
10
11 def __add__(self, b):
12 assert b.bits == self.bits
13 return SelectableInt(self.value + b.value, self.bits)
14
15 def __sub__(self, b):
16 assert b.bits == self.bits
17 return SelectableInt(self.value - b.value, self.bits)
18
19 def __mul__(self, b):
20 assert b.bits == self.bits
21 return SelectableInt(self.value * b.value, self.bits)
22
23 def __or__(self, b):
24 assert b.bits == self.bits
25 return SelectableInt(self.value | b.value, self.bits)
26
27 def __and__(self, b):
28 assert b.bits == self.bits
29 return SelectableInt(self.value & b.value, self.bits)
30
31 def __xor__(self, b):
32 assert b.bits == self.bits
33 return SelectableInt(self.value ^ b.value, self.bits)
34
35 def __invert__(self):
36 return SelectableInt(~self.value, self.bits)
37
38 def __neg__(self):
39 return SelectableInt(~self.value + 1, self.bits)
40
41 def __getitem__(self, key):
42 if isinstance(key, int):
43 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
44 assert key >= 0
45 key = self.bits - (key + 1)
46
47 value = (self.value >> key) & 1
48 return SelectableInt(value, 1)
49 elif isinstance(key, slice):
50 assert key.step is None or key.step == 1
51 assert key.start < key.stop
52 assert key.start >= 0
53 assert key.stop <= self.bits
54
55 stop = self.bits - key.start
56 start = self.bits - key.stop
57
58 bits = stop - start + 1
59 mask = (1 << bits) - 1
60 value = (self.value >> start) & mask
61 return SelectableInt(value, bits)
62
63 def __setitem__(self, key, value):
64 if isinstance(key, int):
65 assert key < self.bits
66 assert key >= 0
67 key = self.bits - (key + 1)
68 if isinstance(value, SelectableInt):
69 assert value.bits == 1
70 value = value.value
71
72 value = value << key
73 mask = 1 << key
74 self.value = (self.value & ~mask) | (value & mask)
75 elif isinstance(key, slice):
76 assert key.step is None or key.step == 1
77 assert key.start < key.stop
78 assert key.start >= 0
79 assert key.stop <= self.bits
80
81 stop = self.bits - key.start
82 start = self.bits - key.stop
83
84 bits = stop - start + 1
85 if isinstance(value, SelectableInt):
86 assert value.bits == bits, "%d into %d" % (value.bits, bits)
87 value = value.value
88 mask = ((1 << bits) - 1) << start
89 value = value << start
90 self.value = (self.value & ~mask) | (value & mask)
91
92 def __ge__(self, other):
93 if isinstance(other, SelectableInt):
94 assert other.bits == self.bits
95 other = other.value
96 if isinstance(other, int):
97 return other >= self.value
98 assert False
99
100 def __le__(self, other):
101 if isinstance(other, SelectableInt):
102 assert other.bits == self.bits
103 other = other.value
104 if isinstance(other, int):
105 return onebit(other <= self.value)
106 assert False
107
108 def __gt__(self, other):
109 if isinstance(other, SelectableInt):
110 assert other.bits == self.bits
111 other = other.value
112 if isinstance(other, int):
113 return onebit(other > self.value)
114 assert False
115
116 def __lt__(self, other):
117 if isinstance(other, SelectableInt):
118 assert other.bits == self.bits
119 other = other.value
120 if isinstance(other, int):
121 return onebit(other < self.value)
122 assert False
123
124 def __eq__(self, other):
125 if isinstance(other, SelectableInt):
126 assert other.bits == self.bits
127 other = other.value
128 if isinstance(other, int):
129 return onebit(other == self.value)
130 assert False
131
132 def __bool__(self):
133 return self.value != 0
134
135 def __repr__(self):
136 return "SelectableInt(value={:x}, bits={})".format(self.value,
137 self.bits)
138
139 def onebit(bit):
140 return SelectableInt(1 if bit else 0, 1)
141
142 def selectltu(lhs, rhs):
143 """ less-than (unsigned)
144 """
145 if isinstance(rhs, SelectableInt):
146 rhs = rhs.value
147 return onebit(lhs.value < rhs)
148
149 def selectgtu(lhs, rhs):
150 """ greater-than (unsigned)
151 """
152 if isinstance(rhs, SelectableInt):
153 rhs = rhs.value
154 return onebit(lhs.value > rhs)
155
156
157 # XXX this probably isn't needed...
158 def selectassign(lhs, idx, rhs):
159 if isinstance(idx, tuple):
160 if len(idx) == 2:
161 lower, upper = idx
162 step = None
163 else:
164 lower, upper, step = idx
165 toidx = range(lower, upper, step)
166 fromidx = range(0, upper-lower, step) # XXX eurgh...
167 else:
168 toidx = [idx]
169 fromidx = [0]
170 for t, f in zip(toidx, fromidx):
171 lhs[t] = rhs[f]
172
173
174 def selectconcat(*args):
175 res = copy(args[0])
176 for i in args[1:]:
177 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
178 res.bits += i.bits
179 res.value = (res.value << i.bits) | i.value
180 return res
181
182
183 class SelectableIntTestCase(unittest.TestCase):
184 def test_arith(self):
185 a = SelectableInt(5, 8)
186 b = SelectableInt(9, 8)
187 c = a + b
188 d = a - b
189 e = a * b
190 f = -a
191 self.assertEqual(c.value, a.value + b.value)
192 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
193 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
194 self.assertEqual(f.value, (-a.value) & 0xFF)
195 self.assertEqual(c.bits, a.bits)
196 self.assertEqual(d.bits, a.bits)
197 self.assertEqual(e.bits, a.bits)
198
199 def test_logic(self):
200 a = SelectableInt(0x0F, 8)
201 b = SelectableInt(0xA5, 8)
202 c = a & b
203 d = a | b
204 e = a ^ b
205 f = ~a
206 self.assertEqual(c.value, a.value & b.value)
207 self.assertEqual(d.value, a.value | b.value)
208 self.assertEqual(e.value, a.value ^ b.value)
209 self.assertEqual(f.value, 0xF0)
210
211 def test_get(self):
212 a = SelectableInt(0xa2, 8)
213 # These should be big endian
214 self.assertEqual(a[7], 0)
215 self.assertEqual(a[0:4], 10)
216 self.assertEqual(a[4:8], 2)
217
218 def test_set(self):
219 a = SelectableInt(0x5, 8)
220 a[7] = SelectableInt(0, 1)
221 self.assertEqual(a, 4)
222 a[4:8] = 9
223 self.assertEqual(a, 9)
224 a[0:4] = 3
225 self.assertEqual(a, 0x39)
226 a[0:4] = a[4:8]
227 self.assertEqual(a, 0x199)
228
229 if __name__ == "__main__":
230 unittest.main()