def aoi(a, b, c, d):
    return ~((a & b) | (c & d))


def oai(a, b, c, d):
    return ~((a | b) & (c | d))


def and_or(a, b, c, d):
    return ((a & b) | (c & d))


def grev_wires(a, step):
    retval = 0
    for i in range(64):
        if a & (1 << i):
            retval |= 1 << (i ^ step)
    return retval


def grev_mask(step):
    retval = 0
    for i in range(64):
        if ~i & step:
            retval |= 1 << i
    return retval


def grevlut_grev_gorc(a, sh, imm_lut, inv_in, inv_out):
    if inv_in:
        a = ~a
    for log2_step in range(6):
        step = 2 ** log2_step
        grev_mask_v = grev_mask(step)
        sh_bit = (sh >> log2_step) & 1
        b = d = 0
        if (imm_lut >> sh_bit) & 0x1:
            b |= grev_mask_v
        if (imm_lut >> sh_bit) & 0x4:
            b |= ~grev_mask_v
        if (imm_lut >> sh_bit) & 0x10:
            d |= grev_mask_v
        if (imm_lut >> sh_bit) & 0x40:
            d |= ~grev_mask_v
        c = grev_wires(a, 2 ** log2_step)
        if log2_step % 2 != 0:
            a = oai(a, ~b, c, ~d)
        else:
            a = aoi(a, b, c, d)
    if inv_out:
        a = ~a
    a %= 2**64
    return a


def case(a, sh, imm_lut, inv_in, inv_out):
    v = grevlut_grev_gorc(a, sh, imm_lut, inv_in, inv_out)
    print("  gl(%x, %x, %s, %d, %d = %x" % (a, sh, bin(imm_lut),
                                            inv_in, inv_out, v))

if __name__ == '__main__':
    # quick experiment
    imms = [0b101_0010, 0b101_0110, 0b1110_0110, 0b1001_0011]
    for imm in imms:
        print ("imm", bin(imm))
        for i in range(64):
            case(0x5555_5555_5555_5555, i, imm, True, False)
        print()

    print()
    print()

    # explore all possible constants, 16 at a time, here. edit the OR
    # to try different combinations | 0b0101_0000 | 0b1000_0000 etc.
    imms = []
    for i in range(16):
        imms.append(i | 0b0101_0000)
    for imm in imms:
        print ("imm", bin(imm))
        for i in range(64):
            case(0x5555_5555_5555_5555, i, imm, True, False)
        print()