#include "config.h"
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/poll.h>
#include <inttypes.h>
#include <unistd.h>
#include <stdlib.h>
#include <string.h>
#include <stdarg.h>
#include <errno.h>

#include "lanbd_whohas.h"
#include "lanbd_write.h"
#include "lanbd_read.h"
#include "dendian.h"
#include "lanbd.h"
#include "net.h"


/*
 * SYNOPSIS
 *    void *cmd_val2func(
 *            cmd_t cmd
 *          );
 *
 * ARGUMENTS
 *    cmd_t cmd              `cmd' specifies the name of the command to
 *                           lookup.  It is of type `cmd_t' which is 
 *                           an enumerated list consisting of atleast:
 *                               WHOHAS
 *                               WHOHAS_R
 *                               READ
 *                               READ_R
 *                               WRITE
 *                               WRITE_R
 *
 * DESCRIPTION
 *    This function looks up the address of the function that cooresponds to
 *    the specified `cmd'.
 *
 * RETURN VALUE
 *    This function returns a pointer to the appropriate function for the
 *    specified `cmd' variable argument.  If nothing appropriate is found,
 *    then NULL is returned.
 *
 */
void *cmd_val2func(cmd_t cmd) {
	switch (cmd) {
		case WHOHAS:	return(cmd_whohas);
		case WHOHAS_R:	return(cmd_whohas_reply);
		case READ:	return(cmd_read);
		case READ_R:	return(cmd_read_reply);
		case WRITE:	return(cmd_write);
		case WRITE_R:	return(cmd_write_reply);
		default: return(NULL);
	}

	return(NULL);
}

/*
 * XXX
 */
