2 from enum
import Enum
, unique
3 from typing
import AbstractSet
, Iterable
, Iterator
, NoReturn
, Tuple
, Union
, overload
5 from cached_property
import cached_property
6 from nmutil
.plain_data
import plain_data
8 from bigint_presentation_code
.util
import OFSet
, OSet
, Self
, assert_never
, final
9 from weakref
import WeakValueDictionary
15 self
.ops
= [] # type: list[Op]
16 op_names
= WeakValueDictionary()
17 self
.__op
_names
= op_names
# type: WeakValueDictionary[str, Op]
18 self
.__next
_name
_suffix
= 2
20 def _add_op_with_unused_name(self
, op
, name
=""):
21 # type: (Op, str) -> str
23 raise ValueError("can't add Op to wrong Fn")
24 if hasattr(op
, "name"):
25 raise ValueError("Op already named")
28 if name
not in self
.__op
_names
:
29 self
.__op
_names
[name
] = op
31 name
= orig_name
+ str(self
.__next
_name
_suffix
)
32 self
.__next
_name
_suffix
+= 1
43 VL_MAXVL
= enum
.auto()
46 def only_scalar(self
):
47 if self
is RegKind
.GPR
:
49 elif self
is RegKind
.CA
or self
is RegKind
.VL_MAXVL
:
56 if self
is RegKind
.GPR
:
58 elif self
is RegKind
.CA
or self
is RegKind
.VL_MAXVL
:
64 return "RegKind." + self
._name
_
67 @plain_data(frozen
=True, unsafe_hash
=True)
70 __slots__
= "kind", "vec"
72 def __init__(self
, kind
, vec
):
73 # type: (RegKind, bool) -> None
75 if kind
.only_scalar
and vec
:
76 raise ValueError(f
"kind={kind} must have vec=False")
79 def get_length(self
, maxvl
):
81 # here's where subvl and elwid would be accounted for
87 @plain_data(frozen
=True, unsafe_hash
=True)
90 __slots__
= "kind", "length"
92 def __init__(self
, kind
, length
=1):
93 # type: (RegKind, int) -> None
95 if length
< 1 or length
> kind
.reg_count
:
96 raise ValueError("invalid length")
99 def try_concat(self
, *others
):
100 # type: (*RegShape | Reg | RegClass | None) -> RegShape | None
104 if isinstance(other
, (Reg
, RegClass
)):
108 if other
.kind
!= self
.kind
:
110 length
+= other
.length
111 if length
> kind
.reg_count
:
113 return RegShape(kind
=kind
, length
=length
)
116 @plain_data(frozen
=True, unsafe_hash
=True)
119 __slots__
= "shape", "start"
121 def __init__(self
, shape
, start
):
122 # type: (RegShape, int) -> None
124 if start
< 0 or start
+ shape
.length
> shape
.kind
.reg_count
:
125 raise ValueError("start not in valid range")
130 return self
.shape
.kind
134 return self
.shape
.length
136 def conflicts(self
, other
):
137 # type: (Reg) -> bool
138 return (self
.kind
== other
.kind
139 and self
.start
< other
.stop
and other
.start
< self
.stop
)
143 return self
.start
+ self
.length
145 def try_concat(self
, *others
):
146 # type: (*Reg | None) -> Reg | None
147 shape
= self
.shape
.try_concat(*others
)
152 assert other
is not None, "already caught by RegShape.try_concat"
153 if stop
!= other
.start
:
156 return Reg(shape
, self
.start
)
160 class RegClass(AbstractSet
[Reg
]):
161 def __init__(self
, regs_or_starts
=(), shape
=None, starts_bitset
=0):
162 # type: (Iterable[Reg | int], RegShape | None, int) -> None
163 for reg_or_start
in regs_or_starts
:
164 if isinstance(reg_or_start
, Reg
):
166 shape
= reg_or_start
.shape
167 elif shape
!= reg_or_start
.shape
:
168 raise ValueError(f
"conflicting RegShapes: {shape} and "
169 f
"{reg_or_start.shape}")
170 start
= reg_or_start
.start
174 raise ValueError("a Reg's start is out of range")
175 starts_bitset |
= 1 << start
176 if starts_bitset
== 0:
179 self
.__starts
_bitset
= starts_bitset
181 if starts_bitset
!= 0:
182 raise ValueError("non-empty RegClass must have non-None shape")
184 if self
.stops_bitset
>= 1 << shape
.kind
.reg_count
:
185 raise ValueError("a Reg's start is out of range")
189 # type: () -> RegShape | None
193 def starts_bitset(self
):
195 return self
.__starts
_bitset
198 def stops_bitset(self
):
200 if self
.__shape
is None:
202 return self
.__starts
_bitset
<< self
.__shape
.length
206 # type: () -> OFSet[int]
207 if self
.length
is None:
210 # return OFSet(for i in range(self.length))
214 # type: () -> OFSet[int]
215 if self
.__shape
is None:
217 return OFSet(i
+ self
.__shape
.length
for i
in self
.__starts
)
221 if self
.__shape
is None:
223 return self
.__shape
.kind
227 """length of registers in this RegClass, not to be confused with the number of `Reg`s in self"""
228 if self
.__shape
is None:
230 return self
.__shape
.length
232 def concat(self
, *others
):
233 # type: (*RegClass) -> RegClass
237 shape
= shape
.try_concat(*others
)
240 starts
= OSet(self
.starts
)
241 offset
= shape
.length
243 assert other
.__shape
is not None, \
244 "already caught by RegShape.try_concat"
245 starts
&= OSet(i
- offset
for i
in other
.starts
)
246 offset
+= other
.__shape
.length
247 return RegClass(starts
, shape
=shape
)
249 def __contains__(self
, reg
):
250 # type: (Reg) -> bool
251 return reg
.shape
== self
.shape
and reg
.start
in self
.starts
254 # type: () -> Iterator[Reg]
255 if self
.shape
is None:
257 for start
in self
.starts
:
258 yield Reg(shape
=self
.shape
, start
=start
)
261 return len(self
.starts
)
264 return super()._hash
()
267 @plain_data(frozen
=True, unsafe_hash
=True)
270 __slots__
= "ty", "regs"
272 def __init__(self
, ty
, regs
=None):
273 # type: (OperandType, OFSet[int] | None) -> None
277 OT_VGPR
= OperandType(RegKind
.GPR
, vec
=True)
278 OT_SGPR
= OperandType(RegKind
.GPR
, vec
=False)
279 OT_CA
= OperandType(RegKind
.CA
, vec
=False)
280 OT_VL
= OperandType(RegKind
.VL_MAXVL
, vec
=False)
283 @plain_data(frozen
=True, unsafe_hash
=True)
285 __slots__
= "input_index", "output_index"
287 def __init__(self
, input_index
, output_index
):
288 # type: (int, int) -> None
289 self
.input_index
= input_index
290 self
.output_index
= output_index
293 Constraint
= Union
[TiedOutput
, NoReturn
]
296 @plain_data(frozen
=True, unsafe_hash
=True)
299 __slots__
= ("demo_asm", "inputs", "outputs", "immediates", "constraints",
300 "is_copy", "is_load_immediate", "has_side_effects")
302 def __init__(self
, demo_asm
, # type: str
303 inputs
, # type: Iterable[OperandType]
304 outputs
, # type: Iterable[OperandType]
305 immediates
, # type: Iterable[range]
306 constraints
, # type: Iterable[Constraint]
307 is_copy
=False, # type: bool
308 is_load_immediate
=False, # type: bool
309 has_side_effects
=False, # type: bool
311 # type: (...) -> None
312 self
.demo_asm
= demo_asm
313 self
.inputs
= tuple(inputs
)
314 self
.outputs
= tuple(outputs
)
315 self
.immediates
= tuple(immediates
)
316 self
.constraints
= tuple(constraints
)
317 self
.is_copy
= is_copy
318 self
.is_load_immediate
= is_load_immediate
319 self
.has_side_effects
= has_side_effects
325 def __init__(self
, properties
):
326 # type: (OpProperties) -> None
328 self
.properties
= properties
330 SvAddE
= OpProperties(
331 demo_asm
="sv.adde *RT, *RA, *RB",
332 inputs
=(OT_VGPR
, OT_VGPR
, OT_CA
, OT_VL
),
333 outputs
=(OT_VGPR
, OT_CA
),
337 SvSubFE
= OpProperties(
338 demo_asm
="sv.subfe *RT, *RA, *RB",
339 inputs
=(OT_VGPR
, OT_VGPR
, OT_CA
, OT_VL
),
340 outputs
=(OT_VGPR
, OT_CA
),
344 SvMAddEDU
= OpProperties(
345 demo_asm
="sv.maddedu *RT, *RA, RB, RC",
346 inputs
=(OT_VGPR
, OT_SGPR
, OT_SGPR
, OT_VL
),
347 outputs
=(OT_VGPR
, OT_SGPR
),
351 SetVLI
= OpProperties(
352 demo_asm
="setvl 0, 0, imm, 0, 1, 1",
355 immediates
=(range(1, 65),),
357 is_load_immediate
=True,
360 demo_asm
="sv.addi *RT, 0, imm",
363 immediates
=(range(-2 ** 15, 2 ** 15),),
365 is_load_immediate
=True,
368 demo_asm
="addi RT, 0, imm",
371 immediates
=(range(-2 ** 15, 2 ** 15),),
373 is_load_immediate
=True,
376 demo_asm
="sv.or *RT, *src, *src",
377 inputs
=(OT_VGPR
, OT_VL
),
384 demo_asm
="mv RT, src",
393 @plain_data(frozen
=True, unsafe_hash
=True, repr=False)
396 __slots__
= "sliced_op_outputs",
398 _SlicedOpOutputIn
= Union
["tuple[Op, int, int | range | slice]",
399 "tuple[Op, int]", "SSAVal"]
402 def __process_sliced_op_outputs(inp
):
403 # type: (Iterable[_SlicedOpOutputIn]) -> Iterable[Tuple["Op", int, range]]
405 if isinstance(v
, SSAVal
):
406 yield from v
.sliced_op_outputs
410 if output_index
< 0 or output_index
>= len(op
.properties
.outputs
):
411 raise ValueError("invalid output_index")
412 cur_len
= op
.properties
.outputs
[output_index
].get_length(op
.maxvl
)
413 slice_
= slice(None) if len(v
) == 2 else v
[2]
414 if isinstance(slice_
, range):
415 slice_
= slice(slice_
.start
, slice_
.stop
, slice_
.step
)
416 if isinstance(slice_
, int):
417 # raise exception for out-of-range values
418 idx
= range(cur_len
)[slice_
]
419 range_
= range(idx
, idx
+ 1)
421 # raise exception for out-of-range values
422 range_
= range(cur_len
)[slice_
]
424 raise ValueError("slice step must be 1")
427 yield op
, output_index
, range_
429 def __init__(self
, sliced_op_outputs
):
430 # type: (Iterable[_SlicedOpOutputIn] | SSAVal) -> None
431 # we have length arg so plain_data.replace works
432 if isinstance(sliced_op_outputs
, SSAVal
):
433 inp
= sliced_op_outputs
.sliced_op_outputs
435 inp
= SSAVal
.__process
_sliced
_op
_outputs
(sliced_op_outputs
)
436 processed
= [] # type: list[tuple[Op, int, range]]
438 for op
, output_index
, range_
in inp
:
439 length
+= len(range_
)
440 if len(processed
) == 0:
441 processed
.append((op
, output_index
, range_
))
443 last_op
, last_output_index
, last_range_
= processed
[-1]
444 if last_op
== op
and last_output_index
== output_index \
445 and last_range_
.stop
== range_
.start
:
447 range_
= range(last_range_
.start
, range_
.stop
)
448 processed
[-1] = op
, output_index
, range_
450 processed
.append((op
, output_index
, range_
))
451 self
.sliced_op_outputs
= tuple(processed
)
453 def __add__(self
, other
):
454 # type: (SSAVal) -> SSAVal
455 if not isinstance(other
, SSAVal
):
456 return NotImplemented
457 return SSAVal(self
.sliced_op_outputs
+ other
.sliced_op_outputs
)
459 def __radd__(self
, other
):
460 # type: (SSAVal) -> SSAVal
461 if isinstance(other
, SSAVal
):
462 return other
.__add
__(self
)
463 return NotImplemented
466 def expanded_sliced_op_outputs(self
):
467 # type: () -> tuple[tuple[Op, int, int], ...]
469 for op
, output_index
, range_
in self
.sliced_op_outputs
:
471 retval
.append((op
, output_index
, i
))
472 # must be tuple to not be modifiable since it's cached
475 def __getitem__(self
, idx
):
476 # type: (int | slice) -> SSAVal
477 if isinstance(idx
, int):
478 return SSAVal([self
.expanded_sliced_op_outputs
[idx
]])
479 return SSAVal(self
.expanded_sliced_op_outputs
[idx
])
482 return len(self
.expanded_sliced_op_outputs
)
485 # type: () -> Iterator[SSAVal]
486 for v
in self
.expanded_sliced_op_outputs
:
491 if len(self
.sliced_op_outputs
) == 0:
494 for op
, output_index
, range_
in self
.sliced_op_outputs
:
495 out_len
= op
.properties
.outputs
[output_index
].get_length(op
.maxvl
)
496 parts
.append(f
"<{op.name}#{output_index}>")
497 if range_
!= range(out_len
):
498 parts
[-1] += f
"[{range_.start}:{range_.stop}]"
499 return " + ".join(parts
)
502 @plain_data(frozen
=True, eq
=False)
505 __slots__
= "fn", "kind", "inputs", "immediates", "outputs", "maxvl", "name"
507 def __init__(self
, fn
, kind
, inputs
, immediates
, maxvl
, name
=""):
508 # type: (Fn, OpKind, Iterable[SSAVal], Iterable[int], int, str) -> None
511 self
.inputs
= list(inputs
)
512 self
.immediates
= list(immediates
)
514 outputs_len
= len(self
.properties
.outputs
)
515 self
.outputs
= tuple(SSAVal([(self
, i
)]) for i
in range(outputs_len
))
516 self
.name
= fn
._add
_op
_with
_unused
_name
(self
, name
)
519 def properties(self
):
520 return self
.kind
.properties
522 def __eq__(self
, other
):
523 if isinstance(other
, Op
):
525 return NotImplemented
528 return object.__hash
__(self
)