/*
 * JBoss, Home of Professional Open Source.
 * Copyright 2014 Red Hat, Inc., and individual contributors
 * as indicated by the @author tags.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 *  Unless required by applicable law or agreed to in writing, software
 *  distributed under the License is distributed on an "AS IS" BASIS,
 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 *  See the License for the specific language governing permissions and
 *  limitations under the License.
 */

package io.undertow.websockets.core.protocol;

import io.undertow.util.Headers;
import io.undertow.websockets.WebSocketExtension;
import io.undertow.websockets.core.WebSocketChannel;
import io.undertow.websockets.core.WebSocketVersion;
import io.undertow.websockets.extensions.CompositeExtensionFunction;
import io.undertow.websockets.extensions.ExtensionFunction;
import io.undertow.websockets.extensions.ExtensionHandshake;
import io.undertow.websockets.spi.WebSocketHttpExchange;
import org.xnio.IoFuture;
import io.undertow.connector.ByteBufferPool;
import org.xnio.StreamConnection;

import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.regex.Pattern;

/**
 * Abstract base class for doing a WebSocket Handshake.
 *
 * @author Mike Brock
 */
public abstract class Handshake {
    private final WebSocketVersion version;
    private final String hashAlgorithm;
    private final String magicNumber;
    protected final Set<String> subprotocols;
    private static final byte[] EMPTY = new byte[0];
    private static final Pattern PATTERN = Pattern.compile("\\s*,\\s*");
    protected Set<ExtensionHandshake> availableExtensions = new HashSet<>();
    protected boolean allowExtensions;

    protected Handshake(WebSocketVersion version, String hashAlgorithm, String magicNumber, final Set<String> subprotocols) {
        this.version = version;
        this.hashAlgorithm = hashAlgorithm;
        this.magicNumber = magicNumber;
        this.subprotocols = subprotocols;
    }

    /**
     * Return the version for which the {@link Handshake} can be used.
     */
    public WebSocketVersion getVersion() {
        return version;
    }

    /**
     * Return the algorithm that is used to hash during the handshake
     */
    public String getHashAlgorithm() {
        return hashAlgorithm;
    }

    /**
     * Return the magic number which will be mixed in
     */
    public String getMagicNumber() {
        return magicNumber;
    }

    /**
     * Return the full url of the websocket location of the given {@link WebSocketHttpExchange}
     */
    protected static String getWebSocketLocation(WebSocketHttpExchange exchange) {
        String scheme;
        if ("https".equals(exchange.getRequestScheme())) {
            scheme = "wss";
        } else {
            scheme = "ws";
        }
        return scheme + "://" + exchange.getRequestHeader(Headers.HOST_STRING) + exchange.getRequestURI();
    }

    /**
     * Issue the WebSocket upgrade
     *
     * @param exchange The {@link WebSocketHttpExchange} for which the handshake and upgrade should occur.
     */
    public final void handshake(final WebSocketHttpExchange exchange) {
        exchange.putAttachment(WebSocketVersion.ATTACHMENT_KEY, version);
        handshakeInternal(exchange);
    }

    protected abstract void handshakeInternal(final WebSocketHttpExchange exchange);

    /**
     * Return {@code true} if this implementation can be used to issue a handshake.
     */
    public abstract boolean matches(WebSocketHttpExchange exchange);

    /**
     * Create the {@link WebSocketChannel} from the {@link WebSocketHttpExchange}
     */
    public abstract WebSocketChannel createChannel(WebSocketHttpExchange exchange, final StreamConnection channel, final ByteBufferPool pool);

    /**
     * convenience method to perform the upgrade
     */
    protected final void performUpgrade(final WebSocketHttpExchange exchange, final byte[] data) {
        exchange.setResponseHeader(Headers.CONTENT_LENGTH_STRING, String.valueOf(data.length));
        exchange.setResponseHeader(Headers.UPGRADE_STRING, "WebSocket");
        exchange.setResponseHeader(Headers.CONNECTION_STRING, "Upgrade");
        upgradeChannel(exchange, data);
    }

    protected void upgradeChannel(final WebSocketHttpExchange exchange, final byte[] data) {
        if (data.length > 0) {
            writePayload(exchange, ByteBuffer.wrap(data));
        } else {
            exchange.endExchange();
        }
    }