ssize_t lanbd_senddata(struct sockaddr *dest, socklen_t destlen, void *buf, size_t buflen, int async, cmd_t cmd, uint32_t device, uint64_t block, ...) {
	struct lanbd_packet p_data; /* Uninitialized */
	struct sockaddr_in *dest_in=NULL; /* XXX: TEMPORARY */
	struct pollfd pfds[1];
	int (*cmdfunc)()=NULL;
	ssize_t retval=0;
	char *msgbuf=NULL, *msgbuf_s=NULL;
	size_t msgbuflen=0;
	va_list va_argparms; /* Uninitialized */
	uint32_t p_buflen=sizeof(p_data.rport)+sizeof(p_data.device)+sizeof(p_data.block)+sizeof(uint8_t);
	char *p_buf=NULL, p_buf_r[LANBD_PACKETSIZE+sizeof(struct lanbd_packet)];
#ifdef _REENTRANT
	uint16_t rport=0;
	int sockfd=-1;
#else
	/* If we're not threaded, keep the same socket and port. */
	static uint16_t rport=0;
	static int sockfd=-1;
#endif
	cmd_t cmd_r; /* Uninitialized */
	int pollret=-1;

	DBG_ENTER("%p, %i, %p, %i, %i, %i, 0x%08x, 0x%016llx, ...", dest, destlen, buf, buflen, async, cmd, device, block);

	/* Make sure the buffer will fit in a packet, if not, complain. */
	if (buflen>LANBD_PACKETSIZE) {
		return(-EFBIG);
	}

	p_buf=p_buf_r;

	/* Get the common parameters (except for rport, which is figured out below) */
	p_data.addr=dest;
	p_data.addrlen=destlen;
	p_data.cmd=cmd;
	p_data.device=device;
	p_data.block=block;
	p_data.data=buf;
	p_data.datalen=buflen;
	p_data.len=buflen;

	/* Fill in the buffer */
	de_writebuf(p_buf, p_data.cmd, sizeof(uint8_t)); p_buf+=sizeof(uint8_t);
	de_writebuf(p_buf, p_data.rport, sizeof(p_data.rport)); p_buf+=sizeof(p_data.rport); /* Placeholder, rport has no useful value yet. */
	de_writebuf(p_buf, p_data.device, sizeof(p_data.device)); p_buf+=sizeof(p_data.device);
	de_writebuf(p_buf, p_data.block, sizeof(p_data.block)); p_buf+=sizeof(p_data.block);

	/* Figure out the command specific parameters */
	va_start(va_argparms, block);
	switch (cmd) {
		case WHOHAS:
			p_data.b_flags = va_arg(va_argparms, int);
			p_data.flags=PFL_NONE;
			p_buflen+=sizeof(uint8_t); /* b_flags */
			de_writebuf(p_buf, p_data.b_flags, sizeof(uint8_t)); p_buf+=sizeof(uint8_t);
			break;
		case WHOHAS_R:
			p_data.e_block = va_arg(va_argparms, uint64_t);
			p_data.b_flags = va_arg(va_argparms, int);
			p_data.flags=PFL_ASYNC;
			p_buflen+=sizeof(uint8_t)+sizeof(p_data.e_block); /* b_flags */
			de_writebuf(p_buf, p_data.e_block, sizeof(p_data.e_block)); p_buf+=sizeof(p_data.e_block);
			de_writebuf(p_buf, p_data.b_flags, sizeof(uint8_t)); p_buf+=sizeof(uint8_t);
			break;
		case READ:
			p_data.offset = va_arg(va_argparms, uint32_t);
			p_data.flags=PFL_NONE;
			p_buflen+=sizeof(p_data.offset)+sizeof(p_data.len);
			de_writebuf(p_buf, p_data.offset, sizeof(p_data.offset)); p_buf+=sizeof(p_data.offset);
			de_writebuf(p_buf, p_data.len, sizeof(p_data.len)); p_buf+=sizeof(p_data.len);
			break;
		case READ_R:
			p_data.offset = va_arg(va_argparms, uint32_t);
			p_data.flags=PFL_ASYNC;
			p_buflen+=sizeof(p_data.offset)+sizeof(p_data.len)+buflen;
			de_writebuf(p_buf, p_data.offset, sizeof(p_data.offset)); p_buf+=sizeof(p_data.offset);
			de_writebuf(p_buf, p_data.len, sizeof(p_data.len)); p_buf+=sizeof(p_data.len);
			memcpy(p_buf, buf, buflen);
			break;
		case WRITE:
			p_data.offset = va_arg(va_argparms, uint32_t);
			p_buflen+=sizeof(p_data.offset)+sizeof(p_data.len)+buflen;
			if (async) { p_data.flags=PFL_ASYNC; } else { p_data.flags=PFL_NONE; }
			de_writebuf(p_buf, p_data.offset, sizeof(p_data.offset)); p_buf+=sizeof(p_data.offset);
			de_writebuf(p_buf, p_data.len, sizeof(p_data.len)); p_buf+=sizeof(p_data.len);
			memcpy(p_buf, buf, buflen);
			break;
		case WRITE_R:
			p_data.offset = va_arg(va_argparms, uint32_t);
			p_data.stat = va_arg(va_argparms, int);
			p_data.flags=PFL_ASYNC;
			p_buflen+=sizeof(p_data.offset)+sizeof(p_data.len)+sizeof(p_data.stat);
			de_writebuf(p_buf, p_data.offset, sizeof(p_data.offset)); p_buf+=sizeof(p_data.offset);
			de_writebuf(p_buf, p_data.len, sizeof(p_data.len)); p_buf+=sizeof(p_data.len);
			de_writebuf(p_buf, p_data.stat, sizeof(p_data.stat)); p_buf+=sizeof(p_data.stat);
			break;
		default:
			PRINTERR("Invalid command (cmd=%i)", cmd);
			return(-EIO);
	}
	va_end(va_argparms);

	/* Find a port to listen on to wait for replies. */
	/* We do this even in async mode so we can have a */
	/* good dport value. */
	if (sockfd<0 || rport==0) {
		PRINTERR("Finding port...");
		for (rport=LANBD_RPORT_MAX; rport>=128; rport--) {
			sockfd=net_listen_udp(rport);
			if (sockfd>=0) break;
		}
	}

	/* Determine which port we're listening on, if we are. */
	if (sockfd<0) {
		PRINTERR("Could not create listening socket for reply!");
		rport=0;
		return(-EIO);
	} else {
		/* Figure out the `rport' based on async. */
		if ((p_data.flags&PFL_ASYNC)==PFL_ASYNC) {
			/* This is an async call, there's no reply port, specify 0
			 * indicate this.
			 */
			p_data.rport=0;
		} else {
			p_data.rport=rport;
		}
	}

	/* Fill the `rport' data in. */
	p_buf=p_buf_r+sizeof(uint8_t);
	de_writebuf(p_buf, p_data.rport, sizeof(p_data.rport)); p_buf+=sizeof(p_data.rport);

	/* Send the data */
	PRINTBUF(p_buf_r, p_buflen);
	retval=net_send_udp(sockfd, p_data.addr, p_data.addrlen, p_buf_r, p_buflen);
	dest_in=(struct sockaddr_in *) dest; /* XXX: TEMPORARY */
	PRINTERR("Wrote %i bytes to destination (%08x/%i).", retval, dest_in->sin_addr.s_addr, ntohs(dest_in->sin_port)); /* XXX: TEMPORARY */

	/* If we're async, return the sendfunc reply. */
	if ((p_data.flags&PFL_ASYNC)==PFL_ASYNC) {
#ifdef _REENTRANT
		PRINTERR("Forcing reentrant");
		if (sockfd>=0) close(sockfd);
		sockfd=-1;
#endif
		return(retval);
	}

	/* If there was an error sending, pass the reply along. */
	if (retval<0) {
#ifdef _REENTRANT
		PRINTERR("Forcing reentrant");
		if (sockfd>=0) close(sockfd);
		sockfd=-1;
#endif
		return(retval);
	}

	/* If buf is invalid, don't try to write to it, return an error. */
	if (buf==NULL) {
		PRINTERR("Invalid buffer, not attempting to get a reply.");
#ifdef _REENTRANT
		PRINTERR("Forcing reentrant");
		if (sockfd>=0) close(sockfd);
		sockfd=-1;
#endif
		return(-EFAULT);
	}

	/* We need to wait for a reply on our socket. XXX */
	cmd_r=cmd|0x8;
	pfds[0].fd=sockfd;
	pfds[0].events=POLLIN;
	pfds[0].revents=0;
	PRINTERR("Listening on port %i for replies.", rport);
	pollret=poll(pfds, 1, 2*1000); /* XXX: Need to not hard-code the timeout. */
	if (pollret>0) {
		/* Allocate space for our header.  Need to free(msgbuf_s) after this block when returning. */
		msgbuflen=buflen+sizeof(struct lanbd_packet);
		msgbuf_s=msgbuf=malloc(msgbuflen);
		if (msgbuf==NULL) {
			PRINTERR("Error allocating data buffer!  Aborting.");
#ifdef _REENTRANT
			PRINTERR("Forcing reentrant");
			if (sockfd>=0) close(sockfd);
			sockfd=-1;
#endif
			return(-ENOSPC);
		}

		/* Recieve data */
		msgbuflen=recv(sockfd, msgbuf, msgbuflen, 0);
		if (msgbuflen<0) {
			PERROR("recv");
			PRINTERR("Error recieving data!  Aborting.");
			free(msgbuf_s);
#ifdef _REENTRANT
			PRINTERR("Forcing reentrant");
			if (sockfd>=0) close(sockfd);
			sockfd=-1;
#endif
			return(-EIO);
		}

		/* Process recieved packet. */
		/* Get common header information from the packet. */
		p_data.cmd=de_readbuf(msgbuf, sizeof(uint8_t)); msgbuf+=sizeof(uint8_t); msgbuflen-=sizeof(uint8_t);
		p_data.rport=de_readbuf(msgbuf, sizeof(p_data.rport)); msgbuf+=sizeof(p_data.rport); msgbuflen-=sizeof(p_data.rport);
		p_data.device=de_readbuf(msgbuf, sizeof(p_data.device)); msgbuf+=sizeof(p_data.device); msgbuflen-=sizeof(p_data.device);
		p_data.block=de_readbuf(msgbuf, sizeof(p_data.block)); msgbuf+=sizeof(p_data.block); msgbuflen-=sizeof(p_data.block);
		p_data.data=msgbuf;
		p_data.datalen=msgbuflen;
		SPOTVAR_I(msgbuflen);
		cmdfunc=cmd_val2func(p_data.cmd);
		if (cmdfunc) {
			retval=cmdfunc(&p_data, buf, buflen);
		} else {
			retval=-1;
		}

		free(msgbuf_s);
	} else {
		retval=-1;
	}


#ifdef _REENTRANT
	PRINTERR("Forcing reentrant");
	close(sockfd);
	sockfd=-1;
#endif
	return(retval);
}

