from operator import ne, eq
from typing import Callable, Union
from hwt.doc_markers import internal
from hwt.hdl.const import HConst
from hwt.hdl.operator import HOperatorNode
from hwt.hdl.operatorDefs import HwtOps, HOperatorDef, CMP_OP_SWAP
from hwt.hdl.types.bits import HBits
from hwt.hdl.types.defs import BIT
from hwt.hdl.types.slice import HSlice
from hwt.hdl.types.typeCast import toHVal
from hwt.mainBases import RtlSignalBase
from hwt.math import isPow2, log2ceil
from hwt.synthesizer.rtlLevel.exceptions import SignalDriverErr
from pyMathBitPrecise.bit_utils import mask
from pyMathBitPrecise.bits3t import bitsCmp__val, bitsBitOp__val, \
bitsArithOp__val, Bits3val, bitsCmp__val_NE, bitsCmp__val_EQ
HBitsAnyCompatibleValue = Union["HBitsRtlSignal", "HBitsConst", int, None]
HBitsAnyIndexCompatibleValue = Union[int, slice, RtlSignalBase[HSlice], RtlSignalBase[HBits], None]
AnyHBitsValue = Union["HBitsRtlSignal", "HBitsConst"]
[docs]
def HBits_common_operand_type_checks_for_self(self: AnyHBitsValue):
t = self._dtype
if t.negated:
raise TypeError("HBits.negated=True supports only _isOn(), ==, !=", self)
[docs]
def HBits_common_operand_type_checks_for_other(other: AnyHBitsValue, t: HBits):
ot = other._dtype
if not isinstance(ot, t.__class__):
raise TypeError(ot)
if ot.negated:
raise TypeError("HBits.negated=True supports only _isOn(), ==, !=", other)
[docs]
def HBits_common_operand_type_checks(self: AnyHBitsValue, other: AnyHBitsValue):
t = self._dtype
ot = other._dtype
if not isinstance(ot, t.__class__):
raise TypeError(ot)
if t.negated:
raise TypeError("HBits.negated=True supports only _isOn(), ==, !=, LHS: ", self)
if ot.negated:
raise TypeError("HBits.negated=True supports only _isOn(), ==, !=, RHS:", other)
if not isinstance(ot, HBits):
raise TypeError(ot, t, self, other)
[docs]
def bitsIsOn(op: AnyHBitsValue) -> AnyHBitsValue:
t = op._dtype
if t.bit_length() == 1:
if t.negated:
res = ~(op._reinterpret_cast(BIT))
else:
res = op
else:
if t.negated:
res = op._reinterpret_cast(op._dtype._createMutated(negated=False))._eq(t.from_py(0))
else:
res = op != t.from_py(0)
assert not res._dtype.negated, (res, res._dtype)
return res
[docs]
@internal
def bitsCmp_detect_useless_cmp(op0: "HBitsRtlSignal", op1: "HBitsConst", op: HOperatorDef) -> Union[None, "HBitsRtlSignal", HOperatorDef]:
v = int(op1)
width = op1._dtype.bit_length()
if op0._dtype.signed:
min_val = -1 if width == 1 else -mask(width - 1) - 1
max_val = 0 if width == 1 else mask(width - 1)
else:
min_val = 0
max_val = mask(width)
if v == min_val:
# value can not be lower than min_val
if op == HwtOps.GE:
# -> always True
return BIT.from_py(1, 1)
elif op == HwtOps.LT:
# -> always False
return BIT.from_py(0, 1)
elif op == HwtOps.LE:
# convert <= to == to highlight the real function
return HwtOps.EQ
elif width == 1:
if op == HwtOps.EQ:
# x == 0 -> ~x
return ~op0
elif op == HwtOps.NE:
# x != 0 -> x
return op0
elif v == max_val:
# value can not be greater than max_val
if op == HwtOps.GT:
# always False
return BIT.from_py(0, 1)
elif op == HwtOps.LE:
# always True
return BIT.from_py(1, 1)
elif op == HwtOps.GE:
# because value can not be greater than max
return HwtOps.EQ
elif width == 1:
if op == HwtOps.EQ:
# x == 1 -> x
return op0
elif op == HwtOps.NE:
# x != 1 -> ~x
return ~op0
[docs]
@internal
def bitsCmp(self: AnyHBitsValue, selfIsHConst: bool, other: HBitsAnyCompatibleValue,
op: HOperatorDef,
selfReduceVal: HConst,
evalFn:Callable[[AnyHBitsValue, AnyHBitsValue], AnyHBitsValue]=None) -> AnyHBitsValue:
"""
Apply a generic comparison binary operator
:attention: If other is bool signal convert this to bool (not ideal,
due VHDL event operator)
:ivar self: operand 0
:ivar other: operand 1
:ivar op: operator used
:ivar selfReduceVal: the value which is a result if operands are all same signal (e.g. a==a = 1, b<b=0)
:ivar evalFn: override of a python operator function (by default one from "op" is used)
"""
t = self._dtype
other = toHVal(other, t)
ot = other._dtype
if op in (HwtOps.EQ, HwtOps.NE) and t.negated and t == ot:
ot = t = t._createMutated(negated=False)
self = self._reinterpret_cast(t)
other = other._reinterpret_cast(t)
else:
HBits_common_operand_type_checks(self, other)
if evalFn is None:
evalFn = op._evalFn
otherIsConst = isinstance(other, HConst)
type_compatible = False
if t == ot:
type_compatible = True
# lock type width/signed to other type with
elif not ot.strict_width or not ot.strict_sign:
type_compatible = True
other = other._auto_cast(t)
elif not t.strict_width or not t.strict_sign:
type_compatible = True
other = other._auto_cast(ot)
elif t.bit_length() == 1 and ot.bit_length() == 1\
and t.signed is ot.signed \
and t.force_vector != ot.force_vector:
# automatically cast to vector with a single item to a single bit
if t.force_vector:
self = self[0]
t = self._dtype
else:
other = other[0]
ot = other._dtype
type_compatible = True
if selfIsHConst and otherIsConst:
if type_compatible:
if evalFn == ne:
return bitsCmp__val_NE(self, other)
elif evalFn == eq:
return bitsCmp__val_EQ(self, other)
else:
return bitsCmp__val(self, other, evalFn)
else:
if type_compatible:
# try to reduce useless cmp
if otherIsConst and other._is_full_valid():
res = bitsCmp_detect_useless_cmp(self, other, op)
elif selfIsHConst and self._is_full_valid():
res = bitsCmp_detect_useless_cmp(other, self, CMP_OP_SWAP[op])
else:
res = None
if res is None:
pass
elif isinstance(res, HOperatorDef):
assert res == HwtOps.EQ, res
op = res
else:
return res
if self is other:
return selfReduceVal
else:
return HOperatorNode.withRes(op, [self, other], BIT)
elif t.strict_width and ot.strict_width and t.bit_length() != ot.bit_length():
pass
elif t.signed != ot.signed:
# handle sign casts
if t.signed is None:
self = self._cast_sign(ot.signed)
return bitsCmp(self, selfIsHConst, other, op, evalFn)
elif ot.signed is None:
other = other._cast_sign(t.signed)
return bitsCmp(self, selfIsHConst, other, op, evalFn)
elif t.force_vector != ot.force_vector:
# handle vector to bit casts
if t.force_vector:
self = self[0]
else:
other = other[0]
return bitsCmp(self, selfIsHConst, other, op, evalFn)
raise TypeError(f"Values of types (", self._dtype, other._dtype, ") are not comparable")
[docs]
@internal
def bitsBitOp(self: Union[RtlSignalBase, HConst],
selfIsHConst: bool, other: HBitsAnyCompatibleValue,
op: HOperatorDef,
getVldFn: Callable[[HConst, HConst], int],
reduceValCheckFn: Callable[[RtlSignalBase, HConst], bool],
reduceSigCheckFn: Callable[[RtlSignalBase, # op0Original
bool, # op0Negated
bool # op1Negated
], Union[RtlSignalBase, HConst]]) -> AnyHBitsValue:
"""
Apply a generic bitwise binary operator
:attention: If other is Bool signal, convert this to bool
(not ideal, due VHDL event operator)
:ivar self: operand 0
:ivar other: operand 1
:ivar op: operator used
:ivar getVldFn: function to resolve invalid (X) states
:ivar reduceValCheckFn: function to reduce useless operators (partially evaluate the expression if possible)
:ivar reduceSigCheckFn: function to reduce useless operators for signals and its negation flags
(e.g. a&a = a, a&~a=0, b^b=0)
function parameters are in format (op0Original:RtlSignalBase, op0Negated: bool, op1Negated:bool) -> Union[RtlSignalBase, HConst]:
returns result signal if reduction is possible else None
"""
other = toHVal(other, self._dtype)
otherIsHConst = isinstance(other, HConst)
if selfIsHConst and otherIsHConst:
other = other._auto_cast(self._dtype)
HBits_common_operand_type_checks(self, other)
return bitsBitOp__val(self, other, op._evalFn, getVldFn)
else:
s_t: HBits = self._dtype
o_t: HBits = other._dtype
HBits_common_operand_type_checks(self, other)
if s_t == o_t:
pass
else:
if s_t.signed is not o_t.signed and bool(s_t.signed) == bool(o_t.signed):
# automatically cast unsigned to vector
if s_t.signed == False and o_t.signed is None:
self = self._vec()
s_t = self._dtype
elif s_t.signed is None and o_t.signed == False:
other = other._vec()
o_t = other._dtype
else:
raise ValueError("Invalid value for signed flag of type", s_t.signed, o_t.signed, s_t, o_t)
if s_t == o_t:
# due to previsous cast the type may become the same
pass
elif s_t.bit_length() == 1 and o_t.bit_length() == 1\
and s_t.signed is o_t.signed \
and s_t.force_vector != o_t.force_vector:
# automatically cast to vector with a single item to a single bit
if s_t.force_vector:
self = self[0]
else:
other = other[0]
elif s_t == o_t._createMutated(negated=s_t.negated, strict_width=s_t.strict_width, strict_sign=s_t.strict_width):
# differs only in flags which do not affect this operator
pass
else:
raise TypeError("Can not apply operator",
op, self._dtype, other._dtype)
if otherIsHConst:
r = reduceValCheckFn(self, other)
if r is not None:
return r
elif selfIsHConst:
r = reduceValCheckFn(other, self)
if r is not None:
return r
else:
_self, _self_n = extractNegation(self)
_other, _other_n = extractNegation(other)
if _self is _other:
return reduceSigCheckFn(self, _self_n, _other_n)
return HOperatorNode.withRes(op, [self, other], self._dtype)
[docs]
def HBits_auto_cast_operands_to_same_type(self: AnyHBitsValue, other: AnyHBitsValue, opForDebug):
t0 = self._dtype
t1 = other._dtype
if t0 != t1:
w0 = t0.bit_length()
w1 = t1.bit_length()
if w0 != w1:
if not t1.strict_width:
if not t0.strict_width:
# pick max width
if w0 < w1:
self = self._auto_cast(t1._createMutated(bit_length=w1))
else:
other = other._auto_cast(t1._createMutated(bit_length=w0))
else:
# resize to type of this
other = other._auto_cast(t1._createMutated(bit_length=t0.bit_length()))
elif not t0.strict_width:
# resize self to type of result
self = self._auto_cast(t0)
else:
raise TypeError("incompatible width", self._dtype, opForDebug, other._dtype, self, other)
if t0.signed != t1.signed:
if not t1.strict_sign:
if not t0.strict_sign:
if t0.signed is None or (not t0.signed and t1.signed is not None):
# priority None < False < True
self = self._cast_sign(t1.signed)
else:
other = other._cast_sign(t0.signed)
else:
other = other._cast_sign(t0.signed)
elif not t0.strict_sign:
self = self._cast_sign(t1.signed)
else:
raise TypeError("incompatible sign", self._dtype, opForDebug, other._dtype, self, other)
t0 = self._dtype
t1 = other._dtype
if t0 != t1:
if t1.differs_only_in_strictness_flags(t1):
t0 = t1 = t0._createMutated(strict_width=t0.strict_width or t1.strict_width,
strict_sign=t0.strict_sign or t1.strict_sign)
self = self._auto_cast(t0)
other = other._auto_cast(t0)
else:
raise TypeError("incompatible types for operation", self._dtype, opForDebug, other._dtype, self, other)
return self, other
[docs]
@internal
def bitsArithOp(self: AnyHBitsValue, selfIsHConst: bool, other: HBitsAnyCompatibleValue, op: HOperatorDef) -> AnyHBitsValue:
other = toHVal(other, self._dtype)
HBits_common_operand_type_checks(self, other)
otherIsHConst = isinstance(other, HConst)
t0 = self._dtype
t1 = other._dtype
signed = t0.signed
# Promote sign from one of argument if necessary
if signed != t1.signed:
if not t1.strict_sign:
if not t0.strict_sign:
if t0.signed is None or (not t0.signed and t1.signed is not None):
# priority None < False < True
signed = t1.signed
else:
signed = t0.signed
else:
signed = t0.signed
elif not t0.strict_sign:
signed = t1.signed
elif bool(t0.signed) != bool(t1.signed):
raise TypeError("incompatible sign", self._dtype, op, other._dtype, self, other)
elif t1.signed is False:
signed = False
# Cast not-signed to unsigned
if t0.signed is None:
self = self._unsigned()
t0 = self._dtype
if t1.signed is None:
other = other._unsigned()
t1 = other._dtype
self, other = HBits_auto_cast_operands_to_same_type(self, other, op)
if selfIsHConst and otherIsHConst:
return bitsArithOp__val(self, other, op._evalFn)._cast_sign(signed)
else:
if op in (HwtOps.ADD, HwtOps.SUB):
if otherIsHConst and other._is_full_valid() and int(other) == 0:
# x +- 0 -> x
return self._cast_sign(signed)
elif op == HwtOps.ADD and selfIsHConst and self._is_full_valid() and int(self) == 0:
# 0 + x -> x
return other._auto_cast(t0)._cast_sign(signed)
o = HOperatorNode.withRes(op, [self, other], self._dtype)
return o._explicit_cast(self._dtype)._cast_sign(signed)
[docs]
@internal
def bitsFloordiv(self: AnyHBitsValue, selfIsHConst: bool, other: HBitsAnyCompatibleValue) -> AnyHBitsValue:
other = toHVal(other, suggestedType=self._dtype)
HBits_common_operand_type_checks(self, other)
op = HwtOps.SDIV if self._dtype.signed else HwtOps.UDIV
self, other = HBits_auto_cast_operands_to_same_type(self, other, op)
if selfIsHConst and isinstance(other, HConst):
return Bits3val.__floordiv__(self, other)
else:
return HOperatorNode.withRes(op,
[self, other],
self._dtype)
[docs]
@internal
def _bitsMulModGetResultType(myT: "HBits", otherT: "HBits"):
if otherT.strict_sign:
res_sign = otherT.signed
if myT.strict_sign:
assert bool(res_sign) == bool(myT.signed), (myT, otherT)
elif myT.strict_sign:
res_sign = myT.signed
else:
res_sign = self._dtype.signed or otherT.signed
if otherT.strict_width:
res_w = otherT.bit_length()
if myT.strict_width:
assert res_w == myT.bit_length(), (myT, otherT)
subResT = resT = otherT
elif myT.strict_width:
res_w = myT.bit_length()
subResT = resT = myT
else:
res_w = max(myT.bit_length(), otherT.bit_length())
subResT = HBits(res_w, signed=res_sign)
resT = HBits(res_w, signed=myT.signed)
return subResT, resT
[docs]
@internal
def bitsMul(self: AnyHBitsValue, selfIsHConst: bool, other: HBitsAnyCompatibleValue) -> AnyHBitsValue:
HBits = self._dtype.__class__
other = toHVal(other, suggestedType=self._dtype)
HBits_common_operand_type_checks(self, other)
otherIsHConst = isinstance(other, HConst)
self, other = HBits_auto_cast_operands_to_same_type(self, other, HwtOps.MUL)
if selfIsHConst and otherIsHConst:
return Bits3val.__mul__(self, other)
else:
# reduce *1 and *0
if selfIsHConst and self._is_full_valid():
_s = int(self)
if _s == 0:
return self._dtype.from_py(0)
elif _s == 1:
return other._auto_cast(self._dtype)
if otherIsHConst and other._is_full_valid():
_o = int(other)
if _o == 0:
return self._dtype.from_py(0)
elif _o == 1:
return self
myT = self._dtype
if self._dtype.signed is None:
self = self._unsigned()
if isinstance(other._dtype, HBits):
s = other._dtype.signed
if s is None:
other = other._unsigned()
else:
raise TypeError(self, HwtOps.MUL, other)
subResT, resT = _bitsMulModGetResultType(myT, other._dtype)
o = HOperatorNode.withRes(HwtOps.MUL, [self, other], subResT)
return o._auto_cast(resT)
[docs]
@internal
def bitsRem(self: AnyHBitsValue, selfIsHConst: bool, other: HBitsAnyCompatibleValue) -> AnyHBitsValue:
HBits = self._dtype.__class__
other = toHVal(other, suggestedType=self._dtype)
HBits_common_operand_type_checks(self, other)
otherIsHConst = isinstance(other, HConst)
if self._dtype.signed:
op = HwtOps.SREM
else:
op = HwtOps.UREM
self, other = HBits_auto_cast_operands_to_same_type(self, other, op)
if selfIsHConst and otherIsHConst:
return Bits3val.__mod__(self, other)
else:
if selfIsHConst and self._is_full_valid():
_s = int(self)
if _s == 0:
# 0 % x == 0
return self
if otherIsHConst and other._is_full_valid():
_o = int(other)
if _o == 0:
# x % 0 = x
return self
elif isPow2(_s):
# x % 2**cutOffBits
cutOffBits = log2ceil(_s)
return HBits(cutOffBits).from_py(0)._concat(self[:cutOffBits])
myT = self._dtype
if self._dtype.signed is None:
self = self._unsigned()
if isinstance(other._dtype, HBits):
s = other._dtype.signed
if s is None:
other = other._unsigned()
else:
raise TypeError(self, op, other)
subResT, resT = _bitsMulModGetResultType(myT, other._dtype)
o = HOperatorNode.withRes(op, [self, other], subResT)
return o._auto_cast(resT)
[docs]
@internal
def bitsLshift(self: AnyHBitsValue, shiftAmount: HBitsAnyCompatibleValue) -> AnyHBitsValue:
"""
shift left by a constant amount with 0 padding
"""
HBits_common_operand_type_checks_for_self(self)
if isinstance(shiftAmount, HConst) and not shiftAmount._is_full_valid():
return self._dtype.from_py(None)
shiftAmount = int(shiftAmount)
if shiftAmount == 0:
return self
assert shiftAmount > 0, ("shift amount must be positive value", shiftAmount)
width = self._dtype.bit_length()
suffix = HBits(min(width, shiftAmount)).from_py(0)
if shiftAmount >= width:
return suffix
else:
return self[(width - shiftAmount):]._concat(suffix)
[docs]
@internal
def bitsRshift(self: AnyHBitsValue, shiftAmount: HBitsAnyCompatibleValue) -> AnyHBitsValue:
"""
shift right by a constant amount
:note: arithmetic shift if type is signed else logical shift with 0 padding
"""
HBits_common_operand_type_checks_for_self(self)
if isinstance(shiftAmount, HConst) and not shiftAmount._is_full_valid():
return self._dtype.from_py(None)
shiftAmount = int(shiftAmount)
if shiftAmount == 0:
return self
assert shiftAmount > 0, ("shift amount must be positive value", shiftAmount)
width = self._dtype.bit_length()
if shiftAmount < width:
return self[:shiftAmount]._ext(width, bool(self._dtype.signed))
elif shiftAmount > width:
if self._dtype.signed:
msb = self[width - 1]
return msb._sext(width)
else:
return self._dtype.from_py(0)
else:
assert shiftAmount == 0, shiftAmount
return self