# source code of yield-based parallel reduction schedule:
# https://git.libre-soc.org/?p=libreriscv.git;a=blob;f=openpower/sv/preduce.py;hb=HEAD
from preduce import preduce_y
from copy import copy

def preduce(vl, vec, pred):
    vec = copy(vec)
    step = 1
    print(" start", step, pred, vec)
    while step < vl:
        step *= 2
        for i in range(0, vl, step):
            other = i + step // 2
            other_pred = other < vl and pred[other]
            if pred[i] and other_pred:
                vec[i] += vec[other]
            elif other_pred:
                vec[i] = vec[other]
        print("   row", step, pred, vec)
    return vec


def preducei(vl, vec, pred):
    vec = copy(vec)
    step = 1
    ix = list(range(vl)) # indices move rather than copy data
    print(" start", step, pred, vec)
    while step < vl:
        step *= 2
        for i in range(0, vl, step):
            other = i + step // 2
            ci = ix[i]
            oi = ix[other] if other < vl else None
            other_pred = other < vl and pred[oi]
            if pred[ci] and other_pred:
                vec[ci] += vec[oi]
            elif other_pred:
                ix[i] = oi # leave data in-place, copy index instead
        print("   row", step, pred, vec, ix)
    return vec



if __name__ == '__main__':
    vec = [1, 2, 3, 4, 9, 5, 6]
    prd = [0, 1, 1, 1, 0, 0, 1]
    print (vec)
    res = preduce(len(vec), vec, prd)
    print (res)
    res2 = preducei(len(vec), vec, prd)
    print (res2)
    print ()
    preduce_y(len(vec), vec, prd)
    print (vec)
    print ()
    assert vec == res2

    vec = [1, 2, 3, 4, 9, 5, 6]
    prd = [1, 0, 0, 1, 1, 0, 1]
    print (vec)
    res = preduce(len(vec), vec, prd)
    print (res)
    res2 = preducei(len(vec), vec, prd)
    print (res2)
    print ()
    preduce_y(len(vec), vec, prd)
    print (vec)
    print ()
    assert vec == res2

    vec = [1, 2, 3, 4, 9, 5, 6, 8]
    prd = [0, 0, 0, 0, 1, 1, 1, 1]
    print (vec)
    res = preduce(len(vec), vec, prd)
    print (res)
    res2 = preducei(len(vec), vec, prd)
    print (res2)
    print ()
    preduce_y(len(vec), vec, prd)
    print (vec)
    print ()
    assert vec == res2

    vec = [1, 2, 3, 4, 9, 5, 6, 8]
    prd = [0, 1, 0, 0, 0, 1, 0, 1]
    print (vec)
    res = preduce(len(vec), vec, prd)
    print (res)
    res2 = preducei(len(vec), vec, prd)
    print (res2)
    print ()
    preduce_y(len(vec), vec, prd)
    print (vec)
    print ()
    assert vec == res2

    vec = [1, 2, 3, 4, 9, 5, 6]
    prd = [1, 0, 1, 1, 0, 0, 1]
    print (vec)
    res = preduce(len(vec), vec, prd)
    print (res)
    res2 = preducei(len(vec), vec, prd)
    print (res2)
    print ()
    preduce_y(len(vec), vec, prd)
    print (vec)
    print ()
    assert vec == res2

    vec = [1, 2, 3, 4, 9, 5, 6]
    prd = [1, 1, 1, 1, 1, 1, 1]
    print (vec)
    res = preduce(len(vec), vec, prd)
    print (res)
    res2 = preducei(len(vec), vec, prd)
    print (res2)
    print ()
    preduce_y(len(vec), vec, prd)
    print (vec)
    print ()
    assert vec == res2