/* XXX: DOCUMENT! */
static struct sockaddr *lanbd_whohas(lanbd_packet_bfl_t type, uint32_t device, uint64_t block) {
#ifdef LANBDD
	struct sockaddr *dest=NULL;
	struct sockaddr_in *dest_in=NULL;
	ssize_t lanbd_ret=0;
	socklen_t dest_inlen=sizeof(*dest_in);
	size_t destlen=sizeof(*dest);

	DBG_ENTER("%i, 0x%08x, 0x%016llx", type, device, block);

	dest=malloc(destlen);
	if (dest==NULL) return(NULL);

	dest_in=malloc(dest_inlen);
	if (dest_in==NULL) return(NULL);
	dest_in->sin_family=AF_INET;
	dest_in->sin_port=htons(LANBD_PORT); /* XXX: need to unhardcode this. */
	dest_in->sin_addr.s_addr=htonl(INADDR_BROADCAST);

	lanbd_ret=lanbd_senddata((struct sockaddr *) dest_in, dest_inlen, dest, destlen, 0, WHOHAS, device, block, type);
	if (lanbd_ret<0) {
		free(dest);
		return(NULL);
	}
	return(dest);
#else
	static struct sockaddr_in dest; /* Uninitialized */

	DBG_ENTER("%i, 0x%08x, 0x%016llx", type, device, block);

	dest.sin_family=AF_INET;
	dest.sin_addr.s_addr=htonl(INADDR_LOOPBACK);
	dest.sin_port=htons(LANBD_PORT); /* XXX: need to unhardcode this... */
	return((struct sockaddr *) &dest);
#endif
}

/* XXX: DOCUMENT! */
ssize_t lanbd_read(uint32_t device, uint64_t block, uint32_t offset, void *buf, size_t count) {
	struct sockaddr *dest=NULL;
	ssize_t readret=-1;

	DBG_ENTER("0x%08x, 0x%016llx, 0x%08x, %p, %i", device, block, offset, buf, count);

	if (buf==NULL) return(-EFAULT);

	dest=lanbd_whohas(BFL_ANY, device, block);
	if (dest==NULL) return(-ENOSPC);

	readret=lanbd_senddata(dest, sizeof(*dest), buf, count, 0, READ, device, block, offset);

	return(readret);
}

/* XXX: DOCUMENT! */
ssize_t lanbd_write(uint32_t device, uint64_t block, uint32_t offset, void *buf, size_t count) {
	struct sockaddr *dest=NULL;
	ssize_t writeret=-1;

	DBG_ENTER("0x%08x, 0x%016llx, 0x%08x, %p, %i", device, block, offset, buf, count);

	if (buf==NULL) return(-EFAULT);

	dest=lanbd_whohas(BFL_MASTER, device, block);
	if (dest==NULL) return(-ENOSPC);

	writeret=lanbd_senddata(dest, sizeof(*dest), buf, count, 0, WRITE, device, block, offset);

	return(writeret);
}

