SelectableInt: make __mul__ return enough space to fit the result
[soc.git] / src / soc / decoder / selectable_int.py
1 import unittest
2 from copy import copy
3 from soc.decoder.power_fields import BitRange
4 from operator import (add, sub, mul, floordiv, truediv, mod, or_, and_, xor,
5 neg, inv, lshift, rshift)
6
7
8 def check_extsign(a, b):
9 if isinstance(b, FieldSelectableInt):
10 b = b.get_range()
11 if isinstance(b, int):
12 return SelectableInt(b, a.bits)
13 if b.bits != 256:
14 return b
15 return SelectableInt(b.value, a.bits)
16
17
18 class FieldSelectableInt:
19 """FieldSelectableInt: allows bit-range selection onto another target
20 """
21 def __init__(self, si, br):
22 self.si = si # target selectable int
23 if isinstance(br, list) or isinstance(br, tuple):
24 _br = BitRange()
25 for i, v in enumerate(br):
26 _br[i] = v
27 br = _br
28 self.br = br # map of indices.
29
30 def eq(self, b):
31 if isinstance(b, SelectableInt):
32 for i in range(b.bits):
33 self[i] = b[i]
34 else:
35 self.si = copy(b.si)
36 self.br = copy(b.br)
37
38 def _op(self, op, b):
39 vi = self.get_range()
40 vi = op(vi, b)
41 return self.merge(vi)
42
43 def _op1(self, op):
44 vi = self.get_range()
45 vi = op(vi)
46 return self.merge(vi)
47
48 def __getitem__(self, key):
49 print ("getitem", key, self.br)
50 if isinstance(key, SelectableInt):
51 key = key.value
52 if isinstance(key, int):
53 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
54 return self.si[key]
55 if isinstance(key, slice):
56 key = self.br[key]
57 return selectconcat(*[self.si[x] for x in key])
58
59 def __setitem__(self, key, value):
60 if isinstance(key, SelectableInt):
61 key = key.value
62 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
63 if isinstance(key, int):
64 return self.si.__setitem__(key, value)
65 else:
66 if not isinstance(value, SelectableInt):
67 value = SelectableInt(value, bits=len(key))
68 for i, k in enumerate(key):
69 self.si[k] = value[i]
70
71 def __negate__(self):
72 return self._op1(negate)
73 def __invert__(self):
74 return self._op1(inv)
75 def __add__(self, b):
76 return self._op(add, b)
77 def __sub__(self, b):
78 return self._op(sub, b)
79 def __mul__(self, b):
80 return self._op(mul, b)
81 def __div__(self, b):
82 return self._op(truediv, b)
83 def __mod__(self, b):
84 return self._op(mod, b)
85 def __and__(self, b):
86 return self._op(and_, b)
87 def __or__(self, b):
88 return self._op(or_, b)
89 def __xor__(self, b):
90 return self._op(xor, b)
91
92 def get_range(self):
93 vi = SelectableInt(0, len(self.br))
94 for k, v in self.br.items():
95 vi[k] = self.si[v]
96 return vi
97
98 def merge(self, vi):
99 fi = copy(self)
100 for i, v in fi.br.items():
101 fi.si[v] = vi[i]
102 return fi
103
104 def __repr__(self):
105 return "FieldSelectableInt(si=%s, br=%s)" % (self.si, self.br)
106
107
108 class FieldSelectableIntTestCase(unittest.TestCase):
109 def test_arith(self):
110 a = SelectableInt(0b10101, 5)
111 b = SelectableInt(0b011, 3)
112 br = BitRange()
113 br[0] = 0
114 br[1] = 2
115 br[2] = 3
116 fs = FieldSelectableInt(a, br)
117 c = fs + b
118 print (c)
119 #self.assertEqual(c.value, a.value + b.value)
120
121 def test_select(self):
122 a = SelectableInt(0b00001111, 8)
123 br = BitRange()
124 br[0] = 0
125 br[1] = 1
126 br[2] = 4
127 br[3] = 5
128 fs = FieldSelectableInt(a, br)
129
130 self.assertEqual(fs.get_range(), 0b0011)
131
132 def test_select_range(self):
133 a = SelectableInt(0b00001111, 8)
134 br = BitRange()
135 br[0] = 0
136 br[1] = 1
137 br[2] = 4
138 br[3] = 5
139 fs = FieldSelectableInt(a, br)
140
141 self.assertEqual(fs[2:4], 0b11)
142
143 fs[0:2] = 0b10
144 self.assertEqual(fs.get_range(), 0b1011)
145
146
147 class SelectableInt:
148 """SelectableInt - a class that behaves exactly like python int
149
150 this class is designed to mirror precisely the behaviour of python int.
151 the only difference is that it must contain the context of the bitwidth
152 (number of bits) associated with that integer.
153
154 FieldSelectableInt can then operate on partial bits, and because there
155 is a bit width associated with SelectableInt, slices operate correctly
156 including negative start/end points.
157 """
158 def __init__(self, value, bits):
159 if isinstance(value, SelectableInt):
160 value = value.value
161 mask = (1 << bits) - 1
162 self.value = value & mask
163 self.bits = bits
164
165 def eq(self, b):
166 self.value = b.value
167 self.bits = b.bits
168
169 def _op(self, op, b):
170 if isinstance(b, int):
171 b = SelectableInt(b, self.bits)
172 b = check_extsign(self, b)
173 assert b.bits == self.bits
174 return SelectableInt(op(self.value, b.value), self.bits)
175
176 def __add__(self, b):
177 return self._op(add, b)
178 def __sub__(self, b):
179 return self._op(sub, b)
180 def __mul__(self, b):
181 # different case: mul result needs to fit the total bitsize
182 if isinstance(b, int):
183 b = SelectableInt(b, self.bits)
184 print ("SelectableInt mul", hex(self.value), hex(b.value),
185 self.bits, b.bits)
186 return SelectableInt(self.value * b.value, self.bits + b.bits)
187 def __floordiv__(self, b):
188 return self._op(floordiv, b)
189 def __truediv__(self, b):
190 return self._op(truediv, b)
191 def __mod__(self, b):
192 return self._op(mod, b)
193 def __and__(self, b):
194 return self._op(and_, b)
195 def __or__(self, b):
196 return self._op(or_, b)
197 def __xor__(self, b):
198 return self._op(xor, b)
199 def __abs__(self):
200 return SelectableInt(0, self.bits) - self
201
202 def __rsub__(self, b):
203 if isinstance(b, int):
204 b = SelectableInt(b, self.bits)
205 b = check_extsign(self, b)
206 assert b.bits == self.bits
207 return SelectableInt(b.value - self.value, self.bits)
208
209 def __radd__(self, b):
210 if isinstance(b, int):
211 b = SelectableInt(b, self.bits)
212 b = check_extsign(self, b)
213 assert b.bits == self.bits
214 return SelectableInt(b.value + self.value, self.bits)
215
216 def __rxor__(self, b):
217 b = check_extsign(self, b)
218 assert b.bits == self.bits
219 return SelectableInt(self.value ^ b.value, self.bits)
220
221 def __invert__(self):
222 return SelectableInt(~self.value, self.bits)
223
224 def __neg__(self):
225 return SelectableInt(~self.value + 1, self.bits)
226
227 def __lshift__(self, b):
228 b = check_extsign(self, b)
229 return SelectableInt(self.value << b.value, self.bits)
230
231 def __rshift__(self, b):
232 b = check_extsign(self, b)
233 return SelectableInt(self.value >> b.value, self.bits)
234
235 def __getitem__(self, key):
236 if isinstance(key, SelectableInt):
237 key = key.value
238 if isinstance(key, int):
239 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
240 assert key >= 0
241 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
242 # MSB is indexed **LOWEST** (sigh)
243 key = self.bits - (key + 1)
244
245 value = (self.value >> key) & 1
246 return SelectableInt(value, 1)
247 elif isinstance(key, slice):
248 assert key.step is None or key.step == 1
249 assert key.start < key.stop
250 assert key.start >= 0
251 assert key.stop <= self.bits
252
253 stop = self.bits - key.start
254 start = self.bits - key.stop
255
256 bits = stop - start
257 #print ("__getitem__ slice num bits", bits)
258 mask = (1 << bits) - 1
259 value = (self.value >> start) & mask
260 return SelectableInt(value, bits)
261
262 def __setitem__(self, key, value):
263 if isinstance(key, SelectableInt):
264 key = key.value
265 if isinstance(key, int):
266 assert key < self.bits
267 assert key >= 0
268 key = self.bits - (key + 1)
269 if isinstance(value, SelectableInt):
270 assert value.bits == 1
271 value = value.value
272
273 value = value << key
274 mask = 1 << key
275 self.value = (self.value & ~mask) | (value & mask)
276 elif isinstance(key, slice):
277 assert key.step is None or key.step == 1
278 assert key.start < key.stop
279 assert key.start >= 0
280 assert key.stop <= self.bits
281
282 stop = self.bits - key.start
283 start = self.bits - key.stop
284
285 bits = stop - start
286 #print ("__setitem__ slice num bits", bits)
287 if isinstance(value, SelectableInt):
288 assert value.bits == bits, "%d into %d" % (value.bits, bits)
289 value = value.value
290 mask = ((1 << bits) - 1) << start
291 value = value << start
292 self.value = (self.value & ~mask) | (value & mask)
293
294 def __ge__(self, other):
295 if isinstance(other, FieldSelectableInt):
296 other = other.get_range()
297 if isinstance(other, SelectableInt):
298 other = check_extsign(self, other)
299 assert other.bits == self.bits
300 other = other.value
301 if isinstance(other, int):
302 return onebit(self.value >= other.value)
303 assert False
304
305 def __le__(self, other):
306 if isinstance(other, FieldSelectableInt):
307 other = other.get_range()
308 if isinstance(other, SelectableInt):
309 other = check_extsign(self, other)
310 assert other.bits == self.bits
311 other = other.value
312 if isinstance(other, int):
313 return onebit(self.value <= other)
314 assert False
315
316 def __gt__(self, other):
317 if isinstance(other, FieldSelectableInt):
318 other = other.get_range()
319 if isinstance(other, SelectableInt):
320 other = check_extsign(self, other)
321 assert other.bits == self.bits
322 other = other.value
323 if isinstance(other, int):
324 return onebit(self.value > other)
325 assert False
326
327 def __lt__(self, other):
328 if isinstance(other, FieldSelectableInt):
329 other = other.get_range()
330 if isinstance(other, SelectableInt):
331 other = check_extsign(self, other)
332 assert other.bits == self.bits
333 other = other.value
334 if isinstance(other, int):
335 return onebit(self.value < other)
336 assert False
337
338 def __eq__(self, other):
339 print ("__eq__", self, other)
340 if isinstance(other, FieldSelectableInt):
341 other = other.get_range()
342 if isinstance(other, SelectableInt):
343 other = check_extsign(self, other)
344 assert other.bits == self.bits
345 other = other.value
346 if isinstance(other, int):
347 return onebit(other == self.value)
348 assert False
349
350 def narrow(self, bits):
351 assert bits <= self.bits
352 return SelectableInt(self.value, bits)
353
354 def __bool__(self):
355 return self.value != 0
356
357 def __repr__(self):
358 return "SelectableInt(value=0x{:x}, bits={})".format(self.value,
359 self.bits)
360
361 def __len__(self):
362 return self.bits
363
364 def asint(self):
365 return self.value
366
367
368 def onebit(bit):
369 return SelectableInt(1 if bit else 0, 1)
370
371 def selectltu(lhs, rhs):
372 """ less-than (unsigned)
373 """
374 if isinstance(rhs, SelectableInt):
375 rhs = rhs.value
376 return onebit(lhs.value < rhs)
377
378 def selectgtu(lhs, rhs):
379 """ greater-than (unsigned)
380 """
381 if isinstance(rhs, SelectableInt):
382 rhs = rhs.value
383 return onebit(lhs.value > rhs)
384
385
386 # XXX this probably isn't needed...
387 def selectassign(lhs, idx, rhs):
388 if isinstance(idx, tuple):
389 if len(idx) == 2:
390 lower, upper = idx
391 step = None
392 else:
393 lower, upper, step = idx
394 toidx = range(lower, upper, step)
395 fromidx = range(0, upper-lower, step) # XXX eurgh...
396 else:
397 toidx = [idx]
398 fromidx = [0]
399 for t, f in zip(toidx, fromidx):
400 lhs[t] = rhs[f]
401
402
403 def selectconcat(*args, repeat=1):
404 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
405 args = [SelectableInt(args[0], 1)]
406 if repeat != 1: # multiplies the incoming arguments
407 tmp = []
408 for i in range(repeat):
409 tmp += args
410 args = tmp
411 res = copy(args[0])
412 for i in args[1:]:
413 if isinstance(i, FieldSelectableInt):
414 i = i.si
415 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
416 res.bits += i.bits
417 res.value = (res.value << i.bits) | i.value
418 print ("concat", repeat, res)
419 return res
420
421
422 class SelectableIntTestCase(unittest.TestCase):
423 def test_arith(self):
424 a = SelectableInt(5, 8)
425 b = SelectableInt(9, 8)
426 c = a + b
427 d = a - b
428 e = a * b
429 f = -a
430 g = abs(f)
431 h = abs(a)
432 self.assertEqual(c.value, a.value + b.value)
433 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
434 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
435 self.assertEqual(f.value, (-a.value) & 0xFF)
436 self.assertEqual(c.bits, a.bits)
437 self.assertEqual(d.bits, a.bits)
438 self.assertEqual(e.bits, a.bits)
439 self.assertEqual(a.bits, f.bits)
440 self.assertEqual(a.bits, h.bits)
441
442 def test_logic(self):
443 a = SelectableInt(0x0F, 8)
444 b = SelectableInt(0xA5, 8)
445 c = a & b
446 d = a | b
447 e = a ^ b
448 f = ~a
449 self.assertEqual(c.value, a.value & b.value)
450 self.assertEqual(d.value, a.value | b.value)
451 self.assertEqual(e.value, a.value ^ b.value)
452 self.assertEqual(f.value, 0xF0)
453
454 def test_get(self):
455 a = SelectableInt(0xa2, 8)
456 # These should be big endian
457 self.assertEqual(a[7], 0)
458 self.assertEqual(a[0:4], 10)
459 self.assertEqual(a[4:8], 2)
460
461 def test_set(self):
462 a = SelectableInt(0x5, 8)
463 a[7] = SelectableInt(0, 1)
464 self.assertEqual(a, 4)
465 a[4:8] = 9
466 self.assertEqual(a, 9)
467 a[0:4] = 3
468 self.assertEqual(a, 0x39)
469 a[0:4] = a[4:8]
470 self.assertEqual(a, 0x99)
471
472 def test_concat(self):
473 a = SelectableInt(0x1, 1)
474 c = selectconcat(a, repeat=8)
475 self.assertEqual(c, 0xff)
476 self.assertEqual(c.bits, 8)
477 a = SelectableInt(0x0, 1)
478 c = selectconcat(a, repeat=8)
479 self.assertEqual(c, 0x00)
480 self.assertEqual(c.bits, 8)
481
482 def test_repr(self):
483 for i in range(65536):
484 a = SelectableInt(i, 16)
485 b = eval(repr(a))
486 self.assertEqual(a, b)
487
488 def test_cmp(self):
489 a = SelectableInt(10, bits=8)
490 b = SelectableInt(5, bits=8)
491 self.assertTrue(a > b)
492 self.assertFalse(a < b)
493 self.assertTrue(a != b)
494 self.assertFalse(a == b)
495
496 def test_unsigned(self):
497 a = SelectableInt(0x80, bits=8)
498 b = SelectableInt(0x7f, bits=8)
499 self.assertTrue(a > b)
500 self.assertFalse(a < b)
501 self.assertTrue(a != b)
502 self.assertFalse(a == b)
503
504 if __name__ == "__main__":
505 unittest.main()