hdl.rec: make Record inherit from UserValue.
[nmigen.git] / nmigen / test / test_hdl_xfrm.py
1 # nmigen: UnusedElaboratable=no
2
3 from ..hdl.ast import *
4 from ..hdl.cd import *
5 from ..hdl.ir import *
6 from ..hdl.xfrm import *
7 from ..hdl.mem import *
8 from .utils import *
9
10
11 class DomainRenamerTestCase(FHDLTestCase):
12 def setUp(self):
13 self.s1 = Signal()
14 self.s2 = Signal()
15 self.s3 = Signal()
16 self.s4 = Signal()
17 self.s5 = Signal()
18 self.c1 = Signal()
19
20 def test_rename_signals(self):
21 f = Fragment()
22 f.add_statements(
23 self.s1.eq(ClockSignal()),
24 ResetSignal().eq(self.s2),
25 self.s3.eq(0),
26 self.s4.eq(ClockSignal("other")),
27 self.s5.eq(ResetSignal("other")),
28 )
29 f.add_driver(self.s1, None)
30 f.add_driver(self.s2, None)
31 f.add_driver(self.s3, "sync")
32
33 f = DomainRenamer("pix")(f)
34 self.assertRepr(f.statements, """
35 (
36 (eq (sig s1) (clk pix))
37 (eq (rst pix) (sig s2))
38 (eq (sig s3) (const 1'd0))
39 (eq (sig s4) (clk other))
40 (eq (sig s5) (rst other))
41 )
42 """)
43 self.assertEqual(f.drivers, {
44 None: SignalSet((self.s1, self.s2)),
45 "pix": SignalSet((self.s3,)),
46 })
47
48 def test_rename_multi(self):
49 f = Fragment()
50 f.add_statements(
51 self.s1.eq(ClockSignal()),
52 self.s2.eq(ResetSignal("other")),
53 )
54
55 f = DomainRenamer({"sync": "pix", "other": "pix2"})(f)
56 self.assertRepr(f.statements, """
57 (
58 (eq (sig s1) (clk pix))
59 (eq (sig s2) (rst pix2))
60 )
61 """)
62
63 def test_rename_cd(self):
64 cd_sync = ClockDomain()
65 cd_pix = ClockDomain()
66
67 f = Fragment()
68 f.add_domains(cd_sync, cd_pix)
69
70 f = DomainRenamer("ext")(f)
71 self.assertEqual(cd_sync.name, "ext")
72 self.assertEqual(f.domains, {
73 "ext": cd_sync,
74 "pix": cd_pix,
75 })
76
77 def test_rename_cd_subfragment(self):
78 cd_sync = ClockDomain()
79 cd_pix = ClockDomain()
80
81 f1 = Fragment()
82 f1.add_domains(cd_sync, cd_pix)
83 f2 = Fragment()
84 f2.add_domains(cd_sync)
85 f1.add_subfragment(f2)
86
87 f1 = DomainRenamer("ext")(f1)
88 self.assertEqual(cd_sync.name, "ext")
89 self.assertEqual(f1.domains, {
90 "ext": cd_sync,
91 "pix": cd_pix,
92 })
93
94 def test_rename_wrong_to_comb(self):
95 with self.assertRaises(ValueError,
96 msg="Domain 'sync' may not be renamed to 'comb'"):
97 DomainRenamer("comb")
98
99 def test_rename_wrong_from_comb(self):
100 with self.assertRaises(ValueError,
101 msg="Domain 'comb' may not be renamed"):
102 DomainRenamer({"comb": "sync"})
103
104
105 class DomainLowererTestCase(FHDLTestCase):
106 def setUp(self):
107 self.s = Signal()
108
109 def test_lower_clk(self):
110 sync = ClockDomain()
111 f = Fragment()
112 f.add_domains(sync)
113 f.add_statements(
114 self.s.eq(ClockSignal("sync"))
115 )
116
117 f = DomainLowerer()(f)
118 self.assertRepr(f.statements, """
119 (
120 (eq (sig s) (sig clk))
121 )
122 """)
123
124 def test_lower_rst(self):
125 sync = ClockDomain()
126 f = Fragment()
127 f.add_domains(sync)
128 f.add_statements(
129 self.s.eq(ResetSignal("sync"))
130 )
131
132 f = DomainLowerer()(f)
133 self.assertRepr(f.statements, """
134 (
135 (eq (sig s) (sig rst))
136 )
137 """)
138
139 def test_lower_rst_reset_less(self):
140 sync = ClockDomain(reset_less=True)
141 f = Fragment()
142 f.add_domains(sync)
143 f.add_statements(
144 self.s.eq(ResetSignal("sync", allow_reset_less=True))
145 )
146
147 f = DomainLowerer()(f)
148 self.assertRepr(f.statements, """
149 (
150 (eq (sig s) (const 1'd0))
151 )
152 """)
153
154 def test_lower_drivers(self):
155 sync = ClockDomain()
156 pix = ClockDomain()
157 f = Fragment()
158 f.add_domains(sync, pix)
159 f.add_driver(ClockSignal("pix"), None)
160 f.add_driver(ResetSignal("pix"), "sync")
161
162 f = DomainLowerer()(f)
163 self.assertEqual(f.drivers, {
164 None: SignalSet((pix.clk,)),
165 "sync": SignalSet((pix.rst,))
166 })
167
168 def test_lower_wrong_domain(self):
169 f = Fragment()
170 f.add_statements(
171 self.s.eq(ClockSignal("xxx"))
172 )
173
174 with self.assertRaises(DomainError,
175 msg="Signal (clk xxx) refers to nonexistent domain 'xxx'"):
176 DomainLowerer()(f)
177
178 def test_lower_wrong_reset_less_domain(self):
179 sync = ClockDomain(reset_less=True)
180 f = Fragment()
181 f.add_domains(sync)
182 f.add_statements(
183 self.s.eq(ResetSignal("sync"))
184 )
185
186 with self.assertRaises(DomainError,
187 msg="Signal (rst sync) refers to reset of reset-less domain 'sync'"):
188 DomainLowerer()(f)
189
190
191 class SampleLowererTestCase(FHDLTestCase):
192 def setUp(self):
193 self.i = Signal()
194 self.o1 = Signal()
195 self.o2 = Signal()
196 self.o3 = Signal()
197
198 def test_lower_signal(self):
199 f = Fragment()
200 f.add_statements(
201 self.o1.eq(Sample(self.i, 2, "sync")),
202 self.o2.eq(Sample(self.i, 1, "sync")),
203 self.o3.eq(Sample(self.i, 1, "pix")),
204 )
205
206 f = SampleLowerer()(f)
207 self.assertRepr(f.statements, """
208 (
209 (eq (sig o1) (sig $sample$s$i$sync$2))
210 (eq (sig o2) (sig $sample$s$i$sync$1))
211 (eq (sig o3) (sig $sample$s$i$pix$1))
212 (eq (sig $sample$s$i$sync$1) (sig i))
213 (eq (sig $sample$s$i$sync$2) (sig $sample$s$i$sync$1))
214 (eq (sig $sample$s$i$pix$1) (sig i))
215 )
216 """)
217 self.assertEqual(len(f.drivers["sync"]), 2)
218 self.assertEqual(len(f.drivers["pix"]), 1)
219
220 def test_lower_const(self):
221 f = Fragment()
222 f.add_statements(
223 self.o1.eq(Sample(1, 2, "sync")),
224 )
225
226 f = SampleLowerer()(f)
227 self.assertRepr(f.statements, """
228 (
229 (eq (sig o1) (sig $sample$c$1$sync$2))
230 (eq (sig $sample$c$1$sync$1) (const 1'd1))
231 (eq (sig $sample$c$1$sync$2) (sig $sample$c$1$sync$1))
232 )
233 """)
234 self.assertEqual(len(f.drivers["sync"]), 2)
235
236
237 class SwitchCleanerTestCase(FHDLTestCase):
238 def test_clean(self):
239 a = Signal()
240 b = Signal()
241 c = Signal()
242 stmts = [
243 Switch(a, {
244 1: a.eq(0),
245 0: [
246 b.eq(1),
247 Switch(b, {1: [
248 Switch(a|b, {})
249 ]})
250 ]
251 })
252 ]
253
254 self.assertRepr(SwitchCleaner()(stmts), """
255 (
256 (switch (sig a)
257 (case 1
258 (eq (sig a) (const 1'd0)))
259 (case 0
260 (eq (sig b) (const 1'd1)))
261 )
262 )
263 """)
264
265
266 class LHSGroupAnalyzerTestCase(FHDLTestCase):
267 def test_no_group_unrelated(self):
268 a = Signal()
269 b = Signal()
270 stmts = [
271 a.eq(0),
272 b.eq(0),
273 ]
274
275 groups = LHSGroupAnalyzer()(stmts)
276 self.assertEqual(list(groups.values()), [
277 SignalSet((a,)),
278 SignalSet((b,)),
279 ])
280
281 def test_group_related(self):
282 a = Signal()
283 b = Signal()
284 stmts = [
285 a.eq(0),
286 Cat(a, b).eq(0),
287 ]
288
289 groups = LHSGroupAnalyzer()(stmts)
290 self.assertEqual(list(groups.values()), [
291 SignalSet((a, b)),
292 ])
293
294 def test_no_loops(self):
295 a = Signal()
296 b = Signal()
297 stmts = [
298 a.eq(0),
299 Cat(a, b).eq(0),
300 Cat(a, b).eq(0),
301 ]
302
303 groups = LHSGroupAnalyzer()(stmts)
304 self.assertEqual(list(groups.values()), [
305 SignalSet((a, b)),
306 ])
307
308 def test_switch(self):
309 a = Signal()
310 b = Signal()
311 stmts = [
312 a.eq(0),
313 Switch(a, {
314 1: b.eq(0),
315 })
316 ]
317
318 groups = LHSGroupAnalyzer()(stmts)
319 self.assertEqual(list(groups.values()), [
320 SignalSet((a,)),
321 SignalSet((b,)),
322 ])
323
324 def test_lhs_empty(self):
325 stmts = [
326 Cat().eq(0)
327 ]
328
329 groups = LHSGroupAnalyzer()(stmts)
330 self.assertEqual(list(groups.values()), [
331 ])
332
333
334 class LHSGroupFilterTestCase(FHDLTestCase):
335 def test_filter(self):
336 a = Signal()
337 b = Signal()
338 c = Signal()
339 stmts = [
340 Switch(a, {
341 1: a.eq(0),
342 0: [
343 b.eq(1),
344 Switch(b, {1: []})
345 ]
346 })
347 ]
348
349 self.assertRepr(LHSGroupFilter(SignalSet((a,)))(stmts), """
350 (
351 (switch (sig a)
352 (case 1
353 (eq (sig a) (const 1'd0)))
354 (case 0 )
355 )
356 )
357 """)
358
359 def test_lhs_empty(self):
360 stmts = [
361 Cat().eq(0)
362 ]
363
364 self.assertRepr(LHSGroupFilter(SignalSet())(stmts), "()")
365
366
367 class ResetInserterTestCase(FHDLTestCase):
368 def setUp(self):
369 self.s1 = Signal()
370 self.s2 = Signal(reset=1)
371 self.s3 = Signal(reset=1, reset_less=True)
372 self.c1 = Signal()
373
374 def test_reset_default(self):
375 f = Fragment()
376 f.add_statements(
377 self.s1.eq(1)
378 )
379 f.add_driver(self.s1, "sync")
380
381 f = ResetInserter(self.c1)(f)
382 self.assertRepr(f.statements, """
383 (
384 (eq (sig s1) (const 1'd1))
385 (switch (sig c1)
386 (case 1 (eq (sig s1) (const 1'd0)))
387 )
388 )
389 """)
390
391 def test_reset_cd(self):
392 f = Fragment()
393 f.add_statements(
394 self.s1.eq(1),
395 self.s2.eq(0),
396 )
397 f.add_domains(ClockDomain("sync"))
398 f.add_driver(self.s1, "sync")
399 f.add_driver(self.s2, "pix")
400
401 f = ResetInserter({"pix": self.c1})(f)
402 self.assertRepr(f.statements, """
403 (
404 (eq (sig s1) (const 1'd1))
405 (eq (sig s2) (const 1'd0))
406 (switch (sig c1)
407 (case 1 (eq (sig s2) (const 1'd1)))
408 )
409 )
410 """)
411
412 def test_reset_value(self):
413 f = Fragment()
414 f.add_statements(
415 self.s2.eq(0)
416 )
417 f.add_driver(self.s2, "sync")
418
419 f = ResetInserter(self.c1)(f)
420 self.assertRepr(f.statements, """
421 (
422 (eq (sig s2) (const 1'd0))
423 (switch (sig c1)
424 (case 1 (eq (sig s2) (const 1'd1)))
425 )
426 )
427 """)
428
429 def test_reset_less(self):
430 f = Fragment()
431 f.add_statements(
432 self.s3.eq(0)
433 )
434 f.add_driver(self.s3, "sync")
435
436 f = ResetInserter(self.c1)(f)
437 self.assertRepr(f.statements, """
438 (
439 (eq (sig s3) (const 1'd0))
440 (switch (sig c1)
441 (case 1 )
442 )
443 )
444 """)
445
446
447 class EnableInserterTestCase(FHDLTestCase):
448 def setUp(self):
449 self.s1 = Signal()
450 self.s2 = Signal()
451 self.s3 = Signal()
452 self.c1 = Signal()
453
454 def test_enable_default(self):
455 f = Fragment()
456 f.add_statements(
457 self.s1.eq(1)
458 )
459 f.add_driver(self.s1, "sync")
460
461 f = EnableInserter(self.c1)(f)
462 self.assertRepr(f.statements, """
463 (
464 (eq (sig s1) (const 1'd1))
465 (switch (sig c1)
466 (case 0 (eq (sig s1) (sig s1)))
467 )
468 )
469 """)
470
471 def test_enable_cd(self):
472 f = Fragment()
473 f.add_statements(
474 self.s1.eq(1),
475 self.s2.eq(0),
476 )
477 f.add_driver(self.s1, "sync")
478 f.add_driver(self.s2, "pix")
479
480 f = EnableInserter({"pix": self.c1})(f)
481 self.assertRepr(f.statements, """
482 (
483 (eq (sig s1) (const 1'd1))
484 (eq (sig s2) (const 1'd0))
485 (switch (sig c1)
486 (case 0 (eq (sig s2) (sig s2)))
487 )
488 )
489 """)
490
491 def test_enable_subfragment(self):
492 f1 = Fragment()
493 f1.add_statements(
494 self.s1.eq(1)
495 )
496 f1.add_driver(self.s1, "sync")
497
498 f2 = Fragment()
499 f2.add_statements(
500 self.s2.eq(1)
501 )
502 f2.add_driver(self.s2, "sync")
503 f1.add_subfragment(f2)
504
505 f1 = EnableInserter(self.c1)(f1)
506 (f2, _), = f1.subfragments
507 self.assertRepr(f1.statements, """
508 (
509 (eq (sig s1) (const 1'd1))
510 (switch (sig c1)
511 (case 0 (eq (sig s1) (sig s1)))
512 )
513 )
514 """)
515 self.assertRepr(f2.statements, """
516 (
517 (eq (sig s2) (const 1'd1))
518 (switch (sig c1)
519 (case 0 (eq (sig s2) (sig s2)))
520 )
521 )
522 """)
523
524 def test_enable_read_port(self):
525 mem = Memory(width=8, depth=4)
526 f = EnableInserter(self.c1)(mem.read_port(transparent=False)).elaborate(platform=None)
527 self.assertRepr(f.named_ports["EN"][0], """
528 (m (sig c1) (sig mem_r_en) (const 1'd0))
529 """)
530
531 def test_enable_write_port(self):
532 mem = Memory(width=8, depth=4)
533 f = EnableInserter(self.c1)(mem.write_port()).elaborate(platform=None)
534 self.assertRepr(f.named_ports["EN"][0], """
535 (m (sig c1) (cat (repl (slice (sig mem_w_en) 0:1) 8)) (const 8'd0))
536 """)
537
538
539 class _MockElaboratable(Elaboratable):
540 def __init__(self):
541 self.s1 = Signal()
542
543 def elaborate(self, platform):
544 f = Fragment()
545 f.add_statements(
546 self.s1.eq(1)
547 )
548 f.add_driver(self.s1, "sync")
549 return f
550
551
552 class TransformedElaboratableTestCase(FHDLTestCase):
553 def setUp(self):
554 self.c1 = Signal()
555 self.c2 = Signal()
556
557 def test_getattr(self):
558 e = _MockElaboratable()
559 te = EnableInserter(self.c1)(e)
560
561 self.assertIs(te.s1, e.s1)
562
563 def test_composition(self):
564 e = _MockElaboratable()
565 te1 = EnableInserter(self.c1)(e)
566 te2 = ResetInserter(self.c2)(te1)
567
568 self.assertIsInstance(te1, TransformedElaboratable)
569 self.assertIs(te1, te2)
570
571 f = Fragment.get(te2, None)
572 self.assertRepr(f.statements, """
573 (
574 (eq (sig s1) (const 1'd1))
575 (switch (sig c1)
576 (case 0 (eq (sig s1) (sig s1)))
577 )
578 (switch (sig c2)
579 (case 1 (eq (sig s1) (const 1'd0)))
580 )
581 )
582 """)
583
584
585 class MockUserValue(UserValue):
586 def __init__(self, lowered):
587 super().__init__()
588 self.lowered = lowered
589
590 def lower(self):
591 return self.lowered
592
593
594 class UserValueTestCase(FHDLTestCase):
595 def setUp(self):
596 self.s = Signal()
597 self.c = Signal()
598 self.uv = MockUserValue(self.s)
599
600 def test_lower(self):
601 sync = ClockDomain()
602 f = Fragment()
603 f.add_domains(sync)
604 f.add_statements(
605 self.uv.eq(1)
606 )
607 for signal in self.uv._lhs_signals():
608 f.add_driver(signal, "sync")
609
610 f = ResetInserter(self.c)(f)
611 f = DomainLowerer()(f)
612 self.assertRepr(f.statements, """
613 (
614 (eq (sig s) (const 1'd1))
615 (switch (sig c)
616 (case 1 (eq (sig s) (const 1'd0)))
617 )
618 (switch (sig rst)
619 (case 1 (eq (sig s) (const 1'd0)))
620 )
621 )
622 """)
623
624
625 class UserValueRecursiveTestCase(UserValueTestCase):
626 def setUp(self):
627 self.s = Signal()
628 self.c = Signal()
629 self.uv = MockUserValue(MockUserValue(self.s))
630
631 # inherit the test_lower method from UserValueTestCase because the checks are the same