支持IP白名单访问限制

This commit is contained in:
chaos-zhu 2024-08-20 10:44:16 +08:00
parent 997761f2fc
commit 9f04c8adbb
13 changed files with 127 additions and 108 deletions

View File

@ -1,2 +1,5 @@
# 启动debug日志 0关闭 1开启
DEBUG=1
# 访问IP限制
allowedIPs=['127.0.0.1']

View File

@ -1,3 +1,4 @@
const ipFilter = require('./ipFilter') // IP过滤
const responseHandler = require('./response') // 统一返回格式, 错误捕获
const useAuth = require('./auth') // 鉴权
// const useCors = require('./cors') // 处理跨域[暂时禁止]
@ -8,8 +9,8 @@ const useStatic = require('./static') // 静态目录
const compress = require('./compress') // br/gzip压缩
const history = require('./history') // vue-router的history模式
// 注意注册顺序
module.exports = [
ipFilter,
compress,
history,
useStatic, // staic先注册不然会被jwt拦截

View File

@ -0,0 +1,16 @@
// 白名单IP
const fs = require('fs')
const path = require('path')
const { isAllowedIp } = require('../utils/tools')
const htmlPath = path.join(__dirname, '../template/ipForbidden.html')
const ipForbiddenHtml = fs.readFileSync(htmlPath, 'utf8')
const ipFilter = async (ctx, next) => {
// console.log('requestIP:', ctx.request.ip)
if (isAllowedIp(ctx.request.ip)) return await next()
ctx.status = 403
ctx.body = ipForbiddenHtml
}
module.exports = ipFilter

View File

@ -5,7 +5,6 @@ const { httpPort } = require('./config')
const middlewares = require('./middlewares')
const wsTerminal = require('./socket/terminal')
const wsSftp = require('./socket/sftp')
// const wsHostStatus = require('./socket/host-status')
const wsClientInfo = require('./socket/clients')
const wsOnekey = require('./socket/onekey')
const { throwError } = require('./utils/tools')
@ -25,7 +24,6 @@ function serverHandler(app, server) {
app.proxy = true // 用于nginx反代时获取真实客户端ip
wsTerminal(server) // 终端
wsSftp(server) // sftp
// wsHostStatus(server) // 终端侧边栏host信息(单个host)
wsOnekey(server) // 一键指令
wsClientInfo(server) // 客户端信息
app.context.throwError = throwError // 常用方法挂载全局ctx上

View File

@ -3,6 +3,7 @@ const { io: ClientIO } = require('socket.io-client')
const { readHostList } = require('../utils/storage')
const { clientPort } = require('../config')
const { verifyAuthSync } = require('../utils/verify-auth')
const { isAllowedIp } = require('../utils/tools')
let clientSockets = []
let clientsData = {}
@ -66,9 +67,14 @@ module.exports = (httpServer) => {
serverIo.on('connection', (socket) => {
// 前者兼容nginx反代, 后者兼容nodejs自身服务
let clientIp = socket.handshake.headers['x-forwarded-for'] || socket.handshake.address
let requestIP = socket.handshake.headers['x-forwarded-for'] || socket.handshake.address
if (!isAllowedIp(requestIP)) {
socket.emit('ip_forbidden', 'IP地址不在白名单中')
socket.disconnect()
return
}
socket.on('init_clients_data', async ({ token }) => {
const { code, msg } = await verifyAuthSync(token, clientIp)
const { code, msg } = await verifyAuthSync(token, requestIP)
if (code !== 1) {
socket.emit('token_verify_fail', msg || '鉴权失败')
socket.disconnect()

View File

@ -1,74 +0,0 @@
const { Server: ServerIO } = require('socket.io')
const { io: ClientIO } = require('socket.io-client')
const { clientPort } = require('../config')
const { verifyAuthSync } = require('../utils/verify-auth')
let hostSockets = {}
function getHostInfo(serverSocket, host) {
let hostSocket = ClientIO(`http://${ host }:${ clientPort }`, {
path: '/client/os-info',
forceNew: false,
timeout: 5000,
reconnectionDelay: 3000,
reconnectionAttempts: 3
})
// 将与客户端连接的socket实例保存起来web端断开时关闭与客户端的连接
hostSockets[serverSocket.id] = hostSocket
hostSocket
.on('connect', () => {
consola.success('host-status-socket连接成功:', host)
hostSocket.on('client_data', (data) => {
serverSocket.emit('host_data', data)
})
hostSocket.on('client_error', () => {
serverSocket.emit('host_data', null)
})
})
.on('connect_error', (error) => {
consola.error('host-status-socket连接[失败]:', host, error.message)
serverSocket.emit('host_data', null)
})
.on('disconnect', () => {
consola.info('host-status-socket连接[断开]:', host)
serverSocket.emit('host_data', null)
})
}
module.exports = (httpServer) => {
const serverIo = new ServerIO(httpServer, {
path: '/host-status',
cors: {
origin: '*' // 需配置跨域
}
})
serverIo.on('connection', (serverSocket) => {
// 前者兼容nginx反代, 后者兼容nodejs自身服务
let clientIp = serverSocket.handshake.headers['x-forwarded-for'] || serverSocket.handshake.address
serverSocket.on('init_host_data', async ({ token, host }) => {
// 校验登录态
const { code, msg } = await verifyAuthSync(token, clientIp)
if(code !== 1) {
serverSocket.emit('token_verify_fail', msg || '鉴权失败')
serverSocket.disconnect()
return
}
// 获取客户端数据
getHostInfo(serverSocket, host)
consola.info('host-status-socket连接socketId: ', serverSocket.id, 'host-status-socket已连接数: ', Object.keys(hostSockets).length)
// 关闭连接
serverSocket.on('disconnect', () => {
// 当web端与服务端断开连接时, 服务端与每个客户端的socket也应该断开连接
let socket = hostSockets[serverSocket.id]
socket.close && socket.close()
delete hostSockets[serverSocket.id]
consola.info('host-status-socket剩余连接数: ', Object.keys(hostSockets).length)
})
})
})
}

View File

@ -5,6 +5,7 @@ const { readSSHRecord, readHostList, writeOneKeyRecord } = require('../utils/sto
const { verifyAuthSync } = require('../utils/verify-auth')
const { shellThrottle } = require('../utils/tools')
const { AESDecryptSync } = require('../utils/encrypt')
const { isAllowedIp } = require('../utils/tools')
const execStatusEnum = {
connecting: '连接中',
@ -90,7 +91,12 @@ module.exports = (httpServer) => {
})
serverIo.on('connection', (socket) => {
// 前者兼容nginx反代, 后者兼容nodejs自身服务
let clientIp = socket.handshake.headers['x-forwarded-for'] || socket.handshake.address
let requestIP = socket.handshake.headers['x-forwarded-for'] || socket.handshake.address
if (!isAllowedIp(requestIP)) {
socket.emit('ip_forbidden', 'IP地址不在白名单中')
socket.disconnect()
return
}
consola.success('onekey-terminal websocket 已连接')
if (isExecuting) {
socket.emit('create_fail', '正在执行中, 请稍后再试')
@ -99,7 +105,7 @@ module.exports = (httpServer) => {
}
isExecuting = true
socket.on('create', async ({ hosts, token, command, timeout }) => {
const { code } = await verifyAuthSync(token, clientIp)
const { code } = await verifyAuthSync(token, requestIP)
if (code !== 1) {
socket.emit('token_verify_fail')
socket.disconnect()

View File

@ -7,6 +7,7 @@ const { sftpCacheDir } = require('../config')
const { verifyAuthSync } = require('../utils/verify-auth')
const { AESDecryptSync } = require('../utils/encrypt')
const { readSSHRecord, readHostList } = require('../utils/storage')
const { isAllowedIp } = require('../utils/tools')
// 读取切片
const pipeStream = (path, writeStream) => {
@ -210,12 +211,17 @@ module.exports = (httpServer) => {
})
serverIo.on('connection', (socket) => {
// 前者兼容nginx反代, 后者兼容nodejs自身服务
let clientIp = socket.handshake.headers['x-forwarded-for'] || socket.handshake.address
let requestIP = socket.handshake.headers['x-forwarded-for'] || socket.handshake.address
if (!isAllowedIp(requestIP)) {
socket.emit('ip_forbidden', 'IP地址不在白名单中')
socket.disconnect()
return
}
let sftpClient = new SFTPClient()
consola.success('terminal websocket 已连接')
socket.on('create', async ({ host: ip, token }) => {
const { code } = await verifyAuthSync(token, clientIp)
const { code } = await verifyAuthSync(token, requestIP)
if (code !== 1) {
socket.emit('token_verify_fail')
socket.disconnect()

View File

@ -4,6 +4,7 @@ const { verifyAuthSync } = require('../utils/verify-auth')
const { AESDecryptSync } = require('../utils/encrypt')
const { readSSHRecord, readHostList } = require('../utils/storage')
const { asyncSendNotice } = require('../utils/notify')
const { isAllowedIp } = require('../utils/tools')
function createInteractiveShell(socket, sshClient) {
return new Promise((resolve) => {
@ -113,11 +114,16 @@ module.exports = (httpServer) => {
})
serverIo.on('connection', (socket) => {
// 前者兼容nginx反代, 后者兼容nodejs自身服务
let clientIp = socket.handshake.headers['x-forwarded-for'] || socket.handshake.address
let requestIP = socket.handshake.headers['x-forwarded-for'] || socket.handshake.address
if (!isAllowedIp(requestIP)) {
socket.emit('ip_forbidden', 'IP地址不在白名单中')
socket.disconnect()
return
}
consola.success('terminal websocket 已连接')
let sshClient = null
socket.on('create', async ({ host: ip, token }) => {
const { code } = await verifyAuthSync(token, clientIp)
const { code } = await verifyAuthSync(token, requestIP)
if (code !== 1) {
socket.emit('token_verify_fail')
socket.disconnect()

View File

@ -0,0 +1,43 @@
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>403 禁止访问</title>
<link rel="icon" href="data:;base64,=">
<style>
body {
font-family: Arial, sans-serif;
background-color: #f4f4f4;
display: flex;
justify-content: center;
align-items: center;
height: 100vh;
margin: 0;
}
.container {
text-align: center;
padding: 20px;
border-radius: 8px;
}
h1 {
color: #d9534f;
}
p {
color: #333;
}
</style>
</head>
<body>
<div class="container">
<h1>403 禁止访问</h1>
<p>抱歉,您没有权限访问此页面。</p>
</div>
</body>
</html>

View File

@ -231,6 +231,13 @@ const isProd = () => {
return EXEC_ENV === 'production'
}
let allowedIPs = process.env.ALLOWED_IPS ? process.env.ALLOWED_IPS.split(',') : ''
if (allowedIPs) consola.warn('allowedIPs:', allowedIPs)
const isAllowedIp = (requestIP) => {
if (allowedIPs.length === 0) return true
return allowedIPs.some(item => item.includes(requestIP))
}
module.exports = {
getNetIPInfo,
throwError,
@ -240,5 +247,6 @@ module.exports = {
formatTimestamp,
resolvePath,
shellThrottle,
isProd
isProd,
isAllowedIp
}