#include <assert.h>
#include <ctype.h>
#include <errno.h>
#include <inttypes.h>
#include <stdlib.h>
#include <string.h>

#include "benc.h"

#define benc_safeset(out, val) if ((out) != NULL) *(out) = (val)

static const char *benc_validate_aux(const char *p, const char *end);

int
benc_validate(const char *p, size_t len)
{
    const char *end = p + len - 1;

    if (len <= 0)
        return EINVAL;

    return benc_validate_aux(p, end) == end ? 0 : EINVAL;
}

static const char *
benc_validate_aux(const char *p, const char *end)
{
    size_t d = 0;
    switch (*p) {
    case 'd':
        d = 1;
    case 'l':
        for (p++; p <= end && *p != 'e'; p++) {
            if (d != 0) {
                if (d % 2 == 1 && !isdigit(*p))
                    return NULL;
                else
                    d++;
            }
            if ((p = benc_validate_aux(p, end)) == NULL)
                return NULL;
        }
        if (p > end || (d != 0 && d % 2 != 1))
            return NULL;
        break;
    case 'i':
        p++;
        if (p > end)
            return NULL;
        if (*p == '-')
            p++;
        if (p > end || !isdigit(*p))
            return NULL;
        p++;
        while (p <= end && isdigit(*p))
            p++;
        if (p > end || *p != 'e')
            return NULL;
        break;
    default:
        if (isdigit(*p)) {
            size_t len = 0;
            while (p <= end && isdigit(*p)) {
                len *= 10;
                len += *p - '0';
                p++;
            }
            if (p <= end && *p == ':' && p + len <= end)
                p += len;
            else
                return NULL;
        }
        else
            return NULL;
        break;
    }
    return p;
}

size_t
benc_length(const char *p)
{
    size_t blen;
    const char *next;

    switch (*p) {
    case 'd':
    case 'l':
        blen = 2; // [l|d]...e
        next = benc_first(p);
        while (*next != 'e') {
            size_t len = benc_length(next);
            blen += len;
            next += len;
        }
        return blen;
    case 'i':
        for (next = p + 1; *next != 'e'; next++)
            ;
        return next - p + 1;
    default:
        assert(benc_str(p, &next, &blen, NULL) == 0);
        return next - p + blen;
    }
}

size_t
benc_nelems(const char *p)
{
    size_t nelems = 0;
    for (p = benc_first(p); p != NULL; p = benc_next(p))
        nelems++;
    return nelems;
}

const char *
benc_first(const char *p)
{
    assert(benc_islst(p));
    return *(p + 1) == 'e' ? NULL : p + 1;
}

const char *
benc_next(const char *p)
{
    size_t blen = benc_length(p);
    return *(p + blen) == 'e' ? NULL : p + blen;
}

int
benc_str(const char *p, const char **out, size_t *len, const char**next)
{
    size_t blen = 0;
    assert(isdigit(*p));
    blen = *p - '0';
    p++;
    while (isdigit(*p)) {
        blen *= 10;
        blen += *p - '0';
        p++;
    }
    assert(*p == ':');
    benc_safeset(len, blen);
    benc_safeset(out, p + 1);
    benc_safeset(next, *(p + blen + 1) == 'e' ? NULL : p + blen + 1);
    return 0;
}

int
benc_strz(const char *p, char **out, size_t *len, const char **next)
{
    int err;
    size_t blen;
    const char *bstr;

    if ((err = benc_str(p, &bstr, &blen, next)) == 0) {
        if ((*out = malloc(blen + 1)) != NULL) {
            memcpy(*out, bstr, blen);
            (*out)[blen] = '\0';
            benc_safeset(len, blen);
        } else
            err = ENOMEM;
    }
    return err;
}

int
benc_stra(const char *p, char **out, size_t *len, const char **next)
{
    int err;
    size_t blen;
    const char *bstr;

    if ((err = benc_str(p, &bstr, &blen, next)) == 0) {
        if ((*out = malloc(blen)) != NULL) {
            memcpy(*out, bstr, blen);
            benc_safeset(len, blen);
        } else
            err = ENOMEM;
    }
    return err;
}

