#include "video_encoder.h"

#include <assert.h>
#include <stdio.h>
#include <time.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netdb.h>
#include <string>
#include <thread>
#include <chrono>

extern "C" {
#include <libavutil/mem.h>
}

#include "audio_encoder.h"
#ifdef HAVE_AV1
#include "av1_encoder.h"
#endif
#include "defs.h"
#include "shared/ffmpeg_raii.h"
#include "flags.h"
#include "shared/httpd.h"
#include "shared/mux.h"
#include "quicksync_encoder.h"
#include "shared/timebase.h"
#include "x264_encoder.h"

class RefCountedFrame;

using namespace std;
using namespace std::chrono;
using namespace movit;

namespace {

string generate_local_dump_filename(int frame)
{
	time_t now = time(NULL);
	tm now_tm;
	localtime_r(&now, &now_tm);

	char timestamp[64];
	strftime(timestamp, sizeof(timestamp), "%F-%H%M%S%z", &now_tm);

	// Use the frame number to disambiguate between two cuts starting
	// on the same second.
	char filename[256];
	snprintf(filename, sizeof(filename), "%s/%s%s-f%02d%s",
		global_flags.recording_dir.c_str(),
		LOCAL_DUMP_PREFIX, timestamp, frame % 100, LOCAL_DUMP_SUFFIX);
	return filename;
}

}  // namespace

VideoEncoder::VideoEncoder(ResourcePool *resource_pool, QSurface *surface, const std::string &va_display, int width, int height, HTTPD *httpd, DiskSpaceEstimator *disk_space_estimator)
	: resource_pool(resource_pool), surface(surface), va_display(va_display), width(width), height(height), httpd(httpd), disk_space_estimator(disk_space_estimator)
{
	// TODO: If we're outputting AV1, we can't use MPEG-TS currently.
	srt_oformat = av_guess_format("mpegts", nullptr, nullptr);
	assert(srt_oformat != nullptr);

	oformat = av_guess_format(global_flags.stream_mux_name.c_str(), nullptr, nullptr);
	assert(oformat != nullptr);
	if (global_flags.stream_audio_codec_name.empty()) {
		stream_audio_encoder.reset(new AudioEncoder(AUDIO_OUTPUT_CODEC_NAME, DEFAULT_AUDIO_OUTPUT_BIT_RATE, oformat));
	} else {
		stream_audio_encoder.reset(new AudioEncoder(global_flags.stream_audio_codec_name, global_flags.stream_audio_codec_bitrate, oformat));
	}
	if (global_flags.x264_video_to_http || global_flags.x264_video_to_disk) {
		x264_encoder.reset(new X264Encoder(oformat, /*use_separate_disk_params=*/false));
	}
	VideoCodecInterface *http_encoder = x264_encoder.get();
	VideoCodecInterface *disk_encoder = x264_encoder.get();
#ifdef HAVE_AV1
	if (global_flags.av1_video_to_http) {
		av1_encoder.reset(new AV1Encoder(oformat));
		http_encoder = av1_encoder.get();
	}
#endif
	if (global_flags.x264_separate_disk_encode) {
		x264_disk_encoder.reset(new X264Encoder(oformat, /*use_separate_disk_params=*/true));
		disk_encoder = x264_disk_encoder.get();
	}

	string filename = generate_local_dump_filename(/*frame=*/0);
	quicksync_encoder.reset(new QuickSyncEncoder(filename, resource_pool, surface, va_display, width, height, oformat, http_encoder, disk_encoder, disk_space_estimator));

	open_output_streams();
	stream_audio_encoder->add_mux(http_mux.get());
	if (srt_mux != nullptr) {
		stream_audio_encoder->add_mux(srt_mux.get());
	}
	quicksync_encoder->set_http_mux(http_mux.get());
	if (srt_mux != nullptr) {
		quicksync_encoder->set_srt_mux(srt_mux.get());
	}
	if (global_flags.x264_video_to_http) {
		x264_encoder->add_mux(http_mux.get());
		if (srt_mux != nullptr) {
			x264_encoder->add_mux(srt_mux.get());
		}
	}
#ifdef HAVE_AV1
	if (global_flags.av1_video_to_http) {
		av1_encoder->add_mux(http_mux.get());
		if (srt_mux != nullptr) {
			av1_encoder->add_mux(srt_mux.get());
		}
	}
#endif
}

VideoEncoder::~VideoEncoder()
{
	should_quit = true;
	quicksync_encoder->shutdown();
	x264_encoder.reset(nullptr);
	x264_disk_encoder.reset(nullptr);
	quicksync_encoder->close_file();
	quicksync_encoder.reset(nullptr);
	while (quicksync_encoders_in_shutdown.load() > 0) {
		usleep(10000);
	}
}

void VideoEncoder::do_cut(int frame)
{
	string filename = generate_local_dump_filename(frame);
	printf("Starting new recording: %s\n", filename.c_str());

	// Do the shutdown of the old encoder in a separate thread, since it can
	// take some time (it needs to wait for all the frames in the queue to be
	// done encoding, for one) and we are running on the main mixer thread.
	// However, since this means both encoders could be sending packets at
	// the same time, it means pts could come out of order to the stream mux,
	// and we need to plug it until the shutdown is complete.
	http_mux->plug();
	lock(qs_mu, qs_audio_mu);
	lock_guard<mutex> lock1(qs_mu, adopt_lock), lock2(qs_audio_mu, adopt_lock);
	QuickSyncEncoder *old_encoder = quicksync_encoder.release();  // When we go C++14, we can use move capture instead.
	X264Encoder *old_x264_encoder = nullptr;
	X264Encoder *old_x264_disk_encoder = nullptr;
	if (global_flags.x264_video_to_disk) {
		old_x264_encoder = x264_encoder.release();
	}
	if (global_flags.x264_separate_disk_encode) {
		old_x264_disk_encoder = x264_disk_encoder.release();
	}
	thread([old_encoder, old_x264_encoder, old_x264_disk_encoder, this]{
		old_encoder->shutdown();
		delete old_x264_encoder;
		delete old_x264_disk_encoder;
		old_encoder->close_file();
		http_mux->unplug();

		// We cannot delete the encoder here, as this thread has no OpenGL context.
		// We'll deal with it in begin_frame().
		lock_guard<mutex> lock(qs_mu);
		qs_needing_cleanup.emplace_back(old_encoder);
	}).detach();

	if (global_flags.x264_video_to_disk) {
		x264_encoder.reset(new X264Encoder(oformat, /*use_separate_disk_params=*/false));
		assert(global_flags.x264_video_to_http);
		if (global_flags.x264_video_to_http) {
			x264_encoder->add_mux(http_mux.get());
		}
		if (overriding_bitrate != 0) {
			x264_encoder->change_bitrate(overriding_bitrate);
		}
	}
	X264Encoder *http_encoder = x264_encoder.get();
	X264Encoder *disk_encoder = x264_encoder.get();
	if (global_flags.x264_separate_disk_encode) {
		x264_disk_encoder.reset(new X264Encoder(oformat, /*use_separate_disk_params=*/true));
		disk_encoder = x264_disk_encoder.get();
	}

	quicksync_encoder.reset(new QuickSyncEncoder(filename, resource_pool, surface, va_display, width, height, oformat, http_encoder, disk_encoder, disk_space_estimator));
	quicksync_encoder->set_http_mux(http_mux.get());
}

void VideoEncoder::change_x264_bitrate(unsigned rate_kbit)
{
	overriding_bitrate = rate_kbit;
	x264_encoder->change_bitrate(rate_kbit);
}

void VideoEncoder::add_audio(int64_t pts, std::vector<float> audio)
{
	// Take only qs_audio_mu, since add_audio() is thread safe
	// (we can only conflict with do_cut(), which takes qs_audio_mu)
	// and we don't want to contend with begin_frame().
	{
		lock_guard<mutex> lock(qs_audio_mu);
		quicksync_encoder->add_audio(pts, audio);
	}
	stream_audio_encoder->encode_audio(audio, pts + quicksync_encoder->global_delay());
}

bool VideoEncoder::is_zerocopy() const
{
	// Explicitly do _not_ take qs_mu; this is called from the mixer,
	// and qs_mu might be contended. is_zerocopy() is thread safe
	// and never called in parallel with do_cut() (both happen only
	// from the mixer thread).
	return quicksync_encoder->is_zerocopy();
}

bool VideoEncoder::begin_frame(int64_t pts, int64_t duration, movit::YCbCrLumaCoefficients ycbcr_coefficients, const std::vector<RefCountedFrame> &input_frames, GLuint *y_tex, GLuint *cbcr_tex)
{
	lock_guard<mutex> lock(qs_mu);
	qs_needing_cleanup.clear();  // Since we have an OpenGL context here, and are called regularly.
	return quicksync_encoder->begin_frame(pts, duration, ycbcr_coefficients, input_frames, y_tex, cbcr_tex);
}

RefCountedGLsync VideoEncoder::end_frame()
{
	want_srt_metric_update = true;
	lock_guard<mutex> lock(qs_mu);
	return quicksync_encoder->end_frame();
}

void VideoEncoder::open_output_streams()
{
	for (bool is_srt : {false, true}) {
		if (is_srt && global_flags.srt_destination_host.empty()) {
			continue;
		}

		AVFormatContext *avctx = avformat_alloc_context();
		avctx->oformat = is_srt ? srt_oformat : oformat;

		uint8_t *buf = (uint8_t *)av_malloc(MUX_BUFFER_SIZE);
		avctx->pb = avio_alloc_context(buf, MUX_BUFFER_SIZE, 1, this, nullptr, nullptr, nullptr);
		if (is_srt) {
			avctx->pb->write_packet = &VideoEncoder::write_srt_packet_thunk;
		} else {
			avctx->pb->write_data_type = &VideoEncoder::write_packet2_thunk;
			avctx->pb->ignore_boundary_point = 1;
		}

		Mux::Codec video_codec;
		if (global_flags.av1_video_to_http) {
			video_codec = Mux::CODEC_AV1;
		} else {
			video_codec = Mux::CODEC_H264;
		}

		avctx->flags = AVFMT_FLAG_CUSTOM_IO;

		string video_extradata;
		if (global_flags.x264_video_to_http) {
			video_extradata = x264_encoder->get_global_headers();
#ifdef HAVE_AV1
		} else if (global_flags.av1_video_to_http) {
			video_extradata = av1_encoder->get_global_headers();
#endif
		}

		Mux *mux = new Mux(avctx, width, height, video_codec, video_extradata, stream_audio_encoder->get_codec_parameters().get(),
			get_color_space(global_flags.ycbcr_rec709_coefficients), COARSE_TIMEBASE,
			/*write_callback=*/nullptr, is_srt ? Mux::WRITE_BACKGROUND : Mux::WRITE_FOREGROUND, { is_srt ? &srt_mux_metrics : &http_mux_metrics });
		if (is_srt) {
			srt_mux.reset(mux);
			srt_mux_metrics.init({{ "destination", "srt" }});
			srt_metrics.init({{ "cardtype", "output" }});
			global_metrics.add("srt_num_connection_attempts", {{ "cardtype", "output" }}, &metric_srt_num_connection_attempts);
		} else {
			http_mux.reset(mux);
			http_mux_metrics.init({{ "destination", "http" }});
		}
	}
}

int VideoEncoder::write_packet2_thunk(void *opaque, uint8_t *buf, int buf_size, AVIODataMarkerType type, int64_t time)
{
	VideoEncoder *video_encoder = (VideoEncoder *)opaque;
	return video_encoder->write_packet2(buf, buf_size, type, time);
}

int VideoEncoder::write_packet2(uint8_t *buf, int buf_size, AVIODataMarkerType type, int64_t time)
{
	if (type == AVIO_DATA_MARKER_SYNC_POINT || type == AVIO_DATA_MARKER_BOUNDARY_POINT) {
		seen_sync_markers = true;
	} else if (type == AVIO_DATA_MARKER_UNKNOWN && !seen_sync_markers) {
		// We don't know if this is a keyframe or not (the muxer could
		// avoid marking it), so we just have to make the best of it.
		type = AVIO_DATA_MARKER_SYNC_POINT;
	}

	if (type == AVIO_DATA_MARKER_HEADER) {
		http_mux_header.append((char *)buf, buf_size);
		httpd->set_header(HTTPD::StreamID{ HTTPD::MAIN_STREAM, 0 }, http_mux_header);
	} else {
		httpd->add_data(HTTPD::StreamID{ HTTPD::MAIN_STREAM, 0 }, (char *)buf, buf_size, type == AVIO_DATA_MARKER_SYNC_POINT, time, AVRational{ AV_TIME_BASE, 1 });
	}
	return buf_size;
}

int VideoEncoder::write_srt_packet_thunk(void *opaque, uint8_t *buf, int buf_size)
{
	VideoEncoder *video_encoder = (VideoEncoder *)opaque;
	return video_encoder->write_srt_packet(buf, buf_size);
}

static string print_addrinfo(const addrinfo *ai)
{
	char hoststr[NI_MAXHOST], portstr[NI_MAXSERV];
	if (getnameinfo(ai->ai_addr, ai->ai_addrlen, hoststr, sizeof(hoststr), portstr, sizeof(portstr), NI_DGRAM | NI_NUMERICHOST | NI_NUMERICSERV) != 0) {
		return "<unknown address>";  // Should basically never happen, since we're not doing DNS lookups.
	}

	if (ai->ai_family == AF_INET6) {
		return string("[") + hoststr + "]:" + portstr;
	} else {
		return string(hoststr) + ":" + portstr;
	}
}

int VideoEncoder::open_srt_socket()
{
	int sock = srt_create_socket();
	if (sock == -1) {
		fprintf(stderr, "srt_create_socket(): %s\n", srt_getlasterror_str());
		return -1;
	}

	SRT_TRANSTYPE live = SRTT_LIVE;
	if (srt_setsockopt(sock, 0, SRTO_TRANSTYPE, &live, sizeof(live)) < 0) {
		fprintf(stderr, "srt_setsockopt(SRTO_TRANSTYPE): %s\n", srt_getlasterror_str());
		srt_close(sock);
		return -1;
	}

	if (srt_setsockopt(sock, 0, SRTO_LATENCY, &global_flags.srt_output_latency_ms, sizeof(global_flags.srt_output_latency_ms)) < 0) {
		fprintf(stderr, "srt_setsockopt(SRTO_LATENCY): %s\n", srt_getlasterror_str());
		srt_close(sock);
		return -1;
	}

	if (!global_flags.srt_streamid.empty()) {
		if (srt_setsockopt(sock, 0, SRTO_STREAMID, global_flags.srt_streamid.data(), global_flags.srt_streamid.size()) < 0) {
			fprintf(stderr, "srt_setsockopt(SRTO_STREAMID): %s\n", srt_getlasterror_str());
			srt_close(sock);
			return -1;
		}
	}

	if (!global_flags.srt_passphrase.empty()) {
		if (srt_setsockopt(sock, 0, SRTO_PASSPHRASE, global_flags.srt_passphrase.data(), global_flags.srt_passphrase.size()) < 0) {
			fprintf(stderr, "srt_setsockopt(SRTO_PASSPHRASE): %s\n", srt_getlasterror_str());
			srt_close(sock);
			return -1;
		}
	}

	return sock;
}

int VideoEncoder::connect_to_srt()
{
	// We need to specify SOCK_DGRAM as a hint, or we'll get all addresses
	// three times (for each of TCP, UDP, raw).
	addrinfo hints;
	memset(&hints, 0, sizeof(hints));
	hints.ai_flags = AI_ADDRCONFIG;
	hints.ai_socktype = SOCK_DGRAM;

	addrinfo *ai;
	int ret = getaddrinfo(global_flags.srt_destination_host.c_str(), global_flags.srt_destination_port.c_str(), &hints, &ai);
	if (ret != 0) {
		fprintf(stderr, "getaddrinfo(%s:%s): %s\n", global_flags.srt_destination_host.c_str(), global_flags.srt_destination_port.c_str(), gai_strerror(ret));
		return -1;
	}

	unique_ptr<addrinfo, decltype(freeaddrinfo) *> ai_cleanup(ai, &freeaddrinfo);

	for (const addrinfo *cur = ai; cur != nullptr; cur = cur->ai_next) {
		// Seemingly, srt_create_socket() isn't universal; once we try to connect,
		// it gets locked to either IPv4 or IPv6. So we need to create a new one
		// for every address we try.
		int sock = open_srt_socket();
		if (sock == -1) {
			// Die immediately.
			return sock;
		}
		++metric_srt_num_connection_attempts;

		// We do a non-blocking connect, so that we can check should_quit
		// every now and then.
		int blocking = 0;
		if (srt_setsockopt(sock, 0, SRTO_RCVSYN, &blocking, sizeof(blocking)) < 0) {
			fprintf(stderr, "srt_setsockopt(SRTO_SNDSYN=0): %s\n", srt_getlasterror_str());
			srt_close(sock);
			continue;
		}
		if (srt_connect(sock, cur->ai_addr, cur->ai_addrlen) < 0) {
			fprintf(stderr, "srt_connect(%s): %s\n", print_addrinfo(cur).c_str(), srt_getlasterror_str());
			srt_close(sock);
			continue;
		}
		int eid = srt_epoll_create();
		if (eid < 0) {
			fprintf(stderr, "srt_epoll_create(): %s\n", srt_getlasterror_str());
			srt_close(sock);
			continue;
		}
		int modes = SRT_EPOLL_ERR | SRT_EPOLL_OUT;
		if (srt_epoll_add_usock(eid, sock, &modes) < 0) {
			fprintf(stderr, "srt_epoll_usock(): %s\n", srt_getlasterror_str());
			srt_close(sock);
			srt_epoll_release(eid);
			continue;
		}
		bool ok;
		while (!should_quit.load()) {
			SRTSOCKET errfds[1], writefds[1];
			int num_errfds = 1, num_writefds = 1;
			int poll_time_ms = 100;
			int ret = srt_epoll_wait(eid, errfds, &num_errfds, writefds, &num_writefds, poll_time_ms, 0, 0, 0, 0);
			if (ret < 0) {
				if (srt_getlasterror(nullptr) == SRT_ETIMEOUT) {
					continue;
				} else {
					fprintf(stderr, "srt_epoll_wait(): %s\n", srt_getlasterror_str());
					srt_close(sock);
					srt_epoll_release(eid);
					return -1;
				}
			} else if (ret > 0) {
				// The SRT epoll framework is pretty odd, but seemingly,
				// this is the way. Getting the same error code as srt_connect()
				// would normally return seems to be impossible, though.
				ok = (num_errfds == 0);
				break;
				fprintf(stderr, "num_errfds=%d num_writefds=%d last_err=%s\n", num_errfds, num_writefds, srt_getlasterror_str());
				break;
			}
		}
		srt_epoll_release(eid);
		if (should_quit.load()) {
			srt_close(sock);
			return -1;
		}
		if (ok) {
			fprintf(stderr, "Connected to destination SRT endpoint at %s.\n", print_addrinfo(cur).c_str());
			return sock;
		} else {
			fprintf(stderr, "srt_connect(%s): %s\n", print_addrinfo(cur).c_str(), srt_getlasterror_str());
			srt_close(sock);
		}
	}

	// Out of candidates, so give up.
	return -1;
}

int VideoEncoder::write_srt_packet(uint8_t *buf, int buf_size)
{
	if (want_srt_metric_update.exchange(false) && srt_sock != -1) {
		srt_metrics.update_srt_stats(srt_sock);
	}

	bool has_drained = false;
	bool trying_reconnect = false;
	steady_clock::time_point first_connect_start;

	while (buf_size > 0 && !should_quit.load()) {
		if (srt_sock == -1) {
			if (!trying_reconnect) {
				first_connect_start = steady_clock::now();
				trying_reconnect = true;
			}
			srt_sock = connect_to_srt();
			if (srt_sock == -1) {
				usleep(100000);
				if (!has_drained && duration<double>(steady_clock::now() - first_connect_start).count() >= global_flags.srt_output_latency_ms * 1e-3) {
					// The entire concept for SRT is to have fixed, low latency.
					// If we've been out for more than a latency period, we shouldn't
					// try to send the entire backlog. (But we should be tolerant
					// of a quick disconnect and reconnect.) Maybe it would be better
					// to have a sliding window of how much we remove, but it quickly
					// starts getting esoteric, so juts drop it all.
					fprintf(stderr, "WARNING: No SRT connection for more than %d ms, dropping data.\n",
						global_flags.srt_output_latency_ms);
					srt_mux->drain();
					has_drained = true;
				}
				continue;
			}
			srt_metrics.update_srt_stats(srt_sock);
		}
		if (has_drained) {
			// Now that we're reconnected, we can start accepting data again,
			// but discard the rest of this write (it is very old by now).
			srt_mux->undrain();
			break;
		}
		int to_send = min(buf_size, SRT_LIVE_DEF_PLSIZE);
		int ret = srt_send(srt_sock, (char *)buf, to_send);
		if (ret < 0)  {
			fprintf(stderr, "srt_send(): %s\n", srt_getlasterror_str());
			srt_close(srt_sock);
			srt_metrics.metric_srt_uptime_seconds = 0.0 / 0.0;
			if (!trying_reconnect) {
				first_connect_start = steady_clock::now();
				trying_reconnect = true;
			}
			srt_sock = connect_to_srt();
			continue;
		}
		buf += ret;
		buf_size -= ret;
	}
	return buf_size;
}

