#include <assert.h>
#include <errno.h>
#include <fcntl.h>
#include <inttypes.h>
#include <stdlib.h>
#include <unistd.h>

#include <openssl/sha.h>

#include "metainfo.h"
#include "subr.h"
#include "stream.h"

struct bt_stream_ro *
bts_open_ro(struct metainfo *meta, off_t off, F_fdcb fd_cb, void *fd_arg)
{
    struct bt_stream_ro *bts = malloc(sizeof(*bts));
    if (bts == NULL)
	return NULL;

    bts->meta = meta;
    bts->fd_cb = fd_cb;
    bts->fd_arg = fd_arg;
    bts->t_off = 0;
    bts->f_off = 0;
    bts->index = 0;
    bts->fd = -1;
    bts_seek_ro(bts, off);
    return bts;
}

void
bts_seek_ro(struct bt_stream_ro *bts, off_t off)
{
    struct fileinfo *files = bts->meta->files;

    assert(off >= 0 && off <= bts->meta->total_length);

    if (bts->fd != -1) {
	close(bts->fd);
	bts->fd = -1;
    }

    bts->t_off = off;
    bts->index = 0;

    while (off >= files[bts->index].length) {
	off -= files[bts->index].length;
	bts->index++;
    }

    bts->f_off = off;
}

int
bts_read_ro(struct bt_stream_ro *bts, char *buf, size_t len)
{
    struct fileinfo *files = bts->meta->files;
    size_t boff, wantread;
    ssize_t didread;

    assert(bts->t_off + len <= bts->meta->total_length);

    boff = 0;
    while (boff < len) {
	if (bts->fd == -1) {
	    int err =
		bts->fd_cb(files[bts->index].path, &bts->fd, bts->fd_arg);
	    if (err != 0)
		return err;
	    if (bts->f_off != 0)
		lseek(bts->fd, bts->f_off, SEEK_SET);
	}

	wantread = min(len - boff, files[bts->index].length - bts->f_off);
	didread = read(bts->fd, buf + boff, wantread);
	if (didread == -1)
	    return errno;

	boff += didread;
	bts->f_off += didread;
	bts->t_off += didread;
	if (bts->f_off == files[bts->index].length) {
	    close(bts->fd);
	    bts->fd = -1;
	    bts->f_off = 0;
	    bts->index++;
	}
	if (didread != wantread)
	    return ENOENT;
    }
    return 0;
}

void
bts_close_ro(struct bt_stream_ro *bts)
{
    if (bts->fd != -1)
	close(bts->fd);
    free(bts);
}

#define SHAFILEBUF (1 << 15)

int
bts_sha(struct bt_stream_ro *bts, off_t length, uint8_t *hash)
{
    SHA_CTX ctx;
    char buf[SHAFILEBUF];
    size_t wantread;
    int err = 0;

    SHA1_Init(&ctx);
    while (length > 0) {
	wantread = min(length, SHAFILEBUF);
	if ((err = bts_read_ro(bts, buf, wantread)) != 0)
	    break;
	length -= wantread;
	SHA1_Update(&ctx, buf, wantread);
    }
    SHA1_Final(hash, &ctx);
    return err;
}

int
bts_hashes(struct metainfo *meta,
    F_fdcb fd_cb,
    void (*cb)(uint32_t, uint8_t *, void *),
    void *arg)
{
    int err = 0;
    uint8_t hash[SHA_DIGEST_LENGTH];
    uint32_t piece;
    struct bt_stream_ro *bts;
    off_t plen = meta->piece_length;
    off_t llen = meta->total_length % plen;

    if ((bts = bts_open_ro(meta, 0, fd_cb, arg)) == NULL)
	return ENOMEM;
    
    for (piece = 0; piece < meta->npieces; piece++) {	
        if (piece < meta->npieces - 1)
	    err = bts_sha(bts, plen, hash);
	else
	    err = bts_sha(bts, llen, hash);

	if (err == 0)
	    cb(piece, hash, arg);
	else if (err == ENOENT) {
	    cb(piece, NULL, arg);
	    if (piece < meta->npieces - 1)
		bts_seek_ro(bts, (piece + 1) * plen);
	    err = 0;
	} else
	    break;
    }
    bts_close_ro(bts);
    return err;
}

struct bt_stream_wo *
bts_open_wo(struct metainfo *meta, off_t off, F_fdcb fd_cb, void *fd_arg)
{
    struct bt_stream_wo *bts = malloc(sizeof(*bts));
    if (bts == NULL)
	return NULL;

    bts->meta = meta;
    bts->fd_cb = fd_cb;
    bts->fd_arg = fd_arg;
    bts->t_off = 0;
    bts->f_off = 0;
    bts->index = 0;
    bts->fd = -1;
    bts_seek_ro((struct bt_stream_ro *)bts, off);
    return bts;
}

int
bts_write_wo(struct bt_stream_wo *bts, const char *buf, size_t len)
{
    struct fileinfo *files = bts->meta->files;
    size_t boff, wantwrite;
    ssize_t didwrite;

    assert(bts->t_off + len <= bts->meta->total_length);

    boff = 0;
    while (boff < len) {
	if (bts->fd == -1) {
	    int err =
		bts->fd_cb(files[bts->index].path, &bts->fd, bts->fd_arg);
	    if (err != 0)
		return err;
	    if (bts->f_off != 0)
		lseek(bts->fd, bts->f_off, SEEK_SET);
	}

	wantwrite = min(len - boff, files[bts->index].length - bts->f_off);
	didwrite = write(bts->fd, buf + boff, wantwrite);
	if (didwrite == -1)
	    return errno;

	boff += didwrite;
	bts->f_off += didwrite;
	bts->t_off += didwrite;
	if (bts->f_off == files[bts->index].length) {
            if (fsync(bts->fd) == -1) {
                int err = errno;
                close(bts->fd);
                return err;
            }
            if (close(bts->fd) == -1)
                return errno;
	    bts->fd = -1;
	    bts->f_off = 0;
	    bts->index++;
	}
    }
    return 0;
}

int
bts_close_wo(struct bt_stream_wo *bts)
{
    int err = 0;
    if (bts->fd != -1) {
	if (fsync(bts->fd) == -1) {
	    err = errno;
	    close(bts->fd);
	} else if (close(bts->fd) == -1)
	    err = errno;
    }
    free(bts);
    return err;
}