add modulo to parser (and div to SelectableInt)
[soc.git] / src / soc / decoder / selectable_int.py
1 import unittest
2 from copy import copy
3
4
5 class SelectableInt:
6 def __init__(self, value, bits):
7 mask = (1 << bits) - 1
8 self.value = value & mask
9 self.bits = bits
10
11 def __add__(self, b):
12 assert b.bits == self.bits
13 return SelectableInt(self.value + b.value, self.bits)
14
15 def __sub__(self, b):
16 assert b.bits == self.bits
17 return SelectableInt(self.value - b.value, self.bits)
18
19 def __mul__(self, b):
20 assert b.bits == self.bits
21 return SelectableInt(self.value * b.value, self.bits)
22
23 def __div__(self, b):
24 assert b.bits == self.bits
25 return SelectableInt(self.value / b.value, self.bits)
26
27 def __mod__(self, b):
28 assert b.bits == self.bits
29 return SelectableInt(self.value % b.value, self.bits)
30
31 def __or__(self, b):
32 assert b.bits == self.bits
33 return SelectableInt(self.value | b.value, self.bits)
34
35 def __and__(self, b):
36 assert b.bits == self.bits
37 return SelectableInt(self.value & b.value, self.bits)
38
39 def __xor__(self, b):
40 assert b.bits == self.bits
41 return SelectableInt(self.value ^ b.value, self.bits)
42
43 def __invert__(self):
44 return SelectableInt(~self.value, self.bits)
45
46 def __neg__(self):
47 return SelectableInt(~self.value + 1, self.bits)
48
49 def __getitem__(self, key):
50 if isinstance(key, int):
51 assert key < self.bits, "key %d accessing %d" % (key, self.bits)
52 assert key >= 0
53 key = self.bits - (key + 1)
54
55 value = (self.value >> key) & 1
56 return SelectableInt(value, 1)
57 elif isinstance(key, slice):
58 assert key.step is None or key.step == 1
59 assert key.start < key.stop
60 assert key.start >= 0
61 assert key.stop <= self.bits
62
63 stop = self.bits - key.start
64 start = self.bits - key.stop
65
66 bits = stop - start + 1
67 mask = (1 << bits) - 1
68 value = (self.value >> start) & mask
69 return SelectableInt(value, bits)
70
71 def __setitem__(self, key, value):
72 if isinstance(key, int):
73 assert key < self.bits
74 assert key >= 0
75 key = self.bits - (key + 1)
76 if isinstance(value, SelectableInt):
77 assert value.bits == 1
78 value = value.value
79
80 value = value << key
81 mask = 1 << key
82 self.value = (self.value & ~mask) | (value & mask)
83 elif isinstance(key, slice):
84 assert key.step is None or key.step == 1
85 assert key.start < key.stop
86 assert key.start >= 0
87 assert key.stop <= self.bits
88
89 stop = self.bits - key.start
90 start = self.bits - key.stop
91
92 bits = stop - start + 1
93 if isinstance(value, SelectableInt):
94 assert value.bits == bits, "%d into %d" % (value.bits, bits)
95 value = value.value
96 mask = ((1 << bits) - 1) << start
97 value = value << start
98 self.value = (self.value & ~mask) | (value & mask)
99
100 def __ge__(self, other):
101 if isinstance(other, SelectableInt):
102 assert other.bits == self.bits
103 other = other.value
104 if isinstance(other, int):
105 return other >= self.value
106 assert False
107
108 def __le__(self, other):
109 if isinstance(other, SelectableInt):
110 assert other.bits == self.bits
111 other = other.value
112 if isinstance(other, int):
113 return onebit(other <= self.value)
114 assert False
115
116 def __gt__(self, other):
117 if isinstance(other, SelectableInt):
118 assert other.bits == self.bits
119 other = other.value
120 if isinstance(other, int):
121 return onebit(other > self.value)
122 assert False
123
124 def __lt__(self, other):
125 if isinstance(other, SelectableInt):
126 assert other.bits == self.bits
127 other = other.value
128 if isinstance(other, int):
129 return onebit(other < self.value)
130 assert False
131
132 def __eq__(self, other):
133 if isinstance(other, SelectableInt):
134 assert other.bits == self.bits
135 other = other.value
136 if isinstance(other, int):
137 return onebit(other == self.value)
138 assert False
139
140 def __bool__(self):
141 return self.value != 0
142
143 def __repr__(self):
144 return "SelectableInt(value={:x}, bits={})".format(self.value,
145 self.bits)
146
147 def onebit(bit):
148 return SelectableInt(1 if bit else 0, 1)
149
150 def selectltu(lhs, rhs):
151 """ less-than (unsigned)
152 """
153 if isinstance(rhs, SelectableInt):
154 rhs = rhs.value
155 return onebit(lhs.value < rhs)
156
157 def selectgtu(lhs, rhs):
158 """ greater-than (unsigned)
159 """
160 if isinstance(rhs, SelectableInt):
161 rhs = rhs.value
162 return onebit(lhs.value > rhs)
163
164
165 # XXX this probably isn't needed...
166 def selectassign(lhs, idx, rhs):
167 if isinstance(idx, tuple):
168 if len(idx) == 2:
169 lower, upper = idx
170 step = None
171 else:
172 lower, upper, step = idx
173 toidx = range(lower, upper, step)
174 fromidx = range(0, upper-lower, step) # XXX eurgh...
175 else:
176 toidx = [idx]
177 fromidx = [0]
178 for t, f in zip(toidx, fromidx):
179 lhs[t] = rhs[f]
180
181
182 def selectconcat(*args):
183 res = copy(args[0])
184 for i in args[1:]:
185 assert isinstance(i, SelectableInt), "can only concat SIs, sorry"
186 res.bits += i.bits
187 res.value = (res.value << i.bits) | i.value
188 return res
189
190
191 class SelectableIntTestCase(unittest.TestCase):
192 def test_arith(self):
193 a = SelectableInt(5, 8)
194 b = SelectableInt(9, 8)
195 c = a + b
196 d = a - b
197 e = a * b
198 f = -a
199 self.assertEqual(c.value, a.value + b.value)
200 self.assertEqual(d.value, (a.value - b.value) & 0xFF)
201 self.assertEqual(e.value, (a.value * b.value) & 0xFF)
202 self.assertEqual(f.value, (-a.value) & 0xFF)
203 self.assertEqual(c.bits, a.bits)
204 self.assertEqual(d.bits, a.bits)
205 self.assertEqual(e.bits, a.bits)
206
207 def test_logic(self):
208 a = SelectableInt(0x0F, 8)
209 b = SelectableInt(0xA5, 8)
210 c = a & b
211 d = a | b
212 e = a ^ b
213 f = ~a
214 self.assertEqual(c.value, a.value & b.value)
215 self.assertEqual(d.value, a.value | b.value)
216 self.assertEqual(e.value, a.value ^ b.value)
217 self.assertEqual(f.value, 0xF0)
218
219 def test_get(self):
220 a = SelectableInt(0xa2, 8)
221 # These should be big endian
222 self.assertEqual(a[7], 0)
223 self.assertEqual(a[0:4], 10)
224 self.assertEqual(a[4:8], 2)
225
226 def test_set(self):
227 a = SelectableInt(0x5, 8)
228 a[7] = SelectableInt(0, 1)
229 self.assertEqual(a, 4)
230 a[4:8] = 9
231 self.assertEqual(a, 9)
232 a[0:4] = 3
233 self.assertEqual(a, 0x39)
234 a[0:4] = a[4:8]
235 self.assertEqual(a, 0x199)
236
237 if __name__ == "__main__":
238 unittest.main()