from typing import Union, Optional, Literal
from hwt.doc_markers import internal
from hwt.hdl.const import HConst
from hwt.hdl.operator import HOperatorNode
from hwt.hdl.operatorDefs import HwtOps, HOperatorDef
from hwt.hdl.types.bitConstFunctions import AnyHBitsValue, \
HBitsAnyIndexCompatibleValue
from hwt.hdl.types.bits import HBits
from hwt.hdl.types.defs import INT, SLICE, BIT, BIT_N
from hwt.hdl.types.slice import HSlice
from hwt.hdl.types.sliceUtils import slice_to_HSlice
from hwt.hdl.types.typeCast import toHVal
from hwt.mainBases import RtlSignalBase
from pyMathBitPrecise.bits3t import Bits3val
[docs]
@internal
def _match_msb_get(v: "HBitsRtlSignal"):
"""
:returns: x if v == x[x.width - 1] else None
"""
if v._dtype.bit_length() != 1:
return None
opVIsResultOf = _get_operator_i_am_the_result_of(v)
if opVIsResultOf == HwtOps.INDEX:
iOp = v.singleDriver()
iOpSrc, iOpI = iOp.operands
if isinstance(iOpI, HConst) and iOpI._is_full_valid() and isinstance(iOpI._dtype, HBits) and int(iOpI) == iOpSrc._dtype.bit_length() - 1:
return iOpSrc
return None
[docs]
@internal
def _fold_concat_of_msb_using_sext(v: "HBitsRtlSignal", vReplicatinCount:int, other: "HBitsRtlSignal", other_w: int):
msbSrc = _match_msb_get(v)
if msbSrc is other or other._dtype.bit_length() == 1 and v is other:
# fold concat(x.msb, x) -> sext(x)
return other._sext(other_w + vReplicatinCount)
else:
opOtherIsResultOf = _get_operator_i_am_the_result_of(other)
if opOtherIsResultOf == HwtOps.SEXT and msbSrc == other.singleDriver().operands[0]:
# fold concat(x.msb, sext(x)) -> sext(x)
return msbSrc._sext(other_w + vReplicatinCount)
elif opOtherIsResultOf == HwtOps.CONCAT:
otherD: HOperatorNode = other.singleDriver()
highBits, lowBits = otherD.operands
if highBits is v:
# fold concat(x1b, concat(x1b, y))) -> concat(sext(x1b), y)
assert highBits._dtype.bit_length() == 1, highBits
return highBits._sext(1 + vReplicatinCount)._concat(lowBits)
elif highBits is msbSrc:
# fold concat(x.msb, concat(x, y)) -> concat(sext(x), y)
return highBits._sext(highBits._dtype.bit_length() + vReplicatinCount)._concat(lowBits)
else:
opHighBitsIsResultOf = _get_operator_i_am_the_result_of(highBits)
if opHighBitsIsResultOf == HwtOps.SEXT:
hiBitsSrc = highBits.singleDriver().operands[0]
if hiBitsSrc is msbSrc:
# fold concat(x.msb, concat(sext(x), y)) -> concat(sext(x), y)
return msbSrc._sext(highBits._dtype.bit_length() + vReplicatinCount)._concat(lowBits)
return None
[docs]
@internal
def _get_operator_i_am_the_result_of(const_or_sig: Union[RtlSignalBase, HConst]) -> Optional[HOperatorDef]:
if len(const_or_sig._rtlDrivers) == 1 and isinstance(const_or_sig._rtlObjectOrigin, HOperatorNode):
return const_or_sig._rtlObjectOrigin.operator
else:
return None
[docs]
@internal
def bitsGetitem_foldSliceOnCONCAT(v: AnyHBitsValue, start:int, stop: int, key: HBitsAnyIndexCompatibleValue) -> AnyHBitsValue:
op_h, op_l = v._rtlObjectOrigin.operands
op_l_w = op_l._dtype.bit_length()
assert start > stop, (start, stop, "Should be in MSB:LSB format")
if start <= op_l_w:
# entirely in first operand of concat
if op_l_w == 1:
if isinstance(key._dtype, HSlice):
assert int(key.val.start) == 1 and int(key.val.stop) == 0 and int(key.val.step) == -1, key
return op_l
else:
assert int(key) == 0, key
return op_l
else:
return op_l[key]
elif stop >= op_l_w:
# intirely in second operand of concat
start -= op_l_w
stop -= op_l_w
if op_h._dtype.bit_length() == 1:
assert start - stop == 1
return op_h
else:
return op_h[SLICE.from_py(slice(start, stop, -1))]
else:
# partially in op_h and op_l, allpy slice on concat operands and return concatenation of it
if stop != 0 or op_l._dtype.bit_length() > 1:
op_l = op_l[:stop]
if op_h._dtype.bit_length() == 1:
assert start - op_l_w == 1, ("Out of range slice (but this error should be catched sooner)", v, key)
else:
op_h = op_h[start - op_l_w:0]
return op_h._concat(op_l)
[docs]
@internal
def bitsGetitem_foldSliceOnEXT(v: AnyHBitsValue,
start:int, stop: int,
key: HBitsAnyIndexCompatibleValue,
iAmResultOfOp: Literal[HwtOps.ZEXT, HwtOps.SEXT]) -> AnyHBitsValue:
assert iAmResultOfOp in (HwtOps.ZEXT, HwtOps.SEXT), iAmResultOfOp
# :note: start points at MSB and stop on LSB (start:stop, eg 8:0)
extSrc = v.singleDriver().operands[0]
extSrcWidth = extSrc._dtype.bit_length()
resultWidth = start - stop
if start < extSrcWidth:
# selecting only bits from extSrc
if stop == 0:
return extSrc._trunc(start)
else:
return extSrc[key]
elif stop >= extSrcWidth:
# only msb bits are selected from ext and
if iAmResultOfOp == HwtOps.ZEXT:
return v._dtype._createMutated(resultWidth).from_py(0)
else:
return extSrc.getMsb()._sext(resultWidth)
else:
# selected value overlaps between extSrc and extension bits
if stop != 0:
extSrc = extSrc[:stop]
return extSrc._ext(resultWidth, iAmResultOfOp == HwtOps.SEXT)
[docs]
@internal
def bitsGetitem_foldBitGetOnEXT(v: AnyHBitsValue,
i:int,
key: HBitsAnyIndexCompatibleValue,
iAmResultOfOp: Literal[HwtOps.ZEXT, HwtOps.SEXT]) -> AnyHBitsValue:
# fold zext(x)[i] -> x[i] if x.width < i else 0
# fold sext(x)[i] -> x[i] if x.width < i else x.msb
extSrc = v.singleDriver().operands[0]
extSrcWidth = extSrc._dtype.bit_length()
if i < extSrcWidth:
# selecting only bits from extSrc
return extSrc[key]
else:
# only msb bits are selected from ext and
if iAmResultOfOp == HwtOps.ZEXT:
return v._dtype._createMutated(1).from_py(0)
else:
return extSrc.getMsb()
[docs]
@internal
def bitsGetitem_foldBitGetOnConcat(v: AnyHBitsValue, key: HBitsAnyIndexCompatibleValue, _index:int, iAmResultOfOp: Optional[HOperatorDef]):
# index directly in the member of concatenation
update_key = False
while iAmResultOfOp == HwtOps.CONCAT:
op_h, op_l = v._rtlObjectOrigin.operands
op_l_w = op_l._dtype.bit_length()
if _index < op_l_w:
v = op_l
else:
v = op_h
_index -= op_l_w
update_key = True
iamConst = isinstance(v, HConst)
iAmResultOfOp = None if iamConst else _get_operator_i_am_the_result_of(v)
# [todo] check if swap of negated flag can cause anything wrong
if update_key:
key = key._dtype.from_py(_index)
return v, key
[docs]
@internal
def bitsGetitem(v: AnyHBitsValue, iamConst:bool, key: HBitsAnyIndexCompatibleValue) -> AnyHBitsValue:
"""
[] operator
:attention: Table below is for little endian bit order (MSB:LSB)
which is default.
This is **reversed** as it is in pure python
where it is [0, len(v)].
:attention: Slice on slice signal is automatically reduced
to single slice. This function also looks trough concatenations.
+-----------------------------+----------------------------------------------------------------------------------+
| a[up:low] | items low through up; a[16:8] selects upper byte from 16b vector a |
+-----------------------------+----------------------------------------------------------------------------------+
| a[up:] | low is automatically substituted with 0; a[8:] will select lower 8 bits |
+-----------------------------+----------------------------------------------------------------------------------+
| a[:end] | up is automatically substituted; a[:8] will select upper byte from 16b vector a |
+-----------------------------+----------------------------------------------------------------------------------+
| a[:], a[-1], a[-2:], a[:-2] | raises NotImplementedError (not implemented due to complicated support in hdl) |
+-----------+----------------------------------------------------------------------------------------------------+
:note: signed is preserved as in VHDL, and not like in Verilog where result of slice is always unsigned
"""
st = v._dtype
vWidth = st.bit_length()
if isinstance(key, slice):
key = slice_to_HSlice(key, vWidth)
isSLICE = True
else:
isSLICE = isinstance(key, HSlice.getConstCls())
is1bScalar = vWidth == 1 and not st.force_vector
if not isSLICE:
if is1bScalar and \
((isinstance(key, int) and key == 0) or\
(isinstance(key, HConst) and key._is_full_valid() and int(key) == 0)):
return v
key = toHVal(key, INT)
else:
if is1bScalar and key.val.start == 1 and key.val.stop == 0 and key.val.step == -1:
return v
if is1bScalar:
# assert not indexing on single bit
raise IndexError("indexing on single bit")
iAmResultOfOp = None if iamConst else _get_operator_i_am_the_result_of(v)
if iAmResultOfOp == HwtOps.TRUNC:
# fold trunc(x)[i] to x[i]
return v.singleDriver().operands[0][key]
elif iAmResultOfOp == HwtOps.BitsAsSigned or iAmResultOfOp == HwtOps.BitsAsUnsigned:
# fold x._signed()[i] to x[i]._signed()
return iAmResultOfOp._evalFn(v.singleDriver().operands[0][key])
HBits = v._dtype.__class__
if isSLICE:
# :note: downto notation
start = key.val.start
stop = key.val.stop
if key.val.step != -1:
raise NotImplementedError()
startIsConst = isinstance(start, HConst)
stopIsConst = isinstance(stop, HConst)
indexesAreHConst = startIsConst and stopIsConst
if indexesAreHConst and start.val == vWidth and stop.val == 0:
# selecting all bits no conversion needed
# fold x[h:l] -> x
return v
# check start boundaries
if startIsConst:
_start = int(start)
if _start < 0 or _start > vWidth:
raise IndexError("start index is out of range start:", _start, " width:", vWidth, "")
# check end boundaries
if stopIsConst:
_stop = int(stop)
if _stop < 0 or _stop >= vWidth:
raise IndexError("stop index is out of range stop:", _stop, " width:", vWidth)
# check width of selected range
if startIsConst and stopIsConst and _start - _stop <= 0:
raise IndexError("start (represents MSB bit index +1) must be > stop (represents LSB bit index)", _start, _stop)
if iAmResultOfOp == HwtOps.INDEX:
# try reduce v and parent slice to one
# fold x[a:b][start:stop] -> x[b+start:b+stop]
original, parentIndex = v._rtlObjectOrigin.operands
if isinstance(parentIndex._dtype, HSlice):
parentLower = parentIndex.val.stop
start = parentLower + start
stop = parentLower + stop
return original[start:stop]
elif startIsConst and stopIsConst:
# index directly in the member of concatenation
# :note: start points at MSB and stop on LSB (start:stop, eg 8:0)
stop = int(stop)
start = int(start)
if iAmResultOfOp == HwtOps.CONCAT:
return bitsGetitem_foldSliceOnCONCAT(v, start, stop, key)
elif iAmResultOfOp == HwtOps.ZEXT or iAmResultOfOp == HwtOps.SEXT:
return bitsGetitem_foldSliceOnEXT(v, start, stop, key, iAmResultOfOp)
elif stop == 0:
return v._trunc(start)
if iamConst:
if isinstance(key, SLICE.getConstCls()):
key = key.val
res = Bits3val.__getitem__(v, key)
if res._dtype.bit_length() == 1 and not res._dtype.force_vector:
assert res._dtype is not v._dtype
res._dtype.force_vector = True
return res
else:
key = SLICE.from_py(slice(start, stop, -1))
_resWidth = start - stop
resT = HBits(bit_length=_resWidth, force_vector=_resWidth == 1,
signed=st.signed, negated=st.negated)
elif isinstance(key, HBits.getConstCls()):
# int like value addressing a single bit
if st.negated:
resT = BIT_N
else:
resT = BIT
if not key._is_full_valid():
return resT.from_py(None)
# check index range
_index = int(key)
# if _index == 0 and not st.force_vector:
# # fold x[0] -> x._trunc(1)
# return v._trunc(1)
if _index < 0 or _index > vWidth - 1:
raise IndexError(_index)
if iAmResultOfOp == HwtOps.INDEX:
# index directly in parent signal
# fold x[a:b][i] -> x[b+i]
original, parentIndex = v._rtlObjectOrigin.operands
if isinstance(parentIndex._dtype, HSlice):
parentLower = parentIndex.val.stop
return original[parentLower + _index]
elif iAmResultOfOp == HwtOps.TRUNC:
# fold x._trunc(n)[i] to x[i]
original = v._rtlObjectOrigin.operands[0]
return original[_index]
elif iAmResultOfOp == HwtOps.ZEXT or iAmResultOfOp == HwtOps.SEXT:
return bitsGetitem_foldBitGetOnEXT(v, _index, key, iAmResultOfOp)
else:
# index directly in the member of concatenation
# fold concat(a, x)[i] -> x[i]
_v, _key = bitsGetitem_foldBitGetOnConcat(v, key, _index, iAmResultOfOp)
changed = v is not _v or _key is not key
v = _v
key = _key
iamConst = isinstance(v, HConst)
st = v._dtype
if isinstance(key, HBits.getConstCls()) and int(key) == 0 and (
v._dtype.bit_length() == 1 and not v._dtype.force_vector
):
return v
elif changed:
return v[key]
if iamConst:
# at the end because multiple non-constant indexes may be applied on constant and we want to merge them
return Bits3val.__getitem__(v, key)
elif key._is_full_valid() and int(key) == 0 and (v._dtype == BIT or v._dtype == BIT_N):
return v
elif isinstance(key, RtlSignalBase):
t = key._dtype
if isinstance(t, HSlice):
bit_length = key.staticEval()._size()
resT = HBits(bit_length,
force_vector=bit_length == 1,
signed=st.signed,
negated=st.negated)
elif isinstance(t, HBits):
resT = BIT
else:
raise TypeError(
"Index operation not implemented"
" for index of type ", t)
else:
raise TypeError(
"Index operation not implemented for index ", key)
if st.negated and resT is BIT:
resT = BIT_N
return HOperatorNode.withRes(HwtOps.INDEX, [v, key], resT)