whoops off-by-one in slice ranges
[soc.git] / src / soc / decoder / selectable_int.py
1 import unittest
2 from copy import copy
3 from soc.decoder.power_fields import BitRange
4 from operator import (add, sub, mul, truediv, mod, or_, and_, xor, neg, inv)
5
6
7 def check_extsign(a, b):
8 if b.bits != 256:
9 return b
10 return SelectableInt(b.value, a.bits)
11
12
13 class FieldSelectableInt:
14 """FieldSelectableInt: allows bit-range selection onto another target
15 """
16 def __init__(self, si, br):
17 self.si = si # target selectable int
18 if isinstance(br, list) or isinstance(br, tuple):
19 _br = BitRange()
20 for i, v in enumerate(br):
21 _br[i] = v
22 br = _br
23 self.br = br # map of indices.
24
25 def _op(self, op, b):
26 vi = self.get_range()
27 vi = op(vi, b)
28 return self.merge(vi)
29
30 def _op1(self, op):
31 vi = self.get_range()
32 vi = op(vi)
33 return self.merge(vi)
34
35 def __getitem__(self, key):
36 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
37 return self.si[key]
38
39 def __setitem__(self, key, value):
40 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
41 return self.si__setitem__(key, value)
42
43 def __negate__(self):
44 return self._op1(negate)
45 def __invert__(self):
46 return self._op1(inv)
47 def __add__(self, b):
48 return self._op(add, b)
49 def __sub__(self, b):
50 return self._op(sub, b)
51 def __mul__(self, b):
52 return self._op(mul, b)
53 def __div__(self, b):
54 return self._op(truediv, b)
55 def __mod__(self, b):
56 return self._op(mod, b)
57 def __and__(self, b):
58 return self._op(and_, b)
59 def __or__(self, b):
60 return self._op(or_, b)
61 def __xor__(self, b):
62 return self._op(xor, b)
63
64 def get_range(self):
65 print ("get_range", self.si)
66 vi = SelectableInt(0, len(self.br))
67 for k, v in self.br.items():
68 print ("get_range", k, v, self.si[v])
69 vi[k] = self.si[v]
70 print ("get_range", vi)
71 return vi
72
73 def merge(self, vi):
74 fi = copy(self)
75 for i, v in fi.br.items():
76 fi.si[v] = vi[i]
77 return fi
78
79 def __repr__(self):
80 return "FieldSelectableInt(si=%s, br=%s)" % (self.si, self.br)
81
82
83 class FieldSelectableIntTestCase(unittest.TestCase):
84 def test_arith(self):
85 a = SelectableInt(0b10101, 5)
86 b = SelectableInt(0b011, 3)
87 br = BitRange()
88 br[0] = 0
89 br[1] = 2
90 br[2] = 3
91 fs = FieldSelectableInt(a, br)
92 c = fs + b
93 print (c)
94 #self.assertEqual(c.value, a.value + b.value)
95
96
97 class SelectableInt:
98 def __init__(self, value, bits):
99 mask = (1 << bits) - 1
100 self.value = value & mask
101 self.bits = bits
102
103 def __add__(self, b):
104 if isinstance(b, int):
105 b = SelectableInt(b, self.bits)
106 b = check_extsign(self, b)
107 assert b.bits == self.bits
108 return SelectableInt(self.value + b.value, self.bits)
109
110 def __sub__(self, b):
111 if isinstance(b, int):
112 b = SelectableInt(b, self.bits)
113 b = check_extsign(self, b)
114 assert b.bits == self.bits
115 return SelectableInt(self.value - b.value, self.bits)
116
117 def __mul__(self, b):
118 b = check_extsign(self, b)
119 assert b.bits == self.bits
120 return SelectableInt(self.value * b.value, self.bits)
121
122 def __div__(self, b):
123 b = check_extsign(self, b)
124 assert b.bits == self.bits
125 return SelectableInt(self.value / b.value, self.bits)
126
127 def __mod__(self, b):
128 b = check_extsign(self, b)
129 assert b.bits == self.bits
130 return SelectableInt(self.value % b.value, self.bits)
131
132 def __or__(self, b):
133 b = check_extsign(self, b)
134 assert b.bits == self.bits
135 return SelectableInt(self.value | b.value, self.bits)
136
137 def __and__(self, b):
138 print ("__and__", self, b)
139 b = check_extsign(self, b)
140 assert b.bits == self.bits
141 return SelectableInt(self.value & b.value, self.bits)
142
143 def __xor__(self, b):
144 b = check_extsign(self, b)
145 assert b.bits == self.bits
146 return SelectableInt(self.value ^ b.value, self.bits)
147
148 def __invert__(self):
149 return SelectableInt(~self.value, self.bits)
150
151 def __neg__(self):
152 return SelectableInt(~self.value + 1, self.bits)
153
154 def __getitem__(self, key):
155 if isinstance(key, int):
156 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
157 assert key >= 0
158 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
159 # MSB is indexed **LOWEST** (sigh)
160 key = self.bits - (key + 1)
161
162 value = (self.value >> key) & 1
163 return SelectableInt(value, 1)
164 elif isinstance(key, slice):
165 assert key.step is None or key.step == 1
166 assert key.start < key.stop
167 assert key.start >= 0
168 assert key.stop <= self.bits
169
170 stop = self.bits - key.start
171 start = self.bits - key.stop
172
173 bits = stop - start
174 #print ("__getitem__ slice num bits", bits)
175 mask = (1 << bits) - 1
176 value = (self.value >> start) & mask
177 return SelectableInt(value, bits)
178
179 def __setitem__(self, key, value):
180 if isinstance(key, int):
181 assert key < self.bits
182 assert key >= 0
183 key = self.bits - (key + 1)
184 if isinstance(value, SelectableInt):
185 assert value.bits == 1
186 value = value.value
187
188 value = value << key
189 mask = 1 << key
190 self.value = (self.value & ~mask) | (value & mask)
191 elif isinstance(key, slice):
192 assert key.step is None or key.step == 1
193 assert key.start < key.stop
194 assert key.start >= 0
195 assert key.stop <= self.bits
196
197 stop = self.bits - key.start
198 start = self.bits - key.stop
199
200 bits = stop - start
201 #print ("__setitem__ slice num bits", bits)
202 if isinstance(value, SelectableInt):
203 assert value.bits == bits, "%d into %d" % (value.bits, bits)
204 value = value.value
205 mask = ((1 << bits) - 1) << start
206 value = value << start
207 self.value = (self.value & ~mask) | (value & mask)
208
209 def __ge__(self, other):
210 if isinstance(other, SelectableInt):
211 other = check_extsign(self, other)
212 assert other.bits == self.bits
213 other = other.value
214 if isinstance(other, int):
215 return other >= self.value
216 assert False
217
218 def __le__(self, other):
219 if isinstance(other, SelectableInt):
220 other = check_extsign(self, other)
221 assert other.bits == self.bits
222 other = other.value
223 if isinstance(other, int):
224 return onebit(other <= self.value)
225 assert False
226
227 def __gt__(self, other):
228 if isinstance(other, SelectableInt):
229 other = check_extsign(self, other)
230 assert other.bits == self.bits
231 other = other.value
232 if isinstance(other, int):
233 return onebit(other > self.value)
234 assert False
235
236 def __lt__(self, other):
237 if isinstance(other, SelectableInt):
238 other = check_extsign(self, other)
239 assert other.bits == self.bits
240 other = other.value
241 if isinstance(other, int):
242 return onebit(other < self.value)
243 assert False
244
245 def __eq__(self, other):
246 if isinstance(other, SelectableInt):
247 other = check_extsign(self, other)
248 assert other.bits == self.bits
249 other = other.value
250 if isinstance(other, int):
251 return onebit(other == self.value)
252 assert False
253
254 def narrow(self, bits):
255 assert bits <= self.bits
256 return SelectableInt(self.value, bits)
257
258 def __bool__(self):
259 return self.value != 0
260
261 def __repr__(self):
262 return "SelectableInt(value=0x{:x}, bits={})".format(self.value,
263 self.bits)
264
265 def onebit(bit):
266 return SelectableInt(1 if bit else 0, 1)
267
268 def selectltu(lhs, rhs):
269 """ less-than (unsigned)
270 """
271 if isinstance(rhs, SelectableInt):
272 rhs = rhs.value
273 return onebit(lhs.value < rhs)
274
275 def selectgtu(lhs, rhs):
276 """ greater-than (unsigned)
277 """
278 if isinstance(rhs, SelectableInt):
279 rhs = rhs.value
280 return onebit(lhs.value > rhs)
281
282
283 # XXX this probably isn't needed...
284 def selectassign(lhs, idx, rhs):
285 if isinstance(idx, tuple):
286 if len(idx) == 2:
287 lower, upper = idx
288 step = None
289 else:
290 lower, upper, step = idx
291 toidx = range(lower, upper, step)
292 fromidx = range(0, upper-lower, step) # XXX eurgh...
293 else:
294 toidx = [idx]
295 fromidx = [0]
296 for t, f in zip(toidx, fromidx):
297 lhs[t] = rhs[f]
298
299
300 def selectconcat(*args, repeat=1):
301 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
302 args = [SelectableInt(args[0], 1)]
303 if repeat != 1: # multiplies the incoming arguments
304 tmp = []
305 for i in range(repeat):
306 tmp += args
307 args = tmp
308 res = copy(args[0])
309 for i in args[1:]:
310 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
311 res.bits += i.bits
312 res.value = (res.value << i.bits) | i.value
313 print ("concat", repeat, res)
314 return res
315
316
317 class SelectableIntTestCase(unittest.TestCase):
318 def test_arith(self):
319 a = SelectableInt(5, 8)
320 b = SelectableInt(9, 8)
321 c = a + b
322 d = a - b
323 e = a * b
324 f = -a
325 self.assertEqual(c.value, a.value + b.value)
326 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
327 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
328 self.assertEqual(f.value, (-a.value) & 0xFF)
329 self.assertEqual(c.bits, a.bits)
330 self.assertEqual(d.bits, a.bits)
331 self.assertEqual(e.bits, a.bits)
332
333 def test_logic(self):
334 a = SelectableInt(0x0F, 8)
335 b = SelectableInt(0xA5, 8)
336 c = a & b
337 d = a | b
338 e = a ^ b
339 f = ~a
340 self.assertEqual(c.value, a.value & b.value)
341 self.assertEqual(d.value, a.value | b.value)
342 self.assertEqual(e.value, a.value ^ b.value)
343 self.assertEqual(f.value, 0xF0)
344
345 def test_get(self):
346 a = SelectableInt(0xa2, 8)
347 # These should be big endian
348 self.assertEqual(a[7], 0)
349 self.assertEqual(a[0:4], 10)
350 self.assertEqual(a[4:8], 2)
351
352 def test_set(self):
353 a = SelectableInt(0x5, 8)
354 a[7] = SelectableInt(0, 1)
355 self.assertEqual(a, 4)
356 a[4:8] = 9
357 self.assertEqual(a, 9)
358 a[0:4] = 3
359 self.assertEqual(a, 0x39)
360 a[0:4] = a[4:8]
361 self.assertEqual(a, 0x99)
362
363 def test_concat(self):
364 a = SelectableInt(0x1, 1)
365 c = selectconcat(a, repeat=8)
366 self.assertEqual(c, 0xff)
367 self.assertEqual(c.bits, 8)
368 a = SelectableInt(0x0, 1)
369 c = selectconcat(a, repeat=8)
370 self.assertEqual(c, 0x00)
371 self.assertEqual(c.bits, 8)
372
373 def test_repr(self):
374 for i in range(65536):
375 a = SelectableInt(i, 16)
376 b = eval(repr(a))
377 self.assertEqual(a, b)
378
379 if __name__ == "__main__":
380 unittest.main()