mirror of
https://github.com/puppetlabs/vmpooler.git
synced 2026-01-26 01:58:41 -05:00
Add rate limiting and input validation security enhancements
This commit is contained in:
parent
1a6b08ab81
commit
d0020becb3
7 changed files with 545 additions and 17 deletions
|
|
@ -1,10 +1,13 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require 'vmpooler/api/input_validator'
|
||||
|
||||
module Vmpooler
|
||||
|
||||
class API
|
||||
|
||||
module Helpers
|
||||
include InputValidator
|
||||
|
||||
def tracer
|
||||
@tracer ||= OpenTelemetry.tracer_provider.tracer('api', Vmpooler::VERSION)
|
||||
|
|
|
|||
159
lib/vmpooler/api/input_validator.rb
Normal file
159
lib/vmpooler/api/input_validator.rb
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module Vmpooler
|
||||
class API
|
||||
# Input validation helpers to enhance security
|
||||
module InputValidator
|
||||
# Maximum lengths to prevent abuse
|
||||
MAX_HOSTNAME_LENGTH = 253
|
||||
MAX_TAG_KEY_LENGTH = 50
|
||||
MAX_TAG_VALUE_LENGTH = 255
|
||||
MAX_REASON_LENGTH = 500
|
||||
MAX_POOL_NAME_LENGTH = 100
|
||||
MAX_TOKEN_LENGTH = 64
|
||||
|
||||
# Valid patterns
|
||||
HOSTNAME_PATTERN = /\A[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?(\.[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?)* \z/ix.freeze
|
||||
POOL_NAME_PATTERN = /\A[a-zA-Z0-9_-]+\z/.freeze
|
||||
TAG_KEY_PATTERN = /\A[a-zA-Z0-9_\-.]+\z/.freeze
|
||||
TOKEN_PATTERN = /\A[a-zA-Z0-9\-_]+\z/.freeze
|
||||
INTEGER_PATTERN = /\A\d+\z/.freeze
|
||||
|
||||
class ValidationError < StandardError; end
|
||||
|
||||
# Validate hostname format and length
|
||||
def validate_hostname(hostname)
|
||||
return error_response('Hostname is required') if hostname.nil? || hostname.empty?
|
||||
return error_response('Hostname too long') if hostname.length > MAX_HOSTNAME_LENGTH
|
||||
return error_response('Invalid hostname format') unless hostname.match?(HOSTNAME_PATTERN)
|
||||
|
||||
true
|
||||
end
|
||||
|
||||
# Validate pool/template name
|
||||
def validate_pool_name(pool_name)
|
||||
return error_response('Pool name is required') if pool_name.nil? || pool_name.empty?
|
||||
return error_response('Pool name too long') if pool_name.length > MAX_POOL_NAME_LENGTH
|
||||
return error_response('Invalid pool name format') unless pool_name.match?(POOL_NAME_PATTERN)
|
||||
|
||||
true
|
||||
end
|
||||
|
||||
# Validate tag key and value
|
||||
def validate_tag(key, value)
|
||||
return error_response('Tag key is required') if key.nil? || key.empty?
|
||||
return error_response('Tag key too long') if key.length > MAX_TAG_KEY_LENGTH
|
||||
return error_response('Invalid tag key format') unless key.match?(TAG_KEY_PATTERN)
|
||||
|
||||
if value
|
||||
return error_response('Tag value too long') if value.length > MAX_TAG_VALUE_LENGTH
|
||||
|
||||
# Sanitize value to prevent injection attacks
|
||||
sanitized_value = value.gsub(/[^\w\s\-.@:\/]/, '')
|
||||
return error_response('Tag value contains invalid characters') if sanitized_value != value
|
||||
end
|
||||
|
||||
true
|
||||
end
|
||||
|
||||
# Validate token format
|
||||
def validate_token_format(token)
|
||||
return error_response('Token is required') if token.nil? || token.empty?
|
||||
return error_response('Token too long') if token.length > MAX_TOKEN_LENGTH
|
||||
return error_response('Invalid token format') unless token.match?(TOKEN_PATTERN)
|
||||
|
||||
true
|
||||
end
|
||||
|
||||
# Validate integer parameter
|
||||
def validate_integer(value, name = 'value', min: nil, max: nil)
|
||||
return error_response("#{name} is required") if value.nil?
|
||||
|
||||
value_str = value.to_s
|
||||
return error_response("#{name} must be a valid integer") unless value_str.match?(INTEGER_PATTERN)
|
||||
|
||||
int_value = value.to_i
|
||||
return error_response("#{name} must be at least #{min}") if min && int_value < min
|
||||
return error_response("#{name} must be at most #{max}") if max && int_value > max
|
||||
|
||||
int_value
|
||||
end
|
||||
|
||||
# Validate VM request count
|
||||
def validate_vm_count(count)
|
||||
validated = validate_integer(count, 'VM count', min: 1, max: 100)
|
||||
return validated if validated.is_a?(Hash) # error response
|
||||
|
||||
validated
|
||||
end
|
||||
|
||||
# Validate disk size
|
||||
def validate_disk_size(size)
|
||||
validated = validate_integer(size, 'Disk size', min: 1, max: 2048)
|
||||
return validated if validated.is_a?(Hash) # error response
|
||||
|
||||
validated
|
||||
end
|
||||
|
||||
# Validate lifetime (TTL) in hours
|
||||
def validate_lifetime(lifetime)
|
||||
validated = validate_integer(lifetime, 'Lifetime', min: 1, max: 168) # max 1 week
|
||||
return validated if validated.is_a?(Hash) # error response
|
||||
|
||||
validated
|
||||
end
|
||||
|
||||
# Validate reason text
|
||||
def validate_reason(reason)
|
||||
return true if reason.nil? || reason.empty?
|
||||
return error_response('Reason too long') if reason.length > MAX_REASON_LENGTH
|
||||
|
||||
# Sanitize to prevent XSS/injection
|
||||
sanitized = reason.gsub(/[<>"']/, '')
|
||||
return error_response('Reason contains invalid characters') if sanitized != reason
|
||||
|
||||
true
|
||||
end
|
||||
|
||||
# Sanitize JSON body to prevent injection
|
||||
def sanitize_json_body(body)
|
||||
return {} if body.nil? || body.empty?
|
||||
|
||||
begin
|
||||
parsed = JSON.parse(body)
|
||||
return error_response('Request body must be a JSON object') unless parsed.is_a?(Hash)
|
||||
|
||||
# Limit depth and size to prevent DoS
|
||||
return error_response('Request body too complex') if json_depth(parsed) > 5
|
||||
return error_response('Request body too large') if body.length > 10_240 # 10KB max
|
||||
|
||||
parsed
|
||||
rescue JSON::ParserError => e
|
||||
error_response("Invalid JSON: #{e.message}")
|
||||
end
|
||||
end
|
||||
|
||||
# Check if validation result is an error
|
||||
def validation_error?(result)
|
||||
result.is_a?(Hash) && result['ok'] == false
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def error_response(message)
|
||||
{ 'ok' => false, 'error' => message }
|
||||
end
|
||||
|
||||
def json_depth(obj, depth = 0)
|
||||
return depth unless obj.is_a?(Hash) || obj.is_a?(Array)
|
||||
return depth + 1 if obj.empty?
|
||||
|
||||
if obj.is_a?(Hash)
|
||||
depth + 1 + obj.values.map { |v| json_depth(v, 0) }.max
|
||||
else
|
||||
depth + 1 + obj.map { |v| json_depth(v, 0) }.max
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
116
lib/vmpooler/api/rate_limiter.rb
Normal file
116
lib/vmpooler/api/rate_limiter.rb
Normal file
|
|
@ -0,0 +1,116 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module Vmpooler
|
||||
class API
|
||||
# Rate limiter middleware to protect against abuse
|
||||
# Uses Redis to track request counts per IP and token
|
||||
class RateLimiter
|
||||
DEFAULT_LIMITS = {
|
||||
global_per_ip: { limit: 100, period: 60 }, # 100 requests per minute per IP
|
||||
authenticated: { limit: 500, period: 60 }, # 500 requests per minute with token
|
||||
vm_creation: { limit: 20, period: 60 }, # 20 VM creations per minute
|
||||
vm_deletion: { limit: 50, period: 60 } # 50 VM deletions per minute
|
||||
}.freeze
|
||||
|
||||
def initialize(app, redis, config = {})
|
||||
@app = app
|
||||
@redis = redis
|
||||
@config = DEFAULT_LIMITS.merge(config[:rate_limits] || {})
|
||||
@enabled = config.fetch(:rate_limiting_enabled, true)
|
||||
end
|
||||
|
||||
def call(env)
|
||||
return @app.call(env) unless @enabled
|
||||
|
||||
request = Rack::Request.new(env)
|
||||
client_id = identify_client(request)
|
||||
endpoint_type = classify_endpoint(request)
|
||||
|
||||
# Check rate limits
|
||||
return rate_limit_response(client_id, endpoint_type) if rate_limit_exceeded?(client_id, endpoint_type, request)
|
||||
|
||||
# Track the request
|
||||
increment_request_count(client_id, endpoint_type)
|
||||
|
||||
@app.call(env)
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def identify_client(request)
|
||||
# Prioritize token-based identification for authenticated requests
|
||||
token = request.env['HTTP_X_AUTH_TOKEN']
|
||||
return "token:#{token}" if token && !token.empty?
|
||||
|
||||
# Fall back to IP address
|
||||
ip = request.ip || request.env['REMOTE_ADDR'] || 'unknown'
|
||||
"ip:#{ip}"
|
||||
end
|
||||
|
||||
def classify_endpoint(request)
|
||||
path = request.path
|
||||
method = request.request_method
|
||||
|
||||
return :vm_creation if method == 'POST' && path.include?('/vm')
|
||||
return :vm_deletion if method == 'DELETE' && path.include?('/vm')
|
||||
return :authenticated if request.env['HTTP_X_AUTH_TOKEN']
|
||||
|
||||
:global_per_ip
|
||||
end
|
||||
|
||||
def rate_limit_exceeded?(client_id, endpoint_type, _request)
|
||||
limit_config = @config[endpoint_type] || @config[:global_per_ip]
|
||||
key = "vmpooler__ratelimit__#{endpoint_type}__#{client_id}"
|
||||
|
||||
current_count = @redis.get(key).to_i
|
||||
current_count >= limit_config[:limit]
|
||||
rescue StandardError => e
|
||||
# If Redis fails, allow the request through (fail open)
|
||||
warn "Rate limiter Redis error: #{e.message}"
|
||||
false
|
||||
end
|
||||
|
||||
def increment_request_count(client_id, endpoint_type)
|
||||
limit_config = @config[endpoint_type] || @config[:global_per_ip]
|
||||
key = "vmpooler__ratelimit__#{endpoint_type}__#{client_id}"
|
||||
|
||||
@redis.pipelined do |pipeline|
|
||||
pipeline.incr(key)
|
||||
pipeline.expire(key, limit_config[:period])
|
||||
end
|
||||
rescue StandardError => e
|
||||
# Log error but don't fail the request
|
||||
warn "Rate limiter increment error: #{e.message}"
|
||||
end
|
||||
|
||||
def rate_limit_response(client_id, endpoint_type)
|
||||
limit_config = @config[endpoint_type] || @config[:global_per_ip]
|
||||
key = "vmpooler__ratelimit__#{endpoint_type}__#{client_id}"
|
||||
|
||||
begin
|
||||
ttl = @redis.ttl(key)
|
||||
rescue StandardError
|
||||
ttl = limit_config[:period]
|
||||
end
|
||||
|
||||
headers = {
|
||||
'Content-Type' => 'application/json',
|
||||
'X-RateLimit-Limit' => limit_config[:limit].to_s,
|
||||
'X-RateLimit-Remaining' => '0',
|
||||
'X-RateLimit-Reset' => (Time.now.to_i + ttl).to_s,
|
||||
'Retry-After' => ttl.to_s
|
||||
}
|
||||
|
||||
body = JSON.pretty_generate({
|
||||
'ok' => false,
|
||||
'error' => 'Rate limit exceeded',
|
||||
'limit' => limit_config[:limit],
|
||||
'period' => limit_config[:period],
|
||||
'retry_after' => ttl
|
||||
})
|
||||
|
||||
[429, headers, [body]]
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
@ -1085,9 +1085,29 @@ module Vmpooler
|
|||
result = { 'ok' => false }
|
||||
metrics.increment('http_requests_vm_total.post.vm.checkout')
|
||||
|
||||
payload = JSON.parse(request.body.read)
|
||||
# Validate and sanitize JSON body
|
||||
payload = sanitize_json_body(request.body.read)
|
||||
if validation_error?(payload)
|
||||
status 400
|
||||
return JSON.pretty_generate(payload)
|
||||
end
|
||||
|
||||
if payload
|
||||
# Validate each template and count
|
||||
payload.each do |template, count|
|
||||
validation = validate_pool_name(template)
|
||||
if validation_error?(validation)
|
||||
status 400
|
||||
return JSON.pretty_generate(validation)
|
||||
end
|
||||
|
||||
validated_count = validate_vm_count(count)
|
||||
if validation_error?(validated_count)
|
||||
status 400
|
||||
return JSON.pretty_generate(validated_count)
|
||||
end
|
||||
end
|
||||
|
||||
if payload && !payload.empty?
|
||||
invalid = invalid_templates(payload)
|
||||
if invalid.empty?
|
||||
result = atomically_allocate_vms(payload)
|
||||
|
|
@ -1206,6 +1226,7 @@ module Vmpooler
|
|||
result = { 'ok' => false }
|
||||
metrics.increment('http_requests_vm_total.get.vm.template')
|
||||
|
||||
# Template can contain multiple pools separated by +, so validate after parsing
|
||||
payload = extract_templates_from_query_params(params[:template])
|
||||
|
||||
if payload
|
||||
|
|
@ -1235,6 +1256,13 @@ module Vmpooler
|
|||
status 404
|
||||
result['ok'] = false
|
||||
|
||||
# Validate hostname
|
||||
validation = validate_hostname(params[:hostname])
|
||||
if validation_error?(validation)
|
||||
status 400
|
||||
return JSON.pretty_generate(validation)
|
||||
end
|
||||
|
||||
params[:hostname] = hostname_shorten(params[:hostname])
|
||||
|
||||
rdata = backend.hgetall("vmpooler__vm__#{params[:hostname]}")
|
||||
|
|
@ -1373,6 +1401,13 @@ module Vmpooler
|
|||
status 404
|
||||
result['ok'] = false
|
||||
|
||||
# Validate hostname
|
||||
validation = validate_hostname(params[:hostname])
|
||||
if validation_error?(validation)
|
||||
status 400
|
||||
return JSON.pretty_generate(validation)
|
||||
end
|
||||
|
||||
params[:hostname] = hostname_shorten(params[:hostname])
|
||||
|
||||
rdata = backend.hgetall("vmpooler__vm__#{params[:hostname]}")
|
||||
|
|
@ -1403,16 +1438,21 @@ module Vmpooler
|
|||
|
||||
failure = []
|
||||
|
||||
# Validate hostname
|
||||
validation = validate_hostname(params[:hostname])
|
||||
if validation_error?(validation)
|
||||
status 400
|
||||
return JSON.pretty_generate(validation)
|
||||
end
|
||||
|
||||
params[:hostname] = hostname_shorten(params[:hostname])
|
||||
|
||||
if backend.exists?("vmpooler__vm__#{params[:hostname]}")
|
||||
begin
|
||||
jdata = JSON.parse(request.body.read)
|
||||
rescue StandardError => e
|
||||
span = OpenTelemetry::Trace.current_span
|
||||
span.record_exception(e)
|
||||
span.status = OpenTelemetry::Trace::Status.error(e.to_s)
|
||||
halt 400, JSON.pretty_generate(result)
|
||||
# Validate and sanitize JSON body
|
||||
jdata = sanitize_json_body(request.body.read)
|
||||
if validation_error?(jdata)
|
||||
status 400
|
||||
return JSON.pretty_generate(jdata)
|
||||
end
|
||||
|
||||
# Validate data payload
|
||||
|
|
@ -1421,6 +1461,13 @@ module Vmpooler
|
|||
when 'lifetime'
|
||||
need_token! if Vmpooler::API.settings.config[:auth]
|
||||
|
||||
# Validate lifetime is a positive integer
|
||||
lifetime_int = arg.to_i
|
||||
if lifetime_int <= 0
|
||||
failure.push("Lifetime must be a positive integer (got #{arg})")
|
||||
next
|
||||
end
|
||||
|
||||
# in hours, defaults to one week
|
||||
max_lifetime_upper_limit = config['max_lifetime_upper_limit']
|
||||
if max_lifetime_upper_limit
|
||||
|
|
@ -1430,13 +1477,17 @@ module Vmpooler
|
|||
end
|
||||
end
|
||||
|
||||
# validate lifetime is within boundaries
|
||||
unless arg.to_i > 0
|
||||
failure.push("You provided a lifetime (#{arg}) but you must provide a positive number.")
|
||||
end
|
||||
|
||||
when 'tags'
|
||||
failure.push("You provided tags (#{arg}) as something other than a hash.") unless arg.is_a?(Hash)
|
||||
|
||||
# Validate each tag key and value
|
||||
arg.each do |key, value|
|
||||
tag_validation = validate_tag(key, value)
|
||||
if validation_error?(tag_validation)
|
||||
failure.push(tag_validation['error'])
|
||||
end
|
||||
end
|
||||
|
||||
failure.push("You provided unsuppored tags (#{arg}).") if config['allowed_tags'] && !(arg.keys - config['allowed_tags']).empty?
|
||||
else
|
||||
failure.push("Unknown argument #{arg}.")
|
||||
|
|
@ -1478,9 +1529,23 @@ module Vmpooler
|
|||
status 404
|
||||
result = { 'ok' => false }
|
||||
|
||||
# Validate hostname
|
||||
validation = validate_hostname(params[:hostname])
|
||||
if validation_error?(validation)
|
||||
status 400
|
||||
return JSON.pretty_generate(validation)
|
||||
end
|
||||
|
||||
# Validate disk size
|
||||
validated_size = validate_disk_size(params[:size])
|
||||
if validation_error?(validated_size)
|
||||
status 400
|
||||
return JSON.pretty_generate(validated_size)
|
||||
end
|
||||
|
||||
params[:hostname] = hostname_shorten(params[:hostname])
|
||||
|
||||
if ((params[:size].to_i > 0 )and (backend.exists?("vmpooler__vm__#{params[:hostname]}")))
|
||||
if backend.exists?("vmpooler__vm__#{params[:hostname]}")
|
||||
result[params[:hostname]] = {}
|
||||
result[params[:hostname]]['disk'] = "+#{params[:size]}gb"
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue