/* from hacker's delight, originsl by hannah suarez (hcs0) */
// Computes the m+n-halfword product of n halfwords x m halfwords, unsigned.
// Max line length is 57, to fit in hacker.book.
#include <stdio.h>
#include <stdlib.h> //To define "exit", req'd by XLC.

// w[0], u[0], and v[0] contain the LEAST significant halfwords.
// (The halfwords are in little-endian order).
// This is Knuth's Algorithm M from [Knuth Vol. 2 Third edition (1998)]
// section 4.3.1.  Picture is:
//                   u[m-1] ... u[1] u[0]
//                 x v[n-1] ... v[1] v[0]
//                   --------------------
//        w[m+n-1] ............ w[1] w[0]

void mulmnu(unsigned short w[], unsigned short u[], unsigned short v[], int m,
            int n)
{

    unsigned int k, t;
    int i, j;

    for (i = 0; i < m; i++)
        w[i] = 0;

    for (j = 0; j < n; j++)
    {
        k = 0;
        unsigned short phi[2000];
        unsigned short plo[2000];
        for (i = 0; i < m; i++)
        {
            unsigned product = (unsigned)u[i] * v[j] + w[i + j];
            phi[i] = product >> 16;
            plo[i] = product;
        }
        for (i = 0; i < m; i++)
        {
            t = (((unsigned)phi[i] << 16) | plo[i]) + k;
            w[i + j] = t; // (I.e., t & 0xFFFF).
            k = t >> 16;
        }
        w[j + m] = k;
    }
    return;
}

int errors;

void check(unsigned short result[], unsigned short u[], unsigned short v[],
           int m, int n, unsigned short correct[])
{
    int i, j;

    for (i = 0; i < m + n; i++)
    {
        if (correct[i] != result[i])
        {
            errors = errors + 1;
            printf("Error, m = %d, n = %d, u = ", m, n);
            for (j = 0; j < m; j++)
                printf(" %04x", u[j]);
            printf(" v =");
            for (j = 0; j < n; j++)
                printf(" %04x", v[j]);
            printf("\nShould get:");
            for (j = 0; j < n + m; j++)
                printf(" %04x", correct[j]);
            printf("\n       Got:");
            for (j = 0; j < n + m; j++)
                printf(" %04x", result[j]);
            printf("\n");
            break;
        }
    }
}

int main()
{
    static struct
    {
        int m, n;
        unsigned short u[4], v[4], correct[8];
    } test[] = {
        // clang-format off
        {.m=1, .n=1, .u={7}, .v={3}, .correct={21, 0}},
        {.m=1, .n=1, .u={2}, .v={0xFFFF}, .correct={0xFFFE, 0x0001}}, // 2*FFFF = 0001_FFFE.
        {.m=1, .n=1, .u={0xFFFF}, .v={0xFFFF}, .correct={1, 0xFFFE}},
        {.m=1, .n=2, .u={7}, .v={5, 6}, .correct={35, 42, 0}},
        {.m=1, .n=2, .u={65000}, .v={63000, 64000}, .correct={0xBDC0, 0x8414, 0xF7F5}},
        {.m=1, .n=3, .u={65535}, .v={31000, 32000, 33000}, .correct={0x86E8, 0xFC17, 0xFC17, 0x80E7}},
        {.m=2, .n=3, .u={400, 300}, .v={500, 100, 200}, .correct={0x0D40, 0xE633, 0xADB2, 0xEA61, 0}},
        {.m=2, .n=3, .u={400, 65535}, .v={500, 100, 65534}, .correct={0x0D40, 0x9A4F, 0xFE70, 0x01F5, 0xFFFD}},
        {.m=4, .n=4, .u={65535, 65535, 65535, 65535}, .v={65535, 65535, 65535, 65535},
         .correct={1, 0, 0, 0, 65534, 65535, 65535, 65535}},
        {.m=2, .n=2, .u={0xFF00, 0xFF00}, .v={0xFF00, 0xFF00}, .correct={0, 0xfe01, 0xfc02, 0xfe02}},
        // clang-format on
    };
    const int ncases = sizeof(test) / sizeof(test[0]);
    unsigned short result[10];

    printf("mulmnu:\n");
    for (int i = 0; i < ncases; i++)
    {
        int m = test[i].m, n = test[i].n;
        unsigned short *u = test[i].u;
        unsigned short *v = test[i].v;
        unsigned short *correct = test[i].correct;
        mulmnu(result, u, v, m, n);
        check(result, u, v, m, n, correct);
        mulmnu(result, v, u, n, m); // Interchange operands.
        check(result, v, u, n, m, correct);
    }

    if (errors == 0)
        printf("Passed all %d cases.\n", ncases);
}