int
benc_int64(const char *p, int64_t *out, const char **next)
{
    int sign = 1;
    int64_t res = 0;

    assert(*p == 'i');
    p++;
    if (*p == '-') {
        sign = -1;
        p++;
    }
    assert(isdigit(*p));
    res += sign * (*p - '0');
    p++;
    while (isdigit(*p)) {
        res *= sign * 10;
        res += sign * (*p - '0');
        p++;
    }
    assert(*p == 'e');
    benc_safeset(out, res);
    benc_safeset(next, *(p + 1) == 'e' ? NULL : p + 1);

    return 0;
}

int
benc_uint32(const char *p, uint32_t *out, const char **next)
{
    int err;
    int64_t res;
    if ((err = benc_int64(p, &res, next)) == 0) {
        if (res >= 0 && res <= 0xffffffffUL)
            *out = (uint32_t)res;
        else
            err = EINVAL;
    }
    return err;
}

int
benc_dget_any(const char *p, const char *key, const char **val)
{
    int res;
    size_t len, blen;
    const char *bstr;

    assert(benc_isdct(p));

    len = strlen(key);

    p = benc_first(p);
    while (p != NULL) {
        if ((res = benc_str(p, &bstr, &blen, &p)) != 0)
            return res;

        res = strncmp(bstr, key, blen);
        if (res == 0 && len == blen) {
            *val = p;
            return 0;
        } else if (res <= 0) {
            p = benc_next(p);
        } else
            return ENOENT;
    }
    return ENOENT;
}

int
benc_dget_lst(const char *p, const char *key, const char **val)
{
    int err;
    if ((err = benc_dget_any(p, key, val)) == 0)
        if (!benc_islst(*val))
            err = EINVAL;
    return err;
}

int
benc_dget_dct(const char *p, const char *key, const char **val)
{
    int err;
    if ((err = benc_dget_any(p, key, val)) == 0)
        if (!benc_isdct(*val))
            err = EINVAL;
    return err;
}

int
benc_dget_str(const char *p, const char *key, const char **val, size_t *len)
{
    int err;
    const char *sp;
    if ((err = benc_dget_any(p, key, &sp)) == 0)
        err = benc_isstr(sp) ? benc_str(sp, val, len, NULL) : EINVAL;
    return err;
}

int
benc_dget_stra(const char *p, const char *key, char **val, size_t *len)
{
    int err;
    const char *sp;
    if ((err = benc_dget_any(p, key, &sp)) == 0)
        err = benc_isstr(sp) ? benc_stra(sp, val, len, NULL) : EINVAL;
    return err;
}

int
benc_dget_strz(const char *p, const char *key, char **val, size_t *len)
{
    int err;
    const char *sp;
    if ((err = benc_dget_any(p, key, &sp)) == 0)
        err = benc_isstr(sp) ? benc_strz(sp, val, len, NULL) : EINVAL;
    return err;
}

int
benc_dget_int64(const char *p, const char *key, int64_t *val)
{
    int err;
    const char *ip;
    if ((err = benc_dget_any(p, key, &ip)) == 0)
        err = benc_isint(ip) ? benc_int64(ip, val, NULL) : EINVAL;
    return err;
}

int
benc_dget_uint32(const char *p, const char *key, uint32_t *val)
{
    int err;
    const char *ip;
    if ((err = benc_dget_any(p, key, &ip)) == 0)
        err = benc_isint(ip) ? benc_uint32(ip, val, NULL) : EINVAL;
    return err;
}

int
benc_islst(const char *p)
{
    return *p == 'l' || *p == 'd';
}

int
benc_isdct(const char *p)
{
    return *p == 'd';
}

int
benc_isint(const char *p)
{
    return *p == 'i';
}

int
benc_isstr(const char *p)
{
    return isdigit(*p);
}