turn SelectableInt less/greater into signed versions.
[soc.git] / src / soc / decoder / selectable_int.py
1 import unittest
2 from copy import copy
3 from soc.decoder.power_fields import BitRange
4 from operator import (add, sub, mul, floordiv, truediv, mod, or_, and_, xor,
5 neg, inv, lshift, rshift)
6
7
8 def check_extsign(a, b):
9 if isinstance(b, FieldSelectableInt):
10 b = b.get_range()
11 if isinstance(b, int):
12 return SelectableInt(b, a.bits)
13 if b.bits != 256:
14 return b
15 return SelectableInt(b.value, a.bits)
16
17
18 class FieldSelectableInt:
19 """FieldSelectableInt: allows bit-range selection onto another target
20 """
21
22 def __init__(self, si, br):
23 self.si = si # target selectable int
24 if isinstance(br, list) or isinstance(br, tuple):
25 _br = BitRange()
26 for i, v in enumerate(br):
27 _br[i] = v
28 br = _br
29 self.br = br # map of indices.
30
31 def eq(self, b):
32 if isinstance(b, SelectableInt):
33 for i in range(b.bits):
34 self[i] = b[i]
35 else:
36 self.si = copy(b.si)
37 self.br = copy(b.br)
38
39 def _op(self, op, b):
40 vi = self.get_range()
41 vi = op(vi, b)
42 return self.merge(vi)
43
44 def _op1(self, op):
45 vi = self.get_range()
46 vi = op(vi)
47 return self.merge(vi)
48
49 def __getitem__(self, key):
50 print("getitem", key, self.br)
51 if isinstance(key, SelectableInt):
52 key = key.value
53 if isinstance(key, int):
54 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
55 return self.si[key]
56 if isinstance(key, slice):
57 key = self.br[key]
58 return selectconcat(*[self.si[x] for x in key])
59
60 def __setitem__(self, key, value):
61 if isinstance(key, SelectableInt):
62 key = key.value
63 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
64 if isinstance(key, int):
65 return self.si.__setitem__(key, value)
66 else:
67 if not isinstance(value, SelectableInt):
68 value = SelectableInt(value, bits=len(key))
69 for i, k in enumerate(key):
70 self.si[k] = value[i]
71
72 def __negate__(self):
73 return self._op1(negate)
74
75 def __invert__(self):
76 return self._op1(inv)
77
78 def __add__(self, b):
79 return self._op(add, b)
80
81 def __sub__(self, b):
82 return self._op(sub, b)
83
84 def __mul__(self, b):
85 return self._op(mul, b)
86
87 def __div__(self, b):
88 return self._op(truediv, b)
89
90 def __mod__(self, b):
91 return self._op(mod, b)
92
93 def __and__(self, b):
94 return self._op(and_, b)
95
96 def __or__(self, b):
97 return self._op(or_, b)
98
99 def __xor__(self, b):
100 return self._op(xor, b)
101
102 def get_range(self):
103 vi = SelectableInt(0, len(self.br))
104 for k, v in self.br.items():
105 vi[k] = self.si[v]
106 return vi
107
108 def merge(self, vi):
109 fi = copy(self)
110 for i, v in fi.br.items():
111 fi.si[v] = vi[i]
112 return fi
113
114 def __repr__(self):
115 return "FieldSelectableInt(si=%s, br=%s)" % (self.si, self.br)
116
117
118 class FieldSelectableIntTestCase(unittest.TestCase):
119 def test_arith(self):
120 a = SelectableInt(0b10101, 5)
121 b = SelectableInt(0b011, 3)
122 br = BitRange()
123 br[0] = 0
124 br[1] = 2
125 br[2] = 3
126 fs = FieldSelectableInt(a, br)
127 c = fs + b
128 print(c)
129 #self.assertEqual(c.value, a.value + b.value)
130
131 def test_select(self):
132 a = SelectableInt(0b00001111, 8)
133 br = BitRange()
134 br[0] = 0
135 br[1] = 1
136 br[2] = 4
137 br[3] = 5
138 fs = FieldSelectableInt(a, br)
139
140 self.assertEqual(fs.get_range(), 0b0011)
141
142 def test_select_range(self):
143 a = SelectableInt(0b00001111, 8)
144 br = BitRange()
145 br[0] = 0
146 br[1] = 1
147 br[2] = 4
148 br[3] = 5
149 fs = FieldSelectableInt(a, br)
150
151 self.assertEqual(fs[2:4], 0b11)
152
153 fs[0:2] = 0b10
154 self.assertEqual(fs.get_range(), 0b1011)
155
156 class SelectableInt:
157 """SelectableInt - a class that behaves exactly like python int
158
159 this class is designed to mirror precisely the behaviour of python int.
160 the only difference is that it must contain the context of the bitwidth
161 (number of bits) associated with that integer.
162
163 FieldSelectableInt can then operate on partial bits, and because there
164 is a bit width associated with SelectableInt, slices operate correctly
165 including negative start/end points.
166 """
167
168 def __init__(self, value, bits):
169 if isinstance(value, SelectableInt):
170 value = value.value
171 mask = (1 << bits) - 1
172 self.value = value & mask
173 self.bits = bits
174
175 def eq(self, b):
176 self.value = b.value
177 self.bits = b.bits
178
179 def to_signed_int(self):
180 print ("to signed?", self.value & (1<<(self.bits-1)), self.value)
181 if self.value & (1<<(self.bits-1)) != 0: # negative
182 res = self.value - (1<<self.bits)
183 print (" val -ve:", self.bits, res)
184 else:
185 res = self.value
186 print (" val +ve:", res)
187 return res
188
189 def _op(self, op, b):
190 if isinstance(b, int):
191 b = SelectableInt(b, self.bits)
192 b = check_extsign(self, b)
193 assert b.bits == self.bits
194 return SelectableInt(op(self.value, b.value), self.bits)
195
196 def __add__(self, b):
197 return self._op(add, b)
198
199 def __sub__(self, b):
200 return self._op(sub, b)
201
202 def __mul__(self, b):
203 # different case: mul result needs to fit the total bitsize
204 if isinstance(b, int):
205 b = SelectableInt(b, self.bits)
206 print("SelectableInt mul", hex(self.value), hex(b.value),
207 self.bits, b.bits)
208 return SelectableInt(self.value * b.value, self.bits + b.bits)
209
210 def __floordiv__(self, b):
211 return self._op(floordiv, b)
212
213 def __truediv__(self, b):
214 return self._op(truediv, b)
215
216 def __mod__(self, b):
217 return self._op(mod, b)
218
219 def __and__(self, b):
220 return self._op(and_, b)
221
222 def __or__(self, b):
223 return self._op(or_, b)
224
225 def __xor__(self, b):
226 return self._op(xor, b)
227
228 def __abs__(self):
229 print("abs", self.value & (1 << (self.bits-1)))
230 if self.value & (1 << (self.bits-1)) != 0:
231 return -self
232 return self
233
234 def __rsub__(self, b):
235 if isinstance(b, int):
236 b = SelectableInt(b, self.bits)
237 b = check_extsign(self, b)
238 assert b.bits == self.bits
239 return SelectableInt(b.value - self.value, self.bits)
240
241 def __radd__(self, b):
242 if isinstance(b, int):
243 b = SelectableInt(b, self.bits)
244 b = check_extsign(self, b)
245 assert b.bits == self.bits
246 return SelectableInt(b.value + self.value, self.bits)
247
248 def __rxor__(self, b):
249 b = check_extsign(self, b)
250 assert b.bits == self.bits
251 return SelectableInt(self.value ^ b.value, self.bits)
252
253 def __invert__(self):
254 return SelectableInt(~self.value, self.bits)
255
256 def __neg__(self):
257 return SelectableInt(~self.value + 1, self.bits)
258
259 def __lshift__(self, b):
260 b = check_extsign(self, b)
261 return SelectableInt(self.value << b.value, self.bits)
262
263 def __rshift__(self, b):
264 b = check_extsign(self, b)
265 return SelectableInt(self.value >> b.value, self.bits)
266
267 def __getitem__(self, key):
268 if isinstance(key, SelectableInt):
269 key = key.value
270 if isinstance(key, int):
271 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
272 assert key >= 0
273 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
274 # MSB is indexed **LOWEST** (sigh)
275 key = self.bits - (key + 1)
276
277 value = (self.value >> key) & 1
278 return SelectableInt(value, 1)
279 elif isinstance(key, slice):
280 assert key.step is None or key.step == 1
281 assert key.start < key.stop
282 assert key.start >= 0
283 assert key.stop <= self.bits
284
285 stop = self.bits - key.start
286 start = self.bits - key.stop
287
288 bits = stop - start
289 #print ("__getitem__ slice num bits", bits)
290 mask = (1 << bits) - 1
291 value = (self.value >> start) & mask
292 return SelectableInt(value, bits)
293
294 def __setitem__(self, key, value):
295 if isinstance(key, SelectableInt):
296 key = key.value
297 if isinstance(key, int):
298 assert key < self.bits
299 assert key >= 0
300 key = self.bits - (key + 1)
301 if isinstance(value, SelectableInt):
302 assert value.bits == 1
303 value = value.value
304
305 value = value << key
306 mask = 1 << key
307 self.value = (self.value & ~mask) | (value & mask)
308 elif isinstance(key, slice):
309 assert key.step is None or key.step == 1
310 assert key.start < key.stop
311 assert key.start >= 0
312 assert key.stop <= self.bits
313
314 stop = self.bits - key.start
315 start = self.bits - key.stop
316
317 bits = stop - start
318 #print ("__setitem__ slice num bits", bits)
319 if isinstance(value, SelectableInt):
320 assert value.bits == bits, "%d into %d" % (value.bits, bits)
321 value = value.value
322 mask = ((1 << bits) - 1) << start
323 value = value << start
324 self.value = (self.value & ~mask) | (value & mask)
325
326 def __ge__(self, other):
327 if isinstance(other, FieldSelectableInt):
328 other = other.get_range()
329 if isinstance(other, SelectableInt):
330 other = check_extsign(self, other)
331 assert other.bits == self.bits
332 other = other.to_signed_int()
333 if isinstance(other, int):
334 return onebit(self.to_signed_int() >= other)
335 assert False
336
337 def __le__(self, other):
338 if isinstance(other, FieldSelectableInt):
339 other = other.get_range()
340 if isinstance(other, SelectableInt):
341 other = check_extsign(self, other)
342 assert other.bits == self.bits
343 other = other.to_signed_int()
344 if isinstance(other, int):
345 return onebit(self.to_signed_int() <= other)
346 assert False
347
348 def __gt__(self, other):
349 if isinstance(other, FieldSelectableInt):
350 other = other.get_range()
351 if isinstance(other, SelectableInt):
352 other = check_extsign(self, other)
353 assert other.bits == self.bits
354 other = other.to_signed_int()
355 if isinstance(other, int):
356 return onebit(self.to_signed_int() > other)
357 assert False
358
359 def __lt__(self, other):
360 print ("SelectableInt lt", self, other)
361 if isinstance(other, FieldSelectableInt):
362 other = other.get_range()
363 if isinstance(other, SelectableInt):
364 other = check_extsign(self, other)
365 assert other.bits == self.bits
366 other = other.to_signed_int()
367 if isinstance(other, int):
368 a = self.to_signed_int()
369 res = onebit(a < other)
370 print (" a < b", a, other, res)
371 return res
372 assert False
373
374 def __eq__(self, other):
375 print("__eq__", self, other)
376 if isinstance(other, FieldSelectableInt):
377 other = other.get_range()
378 if isinstance(other, SelectableInt):
379 other = check_extsign(self, other)
380 assert other.bits == self.bits
381 other = other.value
382 print (" eq", other, self.value, other == self.value)
383 if isinstance(other, int):
384 return onebit(other == self.value)
385 assert False
386
387 def narrow(self, bits):
388 assert bits <= self.bits
389 return SelectableInt(self.value, bits)
390
391 def __bool__(self):
392 return self.value != 0
393
394 def __repr__(self):
395 return "SelectableInt(value=0x{:x}, bits={})".format(self.value,
396 self.bits)
397
398 def __len__(self):
399 return self.bits
400
401 def asint(self):
402 return self.value
403
404
405 def onebit(bit):
406 return SelectableInt(1 if bit else 0, 1)
407
408
409 def selectltu(lhs, rhs):
410 """ less-than (unsigned)
411 """
412 if isinstance(rhs, SelectableInt):
413 rhs = rhs.value
414 return onebit(lhs.value < rhs)
415
416
417 def selectgtu(lhs, rhs):
418 """ greater-than (unsigned)
419 """
420 if isinstance(rhs, SelectableInt):
421 rhs = rhs.value
422 return onebit(lhs.value > rhs)
423
424
425 # XXX this probably isn't needed...
426 def selectassign(lhs, idx, rhs):
427 if isinstance(idx, tuple):
428 if len(idx) == 2:
429 lower, upper = idx
430 step = None
431 else:
432 lower, upper, step = idx
433 toidx = range(lower, upper, step)
434 fromidx = range(0, upper-lower, step) # XXX eurgh...
435 else:
436 toidx = [idx]
437 fromidx = [0]
438 for t, f in zip(toidx, fromidx):
439 lhs[t] = rhs[f]
440
441
442 def selectconcat(*args, repeat=1):
443 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
444 args = [SelectableInt(args[0], 1)]
445 if repeat != 1: # multiplies the incoming arguments
446 tmp = []
447 for i in range(repeat):
448 tmp += args
449 args = tmp
450 res = copy(args[0])
451 for i in args[1:]:
452 if isinstance(i, FieldSelectableInt):
453 i = i.si
454 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
455 res.bits += i.bits
456 res.value = (res.value << i.bits) | i.value
457 print("concat", repeat, res)
458 return res
459
460
461 class SelectableIntTestCase(unittest.TestCase):
462 def test_arith(self):
463 a = SelectableInt(5, 8)
464 b = SelectableInt(9, 8)
465 c = a + b
466 d = a - b
467 e = a * b
468 f = -a
469 g = abs(f)
470 h = abs(a)
471 self.assertEqual(c.value, a.value + b.value)
472 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
473 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
474 self.assertEqual(f.value, (-a.value) & 0xFF)
475 self.assertEqual(c.bits, a.bits)
476 self.assertEqual(d.bits, a.bits)
477 self.assertEqual(e.bits, a.bits)
478 self.assertEqual(a.bits, f.bits)
479 self.assertEqual(a.bits, h.bits)
480
481 def test_logic(self):
482 a = SelectableInt(0x0F, 8)
483 b = SelectableInt(0xA5, 8)
484 c = a & b
485 d = a | b
486 e = a ^ b
487 f = ~a
488 self.assertEqual(c.value, a.value & b.value)
489 self.assertEqual(d.value, a.value | b.value)
490 self.assertEqual(e.value, a.value ^ b.value)
491 self.assertEqual(f.value, 0xF0)
492
493 def test_get(self):
494 a = SelectableInt(0xa2, 8)
495 # These should be big endian
496 self.assertEqual(a[7], 0)
497 self.assertEqual(a[0:4], 10)
498 self.assertEqual(a[4:8], 2)
499
500 def test_set(self):
501 a = SelectableInt(0x5, 8)
502 a[7] = SelectableInt(0, 1)
503 self.assertEqual(a, 4)
504 a[4:8] = 9
505 self.assertEqual(a, 9)
506 a[0:4] = 3
507 self.assertEqual(a, 0x39)
508 a[0:4] = a[4:8]
509 self.assertEqual(a, 0x99)
510
511 def test_concat(self):
512 a = SelectableInt(0x1, 1)
513 c = selectconcat(a, repeat=8)
514 self.assertEqual(c, 0xff)
515 self.assertEqual(c.bits, 8)
516 a = SelectableInt(0x0, 1)
517 c = selectconcat(a, repeat=8)
518 self.assertEqual(c, 0x00)
519 self.assertEqual(c.bits, 8)
520
521 def test_repr(self):
522 for i in range(65536):
523 a = SelectableInt(i, 16)
524 b = eval(repr(a))
525 self.assertEqual(a, b)
526
527 def test_cmp(self):
528 a = SelectableInt(10, bits=8)
529 b = SelectableInt(5, bits=8)
530 self.assertTrue(a > b)
531 self.assertFalse(a < b)
532 self.assertTrue(a != b)
533 self.assertFalse(a == b)
534
535 def test_unsigned(self):
536 a = SelectableInt(0x80, bits=8)
537 b = SelectableInt(0x7f, bits=8)
538 self.assertTrue(a > b)
539 self.assertFalse(a < b)
540 self.assertTrue(a != b)
541 self.assertFalse(a == b)
542
543
544 if __name__ == "__main__":
545 unittest.main()