-import type { IncomingMessage } from 'http';
+import { IncomingMessage, createServer } from 'http';
+import type internal from 'stream';
-import WebSocket, { RawData } from 'ws';
+import { StatusCodes } from 'http-status-codes';
+import WebSocket, { RawData, WebSocketServer } from 'ws';
import BaseError from '../../exception/BaseError';
-import type { ServerOptions } from '../../types/ConfigurationData';
+import type { UIServerConfiguration } from '../../types/ConfigurationData';
import type { ProtocolRequest, ProtocolResponse } from '../../types/UIProtocol';
import { WebSocketCloseEventStatusCode } from '../../types/WebSocket';
-import Configuration from '../../utils/Configuration';
import logger from '../../utils/Logger';
import Utils from '../../utils/Utils';
import { AbstractUIServer } from './AbstractUIServer';
const moduleName = 'UIWebSocketServer';
export default class UIWebSocketServer extends AbstractUIServer {
- public constructor(options?: ServerOptions) {
- super();
- this.server = new WebSocket.Server(options ?? Configuration.getUIServer().options);
+ private readonly webSocketServer: WebSocketServer;
+
+ public constructor(protected readonly uiServerConfiguration: UIServerConfiguration) {
+ super(uiServerConfiguration);
+ this.httpServer = createServer();
+ this.webSocketServer = new WebSocketServer({
+ handleProtocols: UIServiceUtils.handleProtocols,
+ noServer: true,
+ });
}
public start(): void {
- this.server.on('connection', (ws: WebSocket, request: IncomingMessage): void => {
+ this.webSocketServer.on('connection', (ws: WebSocket, req: IncomingMessage): void => {
const [protocol, version] = UIServiceUtils.getProtocolAndVersion(ws.protocol);
if (UIServiceUtils.isProtocolAndVersionSupported(protocol, version) === false) {
logger.error(
);
});
});
+ this.httpServer.on(
+ 'upgrade',
+ (req: IncomingMessage, socket: internal.Duplex, head: Buffer): void => {
+ this.authenticate(req, (err) => {
+ if (err) {
+ socket.write(`HTTP/1.1 ${StatusCodes.UNAUTHORIZED} Unauthorized\r\n\r\n`);
+ socket.destroy();
+ return;
+ }
+ this.webSocketServer.handleUpgrade(req, socket, head, (ws: WebSocket) => {
+ this.webSocketServer.emit('connection', ws, req);
+ });
+ });
+ }
+ );
+ if (this.httpServer.listening === false) {
+ this.httpServer.listen(this.uiServerConfiguration.options);
+ }
}
public stop(): void {
}
private broadcastToClients(message: string): void {
- for (const client of (this.server as WebSocket.Server).clients) {
+ for (const client of this.webSocketServer.clients) {
if (client?.readyState === WebSocket.OPEN) {
client.send(message);
}
}
}
+ private authenticate(req: IncomingMessage, next: (err: Error) => void): void {
+ if (this.isBasicAuthEnabled() === true) {
+ if (this.isValidBasicAuth(req) === false) {
+ next(new Error('Unauthorized'));
+ } else {
+ next(undefined);
+ }
+ } else {
+ next(undefined);
+ }
+ }
+
private validateRawDataRequest(rawData: RawData): ProtocolRequest {
// logger.debug(
// `${this.logPrefix(