add in special regs to be passed out of function (as return results)
[soc.git] / src / soc / decoder / selectable_int.py
1 import unittest
2 from copy import copy
3 from soc.decoder.power_fields import BitRange
4 from operator import (add, sub, mul, truediv, mod, or_, and_, xor, neg, inv)
5
6
7 def check_extsign(a, b):
8 if b.bits != 256:
9 return b
10 return SelectableInt(b.value, a.bits)
11
12
13 class FieldSelectableInt:
14 """FieldSelectableInt: allows bit-range selection onto another target
15 """
16 def __init__(self, si, br):
17 self.si = si # target selectable int
18 if isinstance(br, list) or isinstance(br, tuple):
19 _br = BitRange()
20 for i, v in enumerate(br):
21 _br[i] = v
22 br = _br
23 self.br = br # map of indices.
24
25 def eq(self, b):
26 self.si = copy(b.si)
27 self.br = copy(b.br)
28
29 def _op(self, op, b):
30 vi = self.get_range()
31 vi = op(vi, b)
32 return self.merge(vi)
33
34 def _op1(self, op):
35 vi = self.get_range()
36 vi = op(vi)
37 return self.merge(vi)
38
39 def __getitem__(self, key):
40 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
41 return self.si[key]
42
43 def __setitem__(self, key, value):
44 key = self.br[key] # don't do POWER 1.3.4 bit-inversion
45 return self.si__setitem__(key, value)
46
47 def __negate__(self):
48 return self._op1(negate)
49 def __invert__(self):
50 return self._op1(inv)
51 def __add__(self, b):
52 return self._op(add, b)
53 def __sub__(self, b):
54 return self._op(sub, b)
55 def __mul__(self, b):
56 return self._op(mul, b)
57 def __div__(self, b):
58 return self._op(truediv, b)
59 def __mod__(self, b):
60 return self._op(mod, b)
61 def __and__(self, b):
62 return self._op(and_, b)
63 def __or__(self, b):
64 return self._op(or_, b)
65 def __xor__(self, b):
66 return self._op(xor, b)
67
68 def get_range(self):
69 print ("get_range", self.si)
70 vi = SelectableInt(0, len(self.br))
71 for k, v in self.br.items():
72 print ("get_range", k, v, self.si[v])
73 vi[k] = self.si[v]
74 print ("get_range", vi)
75 return vi
76
77 def merge(self, vi):
78 fi = copy(self)
79 for i, v in fi.br.items():
80 fi.si[v] = vi[i]
81 return fi
82
83 def __repr__(self):
84 return "FieldSelectableInt(si=%s, br=%s)" % (self.si, self.br)
85
86
87 class FieldSelectableIntTestCase(unittest.TestCase):
88 def test_arith(self):
89 a = SelectableInt(0b10101, 5)
90 b = SelectableInt(0b011, 3)
91 br = BitRange()
92 br[0] = 0
93 br[1] = 2
94 br[2] = 3
95 fs = FieldSelectableInt(a, br)
96 c = fs + b
97 print (c)
98 #self.assertEqual(c.value, a.value + b.value)
99
100
101 class SelectableInt:
102 def __init__(self, value, bits):
103 mask = (1 << bits) - 1
104 self.value = value & mask
105 self.bits = bits
106
107 def eq(self, b):
108 self.value = b.value
109 self.bits = b.bits
110
111 def __add__(self, b):
112 if isinstance(b, int):
113 b = SelectableInt(b, self.bits)
114 b = check_extsign(self, b)
115 assert b.bits == self.bits
116 return SelectableInt(self.value + b.value, self.bits)
117
118 def __sub__(self, b):
119 if isinstance(b, int):
120 b = SelectableInt(b, self.bits)
121 b = check_extsign(self, b)
122 assert b.bits == self.bits
123 return SelectableInt(self.value - b.value, self.bits)
124
125 def __mul__(self, b):
126 b = check_extsign(self, b)
127 assert b.bits == self.bits
128 return SelectableInt(self.value * b.value, self.bits)
129
130 def __div__(self, b):
131 b = check_extsign(self, b)
132 assert b.bits == self.bits
133 return SelectableInt(self.value / b.value, self.bits)
134
135 def __mod__(self, b):
136 b = check_extsign(self, b)
137 assert b.bits == self.bits
138 return SelectableInt(self.value % b.value, self.bits)
139
140 def __or__(self, b):
141 b = check_extsign(self, b)
142 assert b.bits == self.bits
143 return SelectableInt(self.value | b.value, self.bits)
144
145 def __and__(self, b):
146 print ("__and__", self, b)
147 b = check_extsign(self, b)
148 assert b.bits == self.bits
149 return SelectableInt(self.value & b.value, self.bits)
150
151 def __xor__(self, b):
152 b = check_extsign(self, b)
153 assert b.bits == self.bits
154 return SelectableInt(self.value ^ b.value, self.bits)
155
156 def __invert__(self):
157 return SelectableInt(~self.value, self.bits)
158
159 def __neg__(self):
160 return SelectableInt(~self.value + 1, self.bits)
161
162 def __getitem__(self, key):
163 if isinstance(key, int):
164 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
165 assert key >= 0
166 # NOTE: POWER 3.0B annotation order! see p4 1.3.2
167 # MSB is indexed **LOWEST** (sigh)
168 key = self.bits - (key + 1)
169
170 value = (self.value >> key) & 1
171 return SelectableInt(value, 1)
172 elif isinstance(key, slice):
173 assert key.step is None or key.step == 1
174 assert key.start < key.stop
175 assert key.start >= 0
176 assert key.stop <= self.bits
177
178 stop = self.bits - key.start
179 start = self.bits - key.stop
180
181 bits = stop - start
182 #print ("__getitem__ slice num bits", bits)
183 mask = (1 << bits) - 1
184 value = (self.value >> start) & mask
185 return SelectableInt(value, bits)
186
187 def __setitem__(self, key, value):
188 if isinstance(key, int):
189 assert key < self.bits
190 assert key >= 0
191 key = self.bits - (key + 1)
192 if isinstance(value, SelectableInt):
193 assert value.bits == 1
194 value = value.value
195
196 value = value << key
197 mask = 1 << key
198 self.value = (self.value & ~mask) | (value & mask)
199 elif isinstance(key, slice):
200 assert key.step is None or key.step == 1
201 assert key.start < key.stop
202 assert key.start >= 0
203 assert key.stop <= self.bits
204
205 stop = self.bits - key.start
206 start = self.bits - key.stop
207
208 bits = stop - start
209 #print ("__setitem__ slice num bits", bits)
210 if isinstance(value, SelectableInt):
211 assert value.bits == bits, "%d into %d" % (value.bits, bits)
212 value = value.value
213 mask = ((1 << bits) - 1) << start
214 value = value << start
215 self.value = (self.value & ~mask) | (value & mask)
216
217 def __ge__(self, other):
218 if isinstance(other, FieldSelectableInt):
219 other = other.get_range()
220 if isinstance(other, SelectableInt):
221 other = check_extsign(self, other)
222 assert other.bits == self.bits
223 other = other.value
224 if isinstance(other, int):
225 return other >= self.value
226 assert False
227
228 def __le__(self, other):
229 if isinstance(other, FieldSelectableInt):
230 other = other.get_range()
231 if isinstance(other, SelectableInt):
232 other = check_extsign(self, other)
233 assert other.bits == self.bits
234 other = other.value
235 if isinstance(other, int):
236 return onebit(other <= self.value)
237 assert False
238
239 def __gt__(self, other):
240 if isinstance(other, FieldSelectableInt):
241 other = other.get_range()
242 if isinstance(other, SelectableInt):
243 other = check_extsign(self, other)
244 assert other.bits == self.bits
245 other = other.value
246 if isinstance(other, int):
247 return onebit(other > self.value)
248 assert False
249
250 def __lt__(self, other):
251 if isinstance(other, FieldSelectableInt):
252 other = other.get_range()
253 if isinstance(other, SelectableInt):
254 other = check_extsign(self, other)
255 assert other.bits == self.bits
256 other = other.value
257 if isinstance(other, int):
258 return onebit(other < self.value)
259 assert False
260
261 def __eq__(self, other):
262 print ("__eq__", self, other)
263 if isinstance(other, FieldSelectableInt):
264 other = other.get_range()
265 if isinstance(other, SelectableInt):
266 other = check_extsign(self, other)
267 assert other.bits == self.bits
268 other = other.value
269 if isinstance(other, int):
270 return onebit(other == self.value)
271 assert False
272
273 def narrow(self, bits):
274 assert bits <= self.bits
275 return SelectableInt(self.value, bits)
276
277 def __bool__(self):
278 return self.value != 0
279
280 def __repr__(self):
281 return "SelectableInt(value=0x{:x}, bits={})".format(self.value,
282 self.bits)
283
284 def onebit(bit):
285 return SelectableInt(1 if bit else 0, 1)
286
287 def selectltu(lhs, rhs):
288 """ less-than (unsigned)
289 """
290 if isinstance(rhs, SelectableInt):
291 rhs = rhs.value
292 return onebit(lhs.value < rhs)
293
294 def selectgtu(lhs, rhs):
295 """ greater-than (unsigned)
296 """
297 if isinstance(rhs, SelectableInt):
298 rhs = rhs.value
299 return onebit(lhs.value > rhs)
300
301
302 # XXX this probably isn't needed...
303 def selectassign(lhs, idx, rhs):
304 if isinstance(idx, tuple):
305 if len(idx) == 2:
306 lower, upper = idx
307 step = None
308 else:
309 lower, upper, step = idx
310 toidx = range(lower, upper, step)
311 fromidx = range(0, upper-lower, step) # XXX eurgh...
312 else:
313 toidx = [idx]
314 fromidx = [0]
315 for t, f in zip(toidx, fromidx):
316 lhs[t] = rhs[f]
317
318
319 def selectconcat(*args, repeat=1):
320 if repeat != 1 and len(args) == 1 and isinstance(args[0], int):
321 args = [SelectableInt(args[0], 1)]
322 if repeat != 1: # multiplies the incoming arguments
323 tmp = []
324 for i in range(repeat):
325 tmp += args
326 args = tmp
327 res = copy(args[0])
328 for i in args[1:]:
329 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
330 res.bits += i.bits
331 res.value = (res.value << i.bits) | i.value
332 print ("concat", repeat, res)
333 return res
334
335
336 class SelectableIntTestCase(unittest.TestCase):
337 def test_arith(self):
338 a = SelectableInt(5, 8)
339 b = SelectableInt(9, 8)
340 c = a + b
341 d = a - b
342 e = a * b
343 f = -a
344 self.assertEqual(c.value, a.value + b.value)
345 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
346 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
347 self.assertEqual(f.value, (-a.value) & 0xFF)
348 self.assertEqual(c.bits, a.bits)
349 self.assertEqual(d.bits, a.bits)
350 self.assertEqual(e.bits, a.bits)
351
352 def test_logic(self):
353 a = SelectableInt(0x0F, 8)
354 b = SelectableInt(0xA5, 8)
355 c = a & b
356 d = a | b
357 e = a ^ b
358 f = ~a
359 self.assertEqual(c.value, a.value & b.value)
360 self.assertEqual(d.value, a.value | b.value)
361 self.assertEqual(e.value, a.value ^ b.value)
362 self.assertEqual(f.value, 0xF0)
363
364 def test_get(self):
365 a = SelectableInt(0xa2, 8)
366 # These should be big endian
367 self.assertEqual(a[7], 0)
368 self.assertEqual(a[0:4], 10)
369 self.assertEqual(a[4:8], 2)
370
371 def test_set(self):
372 a = SelectableInt(0x5, 8)
373 a[7] = SelectableInt(0, 1)
374 self.assertEqual(a, 4)
375 a[4:8] = 9
376 self.assertEqual(a, 9)
377 a[0:4] = 3
378 self.assertEqual(a, 0x39)
379 a[0:4] = a[4:8]
380 self.assertEqual(a, 0x99)
381
382 def test_concat(self):
383 a = SelectableInt(0x1, 1)
384 c = selectconcat(a, repeat=8)
385 self.assertEqual(c, 0xff)
386 self.assertEqual(c.bits, 8)
387 a = SelectableInt(0x0, 1)
388 c = selectconcat(a, repeat=8)
389 self.assertEqual(c, 0x00)
390 self.assertEqual(c.bits, 8)
391
392 def test_repr(self):
393 for i in range(65536):
394 a = SelectableInt(i, 16)
395 b = eval(repr(a))
396 self.assertEqual(a, b)
397
398 if __name__ == "__main__":
399 unittest.main()