Merge pull request #690 from puppetlabs/P4DEVOPS-9434
Some checks failed
Security / Mend Scanning (push) Has been cancelled

Add rate limiting and input validation security enhancements
This commit is contained in:
Mahima Singh 2025-12-26 15:44:39 +05:30 committed by GitHub
commit 76eb62577b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 542 additions and 15 deletions

View file

@ -1,10 +1,13 @@
# frozen_string_literal: true # frozen_string_literal: true
require 'vmpooler/api/input_validator'
module Vmpooler module Vmpooler
class API class API
module Helpers module Helpers
include InputValidator
def tracer def tracer
@tracer ||= OpenTelemetry.tracer_provider.tracer('api', Vmpooler::VERSION) @tracer ||= OpenTelemetry.tracer_provider.tracer('api', Vmpooler::VERSION)

View file

@ -0,0 +1,159 @@
# frozen_string_literal: true
module Vmpooler
class API
# Input validation helpers to enhance security
module InputValidator
# Maximum lengths to prevent abuse
MAX_HOSTNAME_LENGTH = 253
MAX_TAG_KEY_LENGTH = 50
MAX_TAG_VALUE_LENGTH = 255
MAX_REASON_LENGTH = 500
MAX_POOL_NAME_LENGTH = 100
MAX_TOKEN_LENGTH = 64
# Valid patterns
HOSTNAME_PATTERN = /\A[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?(\.[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?)* \z/ix.freeze
POOL_NAME_PATTERN = /\A[a-zA-Z0-9_-]+\z/.freeze
TAG_KEY_PATTERN = /\A[a-zA-Z0-9_\-.]+\z/.freeze
TOKEN_PATTERN = /\A[a-zA-Z0-9\-_]+\z/.freeze
INTEGER_PATTERN = /\A\d+\z/.freeze
class ValidationError < StandardError; end
# Validate hostname format and length
def validate_hostname(hostname)
return error_response('Hostname is required') if hostname.nil? || hostname.empty?
return error_response('Hostname too long') if hostname.length > MAX_HOSTNAME_LENGTH
return error_response('Invalid hostname format') unless hostname.match?(HOSTNAME_PATTERN)
true
end
# Validate pool/template name
def validate_pool_name(pool_name)
return error_response('Pool name is required') if pool_name.nil? || pool_name.empty?
return error_response('Pool name too long') if pool_name.length > MAX_POOL_NAME_LENGTH
return error_response('Invalid pool name format') unless pool_name.match?(POOL_NAME_PATTERN)
true
end
# Validate tag key and value
def validate_tag(key, value)
return error_response('Tag key is required') if key.nil? || key.empty?
return error_response('Tag key too long') if key.length > MAX_TAG_KEY_LENGTH
return error_response('Invalid tag key format') unless key.match?(TAG_KEY_PATTERN)
if value
return error_response('Tag value too long') if value.length > MAX_TAG_VALUE_LENGTH
# Sanitize value to prevent injection attacks
sanitized_value = value.gsub(/[^\w\s\-.@:\/]/, '')
return error_response('Tag value contains invalid characters') if sanitized_value != value
end
true
end
# Validate token format
def validate_token_format(token)
return error_response('Token is required') if token.nil? || token.empty?
return error_response('Token too long') if token.length > MAX_TOKEN_LENGTH
return error_response('Invalid token format') unless token.match?(TOKEN_PATTERN)
true
end
# Validate integer parameter
def validate_integer(value, name = 'value', min: nil, max: nil)
return error_response("#{name} is required") if value.nil?
value_str = value.to_s
return error_response("#{name} must be a valid integer") unless value_str.match?(INTEGER_PATTERN)
int_value = value.to_i
return error_response("#{name} must be at least #{min}") if min && int_value < min
return error_response("#{name} must be at most #{max}") if max && int_value > max
int_value
end
# Validate VM request count
def validate_vm_count(count)
validated = validate_integer(count, 'VM count', min: 1, max: 100)
return validated if validated.is_a?(Hash) # error response
validated
end
# Validate disk size
def validate_disk_size(size)
validated = validate_integer(size, 'Disk size', min: 1, max: 2048)
return validated if validated.is_a?(Hash) # error response
validated
end
# Validate lifetime (TTL) in hours
def validate_lifetime(lifetime)
validated = validate_integer(lifetime, 'Lifetime', min: 1, max: 168) # max 1 week
return validated if validated.is_a?(Hash) # error response
validated
end
# Validate reason text
def validate_reason(reason)
return true if reason.nil? || reason.empty?
return error_response('Reason too long') if reason.length > MAX_REASON_LENGTH
# Sanitize to prevent XSS/injection
sanitized = reason.gsub(/[<>"']/, '')
return error_response('Reason contains invalid characters') if sanitized != reason
true
end
# Sanitize JSON body to prevent injection
def sanitize_json_body(body)
return {} if body.nil? || body.empty?
begin
parsed = JSON.parse(body)
return error_response('Request body must be a JSON object') unless parsed.is_a?(Hash)
# Limit depth and size to prevent DoS
return error_response('Request body too complex') if json_depth(parsed) > 5
return error_response('Request body too large') if body.length > 10_240 # 10KB max
parsed
rescue JSON::ParserError => e
error_response("Invalid JSON: #{e.message}")
end
end
# Check if validation result is an error
def validation_error?(result)
result.is_a?(Hash) && result['ok'] == false
end
private
def error_response(message)
{ 'ok' => false, 'error' => message }
end
def json_depth(obj, depth = 0)
return depth unless obj.is_a?(Hash) || obj.is_a?(Array)
return depth + 1 if obj.empty?
if obj.is_a?(Hash)
depth + 1 + obj.values.map { |v| json_depth(v, 0) }.max
else
depth + 1 + obj.map { |v| json_depth(v, 0) }.max
end
end
end
end
end

View file

@ -0,0 +1,116 @@
# frozen_string_literal: true
module Vmpooler
class API
# Rate limiter middleware to protect against abuse
# Uses Redis to track request counts per IP and token
class RateLimiter
DEFAULT_LIMITS = {
global_per_ip: { limit: 100, period: 60 }, # 100 requests per minute per IP
authenticated: { limit: 500, period: 60 }, # 500 requests per minute with token
vm_creation: { limit: 20, period: 60 }, # 20 VM creations per minute
vm_deletion: { limit: 50, period: 60 } # 50 VM deletions per minute
}.freeze
def initialize(app, redis, config = {})
@app = app
@redis = redis
@config = DEFAULT_LIMITS.merge(config[:rate_limits] || {})
@enabled = config.fetch(:rate_limiting_enabled, true)
end
def call(env)
return @app.call(env) unless @enabled
request = Rack::Request.new(env)
client_id = identify_client(request)
endpoint_type = classify_endpoint(request)
# Check rate limits
return rate_limit_response(client_id, endpoint_type) if rate_limit_exceeded?(client_id, endpoint_type, request)
# Track the request
increment_request_count(client_id, endpoint_type)
@app.call(env)
end
private
def identify_client(request)
# Prioritize token-based identification for authenticated requests
token = request.env['HTTP_X_AUTH_TOKEN']
return "token:#{token}" if token && !token.empty?
# Fall back to IP address
ip = request.ip || request.env['REMOTE_ADDR'] || 'unknown'
"ip:#{ip}"
end
def classify_endpoint(request)
path = request.path
method = request.request_method
return :vm_creation if method == 'POST' && path.include?('/vm')
return :vm_deletion if method == 'DELETE' && path.include?('/vm')
return :authenticated if request.env['HTTP_X_AUTH_TOKEN']
:global_per_ip
end
def rate_limit_exceeded?(client_id, endpoint_type, _request)
limit_config = @config[endpoint_type] || @config[:global_per_ip]
key = "vmpooler__ratelimit__#{endpoint_type}__#{client_id}"
current_count = @redis.get(key).to_i
current_count >= limit_config[:limit]
rescue StandardError => e
# If Redis fails, allow the request through (fail open)
warn "Rate limiter Redis error: #{e.message}"
false
end
def increment_request_count(client_id, endpoint_type)
limit_config = @config[endpoint_type] || @config[:global_per_ip]
key = "vmpooler__ratelimit__#{endpoint_type}__#{client_id}"
@redis.pipelined do |pipeline|
pipeline.incr(key)
pipeline.expire(key, limit_config[:period])
end
rescue StandardError => e
# Log error but don't fail the request
warn "Rate limiter increment error: #{e.message}"
end
def rate_limit_response(client_id, endpoint_type)
limit_config = @config[endpoint_type] || @config[:global_per_ip]
key = "vmpooler__ratelimit__#{endpoint_type}__#{client_id}"
begin
ttl = @redis.ttl(key)
rescue StandardError
ttl = limit_config[:period]
end
headers = {
'Content-Type' => 'application/json',
'X-RateLimit-Limit' => limit_config[:limit].to_s,
'X-RateLimit-Remaining' => '0',
'X-RateLimit-Reset' => (Time.now.to_i + ttl).to_s,
'Retry-After' => ttl.to_s
}
body = JSON.pretty_generate({
'ok' => false,
'error' => 'Rate limit exceeded',
'limit' => limit_config[:limit],
'period' => limit_config[:period],
'retry_after' => ttl
})
[429, headers, [body]]
end
end
end
end

View file

@ -1137,9 +1137,29 @@ module Vmpooler
result = { 'ok' => false } result = { 'ok' => false }
metrics.increment('http_requests_vm_total.post.vm.checkout') metrics.increment('http_requests_vm_total.post.vm.checkout')
payload = JSON.parse(request.body.read) # Validate and sanitize JSON body
payload = sanitize_json_body(request.body.read)
if validation_error?(payload)
status 400
return JSON.pretty_generate(payload)
end
if payload # Validate each template and count
payload.each do |template, count|
validation = validate_pool_name(template)
if validation_error?(validation)
status 400
return JSON.pretty_generate(validation)
end
validated_count = validate_vm_count(count)
if validation_error?(validated_count)
status 400
return JSON.pretty_generate(validated_count)
end
end
if payload && !payload.empty?
invalid = invalid_templates(payload) invalid = invalid_templates(payload)
if invalid.empty? if invalid.empty?
result = atomically_allocate_vms(payload) result = atomically_allocate_vms(payload)
@ -1258,6 +1278,7 @@ module Vmpooler
result = { 'ok' => false } result = { 'ok' => false }
metrics.increment('http_requests_vm_total.get.vm.template') metrics.increment('http_requests_vm_total.get.vm.template')
# Template can contain multiple pools separated by +, so validate after parsing
payload = extract_templates_from_query_params(params[:template]) payload = extract_templates_from_query_params(params[:template])
if payload if payload
@ -1287,6 +1308,13 @@ module Vmpooler
status 404 status 404
result['ok'] = false result['ok'] = false
# Validate hostname
validation = validate_hostname(params[:hostname])
if validation_error?(validation)
status 400
return JSON.pretty_generate(validation)
end
params[:hostname] = hostname_shorten(params[:hostname]) params[:hostname] = hostname_shorten(params[:hostname])
rdata = backend.hgetall("vmpooler__vm__#{params[:hostname]}") rdata = backend.hgetall("vmpooler__vm__#{params[:hostname]}")
@ -1425,6 +1453,13 @@ module Vmpooler
status 404 status 404
result['ok'] = false result['ok'] = false
# Validate hostname
validation = validate_hostname(params[:hostname])
if validation_error?(validation)
status 400
return JSON.pretty_generate(validation)
end
params[:hostname] = hostname_shorten(params[:hostname]) params[:hostname] = hostname_shorten(params[:hostname])
rdata = backend.hgetall("vmpooler__vm__#{params[:hostname]}") rdata = backend.hgetall("vmpooler__vm__#{params[:hostname]}")
@ -1455,16 +1490,21 @@ module Vmpooler
failure = [] failure = []
# Validate hostname
validation = validate_hostname(params[:hostname])
if validation_error?(validation)
status 400
return JSON.pretty_generate(validation)
end
params[:hostname] = hostname_shorten(params[:hostname]) params[:hostname] = hostname_shorten(params[:hostname])
if backend.exists?("vmpooler__vm__#{params[:hostname]}") if backend.exists?("vmpooler__vm__#{params[:hostname]}")
begin # Validate and sanitize JSON body
jdata = JSON.parse(request.body.read) jdata = sanitize_json_body(request.body.read)
rescue StandardError => e if validation_error?(jdata)
span = OpenTelemetry::Trace.current_span status 400
span.record_exception(e) return JSON.pretty_generate(jdata)
span.status = OpenTelemetry::Trace::Status.error(e.to_s)
halt 400, JSON.pretty_generate(result)
end end
# Validate data payload # Validate data payload
@ -1473,6 +1513,13 @@ module Vmpooler
when 'lifetime' when 'lifetime'
need_token! if Vmpooler::API.settings.config[:auth] need_token! if Vmpooler::API.settings.config[:auth]
# Validate lifetime is a positive integer
lifetime_int = arg.to_i
if lifetime_int <= 0
failure.push("Lifetime must be a positive integer (got #{arg})")
next
end
# in hours, defaults to one week # in hours, defaults to one week
max_lifetime_upper_limit = config['max_lifetime_upper_limit'] max_lifetime_upper_limit = config['max_lifetime_upper_limit']
if max_lifetime_upper_limit if max_lifetime_upper_limit
@ -1482,13 +1529,17 @@ module Vmpooler
end end
end end
# validate lifetime is within boundaries
unless arg.to_i > 0
failure.push("You provided a lifetime (#{arg}) but you must provide a positive number.")
end
when 'tags' when 'tags'
failure.push("You provided tags (#{arg}) as something other than a hash.") unless arg.is_a?(Hash) failure.push("You provided tags (#{arg}) as something other than a hash.") unless arg.is_a?(Hash)
# Validate each tag key and value
arg.each do |key, value|
tag_validation = validate_tag(key, value)
if validation_error?(tag_validation)
failure.push(tag_validation['error'])
end
end
failure.push("You provided unsuppored tags (#{arg}).") if config['allowed_tags'] && !(arg.keys - config['allowed_tags']).empty? failure.push("You provided unsuppored tags (#{arg}).") if config['allowed_tags'] && !(arg.keys - config['allowed_tags']).empty?
else else
failure.push("Unknown argument #{arg}.") failure.push("Unknown argument #{arg}.")
@ -1530,9 +1581,23 @@ module Vmpooler
status 404 status 404
result = { 'ok' => false } result = { 'ok' => false }
# Validate hostname
validation = validate_hostname(params[:hostname])
if validation_error?(validation)
status 400
return JSON.pretty_generate(validation)
end
# Validate disk size
validated_size = validate_disk_size(params[:size])
if validation_error?(validated_size)
status 400
return JSON.pretty_generate(validated_size)
end
params[:hostname] = hostname_shorten(params[:hostname]) params[:hostname] = hostname_shorten(params[:hostname])
if ((params[:size].to_i > 0 )and (backend.exists?("vmpooler__vm__#{params[:hostname]}"))) if backend.exists?("vmpooler__vm__#{params[:hostname]}")
result[params[:hostname]] = {} result[params[:hostname]] = {}
result[params[:hostname]]['disk'] = "+#{params[:size]}gb" result[params[:hostname]]['disk'] = "+#{params[:size]}gb"

View file

@ -0,0 +1,184 @@
# frozen_string_literal: true
require 'spec_helper'
require 'rack/test'
require 'vmpooler/api/input_validator'
describe Vmpooler::API::InputValidator do
let(:test_class) do
Class.new do
include Vmpooler::API::InputValidator
end
end
let(:validator) { test_class.new }
describe '#validate_hostname' do
it 'accepts valid hostnames' do
expect(validator.validate_hostname('test-host.example.com')).to be true
expect(validator.validate_hostname('host123')).to be true
end
it 'rejects invalid hostnames' do
result = validator.validate_hostname('invalid_host!')
expect(result['ok']).to be false
expect(result['error']).to include('Invalid hostname format')
end
it 'rejects hostnames that are too long' do
long_hostname = 'a' * 300
result = validator.validate_hostname(long_hostname)
expect(result['ok']).to be false
expect(result['error']).to include('too long')
end
it 'rejects empty hostnames' do
result = validator.validate_hostname('')
expect(result['ok']).to be false
expect(result['error']).to include('required')
end
end
describe '#validate_pool_name' do
it 'accepts valid pool names' do
expect(validator.validate_pool_name('centos-7-x86_64')).to be true
expect(validator.validate_pool_name('ubuntu-2204')).to be true
end
it 'rejects invalid pool names' do
result = validator.validate_pool_name('invalid pool!')
expect(result['ok']).to be false
expect(result['error']).to include('Invalid pool name format')
end
it 'rejects pool names that are too long' do
result = validator.validate_pool_name('a' * 150)
expect(result['ok']).to be false
expect(result['error']).to include('too long')
end
end
describe '#validate_tag' do
it 'accepts valid tags' do
expect(validator.validate_tag('project', 'test-123')).to be true
expect(validator.validate_tag('owner', 'user@example.com')).to be true
end
it 'rejects tags with invalid keys' do
result = validator.validate_tag('invalid key!', 'value')
expect(result['ok']).to be false
expect(result['error']).to include('Invalid tag key format')
end
it 'rejects tags with invalid characters in value' do
result = validator.validate_tag('key', 'value<script>')
expect(result['ok']).to be false
expect(result['error']).to include('invalid characters')
end
it 'rejects tags that are too long' do
result = validator.validate_tag('key', 'a' * 300)
expect(result['ok']).to be false
expect(result['error']).to include('too long')
end
end
describe '#validate_vm_count' do
it 'accepts valid VM counts' do
expect(validator.validate_vm_count(5)).to eq(5)
expect(validator.validate_vm_count('10')).to eq(10)
end
it 'rejects counts less than 1' do
result = validator.validate_vm_count(0)
expect(result['ok']).to be false
expect(result['error']).to include('at least 1')
end
it 'rejects counts greater than 100' do
result = validator.validate_vm_count(150)
expect(result['ok']).to be false
expect(result['error']).to include('at most 100')
end
it 'rejects non-integer values' do
result = validator.validate_vm_count('abc')
expect(result['ok']).to be false
expect(result['error']).to include('valid integer')
end
end
describe '#validate_disk_size' do
it 'accepts valid disk sizes' do
expect(validator.validate_disk_size(50)).to eq(50)
expect(validator.validate_disk_size('100')).to eq(100)
end
it 'rejects sizes less than 1' do
result = validator.validate_disk_size(0)
expect(result['ok']).to be false
end
it 'rejects sizes greater than 2048' do
result = validator.validate_disk_size(3000)
expect(result['ok']).to be false
end
end
describe '#validate_lifetime' do
it 'accepts valid lifetimes' do
expect(validator.validate_lifetime(24)).to eq(24)
expect(validator.validate_lifetime('48')).to eq(48)
end
it 'rejects lifetimes greater than 168 hours (1 week)' do
result = validator.validate_lifetime(200)
expect(result['ok']).to be false
expect(result['error']).to include('at most 168')
end
end
describe '#sanitize_json_body' do
it 'parses valid JSON' do
result = validator.sanitize_json_body('{"key": "value"}')
expect(result).to eq('key' => 'value')
end
it 'rejects invalid JSON' do
result = validator.sanitize_json_body('{invalid}')
expect(result['ok']).to be false
expect(result['error']).to include('Invalid JSON')
end
it 'rejects non-object JSON' do
result = validator.sanitize_json_body('["array"]')
expect(result['ok']).to be false
expect(result['error']).to include('must be a JSON object')
end
it 'rejects deeply nested JSON' do
deep_json = '{"a":{"b":{"c":{"d":{"e":{"f":"too deep"}}}}}}'
result = validator.sanitize_json_body(deep_json)
expect(result['ok']).to be false
expect(result['error']).to include('too complex')
end
it 'rejects bodies that are too large' do
large_json = '{"data":"' + ('a' * 20000) + '"}'
result = validator.sanitize_json_body(large_json)
expect(result['ok']).to be false
expect(result['error']).to include('too large')
end
end
describe '#validation_error?' do
it 'returns true for error responses' do
error = { 'ok' => false, 'error' => 'test error' }
expect(validator.validation_error?(error)).to be true
end
it 'returns false for successful responses' do
expect(validator.validation_error?(true)).to be false
expect(validator.validation_error?(5)).to be false
end
end
end