try making CR bitrange 32..63 not 0..31
[soc.git] / src / soc / decoder / selectable_int.py
1 import unittest
2 from copy import copy
3 from soc.decoder.power_fields import BitRange
4 from operator import (add, sub, mul, truediv, mod, or_, and_, xor, neg, inv)
5
6
7 def check_extsign(a, b):
8 if isinstance(b, FieldSelectableInt):
9 b = b.get_range()
10 if b.bits != 256:
11 return b
12 return SelectableInt(b.value, a.bits)
13
14
15 class FieldSelectableInt:
16 """FieldSelectableInt: allows bit-range selection onto another target
17 """
18 def __init__(self, si, br):
19 self.si = si # target selectable int
20 if isinstance(br, list) or isinstance(br, tuple):
21 _br = BitRange()
22 for i, v in enumerate(br):
23 _br[i] = v
24 br = _br
25 self.br = br # map of indices.
26
27 def eq(self, b):
28 if isinstance(b, SelectableInt):
29 for i in range(b.bits):
30 self[i] = b[i]
31 else:
32 self.si = copy(b.si)
33 self.br = copy(b.br)
34
35 def _op(self, op, b):
36 vi = self.get_range()
37 vi = op(vi, b)
38 return self.merge(vi)
39
40 def _op1(self, op):
41 vi = self.get_range()
42 vi = op(vi)
43 return self.merge(vi)
44
45 def __getitem__(self, key):
46 print ("getitem", key, self.br)
47 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
48 return self.si[key]
49
50 def __setitem__(self, key, value):
51 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
52 return self.si.__setitem__(key, value)
53
54 def __negate__(self):
55 return self._op1(negate)
56 def __invert__(self):
57 return self._op1(inv)
58 def __add__(self, b):
59 return self._op(add, b)
60 def __sub__(self, b):
61 return self._op(sub, b)
62 def __mul__(self, b):
63 return self._op(mul, b)
64 def __div__(self, b):
65 return self._op(truediv, b)
66 def __mod__(self, b):
67 return self._op(mod, b)
68 def __and__(self, b):
69 return self._op(and_, b)
70 def __or__(self, b):
71 return self._op(or_, b)
72 def __xor__(self, b):
73 return self._op(xor, b)
74
75 def get_range(self):
76 print ("get_range", self.si)
77 vi = SelectableInt(0, len(self.br))
78 for k, v in self.br.items():
79 print ("get_range", k, v, self.si[v])
80 vi[k] = self.si[v]
81 print ("get_range", vi)
82 return vi
83
84 def merge(self, vi):
85 fi = copy(self)
86 for i, v in fi.br.items():
87 fi.si[v] = vi[i]
88 return fi
89
90 def __repr__(self):
91 return "FieldSelectableInt(si=%s, br=%s)" % (self.si, self.br)
92
93
94 class FieldSelectableIntTestCase(unittest.TestCase):
95 def test_arith(self):
96 a = SelectableInt(0b10101, 5)
97 b = SelectableInt(0b011, 3)
98 br = BitRange()
99 br[0] = 0
100 br[1] = 2
101 br[2] = 3
102 fs = FieldSelectableInt(a, br)
103 c = fs + b
104 print (c)
105 #self.assertEqual(c.value, a.value + b.value)
106
107
108 class SelectableInt:
109 def __init__(self, value, bits):
110 mask = (1 << bits) - 1
111 self.value = value & mask
112 self.bits = bits
113
114 def eq(self, b):
115 self.value = b.value
116 self.bits = b.bits
117
118 def __add__(self, b):
119 if isinstance(b, int):
120 b = SelectableInt(b, self.bits)
121 b = check_extsign(self, b)
122 assert b.bits == self.bits
123 return SelectableInt(self.value + b.value, self.bits)
124
125 def __sub__(self, b):
126 if isinstance(b, int):
127 b = SelectableInt(b, self.bits)
128 b = check_extsign(self, b)
129 assert b.bits == self.bits
130 return SelectableInt(self.value - b.value, self.bits)
131
132 def __mul__(self, b):
133 b = check_extsign(self, b)
134 assert b.bits == self.bits
135 return SelectableInt(self.value * b.value, self.bits)
136
137 def __div__(self, b):
138 b = check_extsign(self, b)
139 assert b.bits == self.bits
140 return SelectableInt(self.value / b.value, self.bits)
141
142 def __mod__(self, b):
143 b = check_extsign(self, b)
144 assert b.bits == self.bits
145 return SelectableInt(self.value % b.value, self.bits)
146
147 def __or__(self, b):
148 b = check_extsign(self, b)
149 assert b.bits == self.bits
150 return SelectableInt(self.value | b.value, self.bits)
151
152 def __and__(self, b):
153 print ("__and__", self, b)
154 b = check_extsign(self, b)
155 assert b.bits == self.bits
156 return SelectableInt(self.value & b.value, self.bits)
157
158 def __xor__(self, b):
159 b = check_extsign(self, b)
160 assert b.bits == self.bits
161 return SelectableInt(self.value ^ b.value, self.bits)
162
163 def __invert__(self):
164 return SelectableInt(~self.value, self.bits)
165
166 def __neg__(self):
167 return SelectableInt(~self.value + 1, self.bits)
168
169 def __getitem__(self, key):
170 if isinstance(key, int):
171 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
172 assert key >= 0
173 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
174 # MSB is indexed **LOWEST** (sigh)
175 key = self.bits - (key + 1)
176
177 value = (self.value >> key) & 1
178 return SelectableInt(value, 1)
179 elif isinstance(key, slice):
180 assert key.step is None or key.step == 1
181 assert key.start < key.stop
182 assert key.start >= 0
183 assert key.stop <= self.bits
184
185 stop = self.bits - key.start
186 start = self.bits - key.stop
187
188 bits = stop - start
189 #print ("__getitem__ slice num bits", bits)
190 mask = (1 << bits) - 1
191 value = (self.value >> start) & mask
192 return SelectableInt(value, bits)
193
194 def __setitem__(self, key, value):
195 if isinstance(key, int):
196 assert key < self.bits
197 assert key >= 0
198 key = self.bits - (key + 1)
199 if isinstance(value, SelectableInt):
200 assert value.bits == 1
201 value = value.value
202
203 value = value << key
204 mask = 1 << key
205 self.value = (self.value & ~mask) | (value & mask)
206 elif isinstance(key, slice):
207 assert key.step is None or key.step == 1
208 assert key.start < key.stop
209 assert key.start >= 0
210 assert key.stop <= self.bits
211
212 stop = self.bits - key.start
213 start = self.bits - key.stop
214
215 bits = stop - start
216 #print ("__setitem__ slice num bits", bits)
217 if isinstance(value, SelectableInt):
218 assert value.bits == bits, "%d into %d" % (value.bits, bits)
219 value = value.value
220 mask = ((1 << bits) - 1) << start
221 value = value << start
222 self.value = (self.value & ~mask) | (value & mask)
223
224 def __ge__(self, other):
225 if isinstance(other, FieldSelectableInt):
226 other = other.get_range()
227 if isinstance(other, SelectableInt):
228 other = check_extsign(self, other)
229 assert other.bits == self.bits
230 other = other.value
231 if isinstance(other, int):
232 return other >= self.value
233 assert False
234
235 def __le__(self, other):
236 if isinstance(other, FieldSelectableInt):
237 other = other.get_range()
238 if isinstance(other, SelectableInt):
239 other = check_extsign(self, other)
240 assert other.bits == self.bits
241 other = other.value
242 if isinstance(other, int):
243 return onebit(other <= self.value)
244 assert False
245
246 def __gt__(self, other):
247 if isinstance(other, FieldSelectableInt):
248 other = other.get_range()
249 if isinstance(other, SelectableInt):
250 other = check_extsign(self, other)
251 assert other.bits == self.bits
252 other = other.value
253 if isinstance(other, int):
254 return onebit(other > self.value)
255 assert False
256
257 def __lt__(self, other):
258 if isinstance(other, FieldSelectableInt):
259 other = other.get_range()
260 if isinstance(other, SelectableInt):
261 other = check_extsign(self, other)
262 assert other.bits == self.bits
263 other = other.value
264 if isinstance(other, int):
265 return onebit(other < self.value)
266 assert False
267
268 def __eq__(self, other):
269 print ("__eq__", self, other)
270 if isinstance(other, FieldSelectableInt):
271 other = other.get_range()
272 if isinstance(other, SelectableInt):
273 other = check_extsign(self, other)
274 assert other.bits == self.bits
275 other = other.value
276 if isinstance(other, int):
277 return onebit(other == self.value)
278 assert False
279
280 def narrow(self, bits):
281 assert bits <= self.bits
282 return SelectableInt(self.value, bits)
283
284 def __bool__(self):
285 return self.value != 0
286
287 def __repr__(self):
288 return "SelectableInt(value=0x{:x}, bits={})".format(self.value,
289 self.bits)
290
291 def onebit(bit):
292 return SelectableInt(1 if bit else 0, 1)
293
294 def selectltu(lhs, rhs):
295 """ less-than (unsigned)
296 """
297 if isinstance(rhs, SelectableInt):
298 rhs = rhs.value
299 return onebit(lhs.value < rhs)
300
301 def selectgtu(lhs, rhs):
302 """ greater-than (unsigned)
303 """
304 if isinstance(rhs, SelectableInt):
305 rhs = rhs.value
306 return onebit(lhs.value > rhs)
307
308
309 # XXX this probably isn't needed...
310 def selectassign(lhs, idx, rhs):
311 if isinstance(idx, tuple):
312 if len(idx) == 2:
313 lower, upper = idx
314 step = None
315 else:
316 lower, upper, step = idx
317 toidx = range(lower, upper, step)
318 fromidx = range(0, upper-lower, step) # XXX eurgh...
319 else:
320 toidx = [idx]
321 fromidx = [0]
322 for t, f in zip(toidx, fromidx):
323 lhs[t] = rhs[f]
324
325
326 def selectconcat(*args, repeat=1):
327 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
328 args = [SelectableInt(args[0], 1)]
329 if repeat != 1: # multiplies the incoming arguments
330 tmp = []
331 for i in range(repeat):
332 tmp += args
333 args = tmp
334 res = copy(args[0])
335 for i in args[1:]:
336 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
337 res.bits += i.bits
338 res.value = (res.value << i.bits) | i.value
339 print ("concat", repeat, res)
340 return res
341
342
343 class SelectableIntTestCase(unittest.TestCase):
344 def test_arith(self):
345 a = SelectableInt(5, 8)
346 b = SelectableInt(9, 8)
347 c = a + b
348 d = a - b
349 e = a * b
350 f = -a
351 self.assertEqual(c.value, a.value + b.value)
352 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
353 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
354 self.assertEqual(f.value, (-a.value) & 0xFF)
355 self.assertEqual(c.bits, a.bits)
356 self.assertEqual(d.bits, a.bits)
357 self.assertEqual(e.bits, a.bits)
358
359 def test_logic(self):
360 a = SelectableInt(0x0F, 8)
361 b = SelectableInt(0xA5, 8)
362 c = a & b
363 d = a | b
364 e = a ^ b
365 f = ~a
366 self.assertEqual(c.value, a.value & b.value)
367 self.assertEqual(d.value, a.value | b.value)
368 self.assertEqual(e.value, a.value ^ b.value)
369 self.assertEqual(f.value, 0xF0)
370
371 def test_get(self):
372 a = SelectableInt(0xa2, 8)
373 # These should be big endian
374 self.assertEqual(a[7], 0)
375 self.assertEqual(a[0:4], 10)
376 self.assertEqual(a[4:8], 2)
377
378 def test_set(self):
379 a = SelectableInt(0x5, 8)
380 a[7] = SelectableInt(0, 1)
381 self.assertEqual(a, 4)
382 a[4:8] = 9
383 self.assertEqual(a, 9)
384 a[0:4] = 3
385 self.assertEqual(a, 0x39)
386 a[0:4] = a[4:8]
387 self.assertEqual(a, 0x99)
388
389 def test_concat(self):
390 a = SelectableInt(0x1, 1)
391 c = selectconcat(a, repeat=8)
392 self.assertEqual(c, 0xff)
393 self.assertEqual(c.bits, 8)
394 a = SelectableInt(0x0, 1)
395 c = selectconcat(a, repeat=8)
396 self.assertEqual(c, 0x00)
397 self.assertEqual(c.bits, 8)
398
399 def test_repr(self):
400 for i in range(65536):
401 a = SelectableInt(i, 16)
402 b = eval(repr(a))
403 self.assertEqual(a, b)
404
405 if __name__ == "__main__":
406 unittest.main()