    private static void writePayload(final WebSocketHttpExchange exchange, final ByteBuffer payload) {
        exchange.sendData(payload).addNotifier(new IoFuture.Notifier<Void, Object>() {
            @Override
            public void notify(final IoFuture<? extends Void> ioFuture, final Object attachment) {
                if (ioFuture.getStatus() == IoFuture.Status.DONE) {
                    exchange.endExchange();
                } else {
                    exchange.close();
                }
            }
        }, null);
    }

    /**
     * Perform the upgrade using no payload
     */
    protected final void performUpgrade(final WebSocketHttpExchange exchange) {
        performUpgrade(exchange, EMPTY);
    }

    /**
     * Selects the first matching supported sub protocol and add it the the headers of the exchange.
     *
     */
    protected final void selectSubprotocol(final WebSocketHttpExchange exchange) {
        String requestedSubprotocols = exchange.getRequestHeader(Headers.SEC_WEB_SOCKET_PROTOCOL_STRING);
        if (requestedSubprotocols == null) {
            return;
        }

        String[] requestedSubprotocolArray = PATTERN.split(requestedSubprotocols);
        String subProtocol = supportedSubprotols(requestedSubprotocolArray);
        if (subProtocol != null && !subProtocol.isEmpty()) {
            exchange.setResponseHeader(Headers.SEC_WEB_SOCKET_PROTOCOL_STRING, subProtocol);
        }

    }


    protected final void selectExtensions(final WebSocketHttpExchange exchange) {
        List<WebSocketExtension> requestedExtensions = WebSocketExtension.parse(exchange.getRequestHeader(Headers.SEC_WEB_SOCKET_EXTENSIONS_STRING));
        List<WebSocketExtension> extensions = selectedExtension(requestedExtensions);
        if (extensions != null && !extensions.isEmpty()) {
            exchange.setResponseHeader(Headers.SEC_WEB_SOCKET_EXTENSIONS_STRING, WebSocketExtension.toExtensionHeader(extensions));
        }

    }

    protected String supportedSubprotols(String[] requestedSubprotocolArray) {
        for (String p : requestedSubprotocolArray) {
            String requestedSubprotocol = p.trim();

            for (String supportedSubprotocol : subprotocols) {
                if (requestedSubprotocol.equals(supportedSubprotocol)) {
                    return supportedSubprotocol;
                }
            }
        }
        return null;
    }

    protected List<WebSocketExtension> selectedExtension(List<WebSocketExtension> extensionList) {
        List<WebSocketExtension> selected = new ArrayList<>();
        List<ExtensionHandshake> configured = new ArrayList<>();
        for (WebSocketExtension ext : extensionList) {
            for (ExtensionHandshake extHandshake : availableExtensions) {
                WebSocketExtension negotiated = extHandshake.accept(ext);
                if (ext != null && !extHandshake.isIncompatible(configured)) {
                    selected.add(negotiated);
                    configured.add(extHandshake);
                }
            }
        }
        return selected;
    }

    /**
     * Add a new WebSocket Extension handshake to the list of available extensions.
     *
     * @param extension a new {@code ExtensionHandshake}
     */
    public final void addExtension(ExtensionHandshake extension) {
        availableExtensions.add(extension);
        allowExtensions = true;
    }

    /**
     * Create the {@code ExtensionFunction} list associated with the negotiated extensions defined in the exchange's response.
     *
     * @param exchange the exchange used to retrieve negotiated extensions
     * @return         a list of {@code ExtensionFunction} with the implementation of the extensions
     */
    protected final ExtensionFunction initExtensions(final WebSocketHttpExchange exchange) {
        String extHeader = exchange.getResponseHeaders().get(Headers.SEC_WEB_SOCKET_EXTENSIONS_STRING) != null ?
                exchange.getResponseHeaders().get(Headers.SEC_WEB_SOCKET_EXTENSIONS_STRING).get(0) : null;

        List<ExtensionFunction> negotiated = new ArrayList<>();
        if (extHeader != null) {
            List<WebSocketExtension> extensions = WebSocketExtension.parse(extHeader);
            if (extensions != null && !extensions.isEmpty()) {
                for (WebSocketExtension ext : extensions) {
                    for (ExtensionHandshake extHandshake : availableExtensions) {
                        if (extHandshake.getName().equals(ext.getName())) {
                            negotiated.add(extHandshake.create());
                        }
                    }
                }
            }
        }
        return CompositeExtensionFunction.compose(negotiated);
    }
}
