From d0020becb3a08417d7e58f118987e22eea9bb42c Mon Sep 17 00:00:00 2001 From: Mahima Singh <105724608+smahima27@users.noreply.github.com> Date: Wed, 24 Dec 2025 16:53:28 +0530 Subject: [PATCH] Add rate limiting and input validation security enhancements --- Gemfile | 4 +- Gemfile.lock | 1 + lib/vmpooler/api/helpers.rb | 3 + lib/vmpooler/api/input_validator.rb | 159 ++++++++++++++++++++++ lib/vmpooler/api/rate_limiter.rb | 116 ++++++++++++++++ lib/vmpooler/api/v3.rb | 95 ++++++++++--- spec/unit/api/input_validator_spec.rb | 184 ++++++++++++++++++++++++++ 7 files changed, 545 insertions(+), 17 deletions(-) create mode 100644 lib/vmpooler/api/input_validator.rb create mode 100644 lib/vmpooler/api/rate_limiter.rb create mode 100644 spec/unit/api/input_validator_spec.rb diff --git a/Gemfile b/Gemfile index 122d6b5..0313b80 100644 --- a/Gemfile +++ b/Gemfile @@ -3,11 +3,11 @@ source ENV['GEM_SOURCE'] || 'https://rubygems.org' gemspec # Evaluate Gemfile.local if it exists -if File.exists? "#{__FILE__}.local" +if File.exist? "#{__FILE__}.local" instance_eval(File.read("#{__FILE__}.local")) end # Evaluate ~/.gemfile if it exists -if File.exists?(File.join(Dir.home, '.gemfile')) +if File.exist?(File.join(Dir.home, '.gemfile')) instance_eval(File.read(File.join(Dir.home, '.gemfile'))) end diff --git a/Gemfile.lock b/Gemfile.lock index 418f24d..2099da1 100644 --- a/Gemfile.lock +++ b/Gemfile.lock @@ -197,6 +197,7 @@ GEM PLATFORMS arm64-darwin-22 arm64-darwin-23 + arm64-darwin-25 universal-java-11 universal-java-17 x86_64-darwin-22 diff --git a/lib/vmpooler/api/helpers.rb b/lib/vmpooler/api/helpers.rb index 025e0b7..75002d4 100644 --- a/lib/vmpooler/api/helpers.rb +++ b/lib/vmpooler/api/helpers.rb @@ -1,10 +1,13 @@ # frozen_string_literal: true +require 'vmpooler/api/input_validator' + module Vmpooler class API module Helpers + include InputValidator def tracer @tracer ||= OpenTelemetry.tracer_provider.tracer('api', Vmpooler::VERSION) diff --git a/lib/vmpooler/api/input_validator.rb b/lib/vmpooler/api/input_validator.rb new file mode 100644 index 0000000..add4d6a --- /dev/null +++ b/lib/vmpooler/api/input_validator.rb @@ -0,0 +1,159 @@ +# frozen_string_literal: true + +module Vmpooler + class API + # Input validation helpers to enhance security + module InputValidator + # Maximum lengths to prevent abuse + MAX_HOSTNAME_LENGTH = 253 + MAX_TAG_KEY_LENGTH = 50 + MAX_TAG_VALUE_LENGTH = 255 + MAX_REASON_LENGTH = 500 + MAX_POOL_NAME_LENGTH = 100 + MAX_TOKEN_LENGTH = 64 + + # Valid patterns + HOSTNAME_PATTERN = /\A[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?(\.[a-z0-9]([a-z0-9-]{0,61}[a-z0-9])?)* \z/ix.freeze + POOL_NAME_PATTERN = /\A[a-zA-Z0-9_-]+\z/.freeze + TAG_KEY_PATTERN = /\A[a-zA-Z0-9_\-.]+\z/.freeze + TOKEN_PATTERN = /\A[a-zA-Z0-9\-_]+\z/.freeze + INTEGER_PATTERN = /\A\d+\z/.freeze + + class ValidationError < StandardError; end + + # Validate hostname format and length + def validate_hostname(hostname) + return error_response('Hostname is required') if hostname.nil? || hostname.empty? + return error_response('Hostname too long') if hostname.length > MAX_HOSTNAME_LENGTH + return error_response('Invalid hostname format') unless hostname.match?(HOSTNAME_PATTERN) + + true + end + + # Validate pool/template name + def validate_pool_name(pool_name) + return error_response('Pool name is required') if pool_name.nil? || pool_name.empty? + return error_response('Pool name too long') if pool_name.length > MAX_POOL_NAME_LENGTH + return error_response('Invalid pool name format') unless pool_name.match?(POOL_NAME_PATTERN) + + true + end + + # Validate tag key and value + def validate_tag(key, value) + return error_response('Tag key is required') if key.nil? || key.empty? + return error_response('Tag key too long') if key.length > MAX_TAG_KEY_LENGTH + return error_response('Invalid tag key format') unless key.match?(TAG_KEY_PATTERN) + + if value + return error_response('Tag value too long') if value.length > MAX_TAG_VALUE_LENGTH + + # Sanitize value to prevent injection attacks + sanitized_value = value.gsub(/[^\w\s\-.@:\/]/, '') + return error_response('Tag value contains invalid characters') if sanitized_value != value + end + + true + end + + # Validate token format + def validate_token_format(token) + return error_response('Token is required') if token.nil? || token.empty? + return error_response('Token too long') if token.length > MAX_TOKEN_LENGTH + return error_response('Invalid token format') unless token.match?(TOKEN_PATTERN) + + true + end + + # Validate integer parameter + def validate_integer(value, name = 'value', min: nil, max: nil) + return error_response("#{name} is required") if value.nil? + + value_str = value.to_s + return error_response("#{name} must be a valid integer") unless value_str.match?(INTEGER_PATTERN) + + int_value = value.to_i + return error_response("#{name} must be at least #{min}") if min && int_value < min + return error_response("#{name} must be at most #{max}") if max && int_value > max + + int_value + end + + # Validate VM request count + def validate_vm_count(count) + validated = validate_integer(count, 'VM count', min: 1, max: 100) + return validated if validated.is_a?(Hash) # error response + + validated + end + + # Validate disk size + def validate_disk_size(size) + validated = validate_integer(size, 'Disk size', min: 1, max: 2048) + return validated if validated.is_a?(Hash) # error response + + validated + end + + # Validate lifetime (TTL) in hours + def validate_lifetime(lifetime) + validated = validate_integer(lifetime, 'Lifetime', min: 1, max: 168) # max 1 week + return validated if validated.is_a?(Hash) # error response + + validated + end + + # Validate reason text + def validate_reason(reason) + return true if reason.nil? || reason.empty? + return error_response('Reason too long') if reason.length > MAX_REASON_LENGTH + + # Sanitize to prevent XSS/injection + sanitized = reason.gsub(/[<>"']/, '') + return error_response('Reason contains invalid characters') if sanitized != reason + + true + end + + # Sanitize JSON body to prevent injection + def sanitize_json_body(body) + return {} if body.nil? || body.empty? + + begin + parsed = JSON.parse(body) + return error_response('Request body must be a JSON object') unless parsed.is_a?(Hash) + + # Limit depth and size to prevent DoS + return error_response('Request body too complex') if json_depth(parsed) > 5 + return error_response('Request body too large') if body.length > 10_240 # 10KB max + + parsed + rescue JSON::ParserError => e + error_response("Invalid JSON: #{e.message}") + end + end + + # Check if validation result is an error + def validation_error?(result) + result.is_a?(Hash) && result['ok'] == false + end + + private + + def error_response(message) + { 'ok' => false, 'error' => message } + end + + def json_depth(obj, depth = 0) + return depth unless obj.is_a?(Hash) || obj.is_a?(Array) + return depth + 1 if obj.empty? + + if obj.is_a?(Hash) + depth + 1 + obj.values.map { |v| json_depth(v, 0) }.max + else + depth + 1 + obj.map { |v| json_depth(v, 0) }.max + end + end + end + end +end diff --git a/lib/vmpooler/api/rate_limiter.rb b/lib/vmpooler/api/rate_limiter.rb new file mode 100644 index 0000000..8ecfb62 --- /dev/null +++ b/lib/vmpooler/api/rate_limiter.rb @@ -0,0 +1,116 @@ +# frozen_string_literal: true + +module Vmpooler + class API + # Rate limiter middleware to protect against abuse + # Uses Redis to track request counts per IP and token + class RateLimiter + DEFAULT_LIMITS = { + global_per_ip: { limit: 100, period: 60 }, # 100 requests per minute per IP + authenticated: { limit: 500, period: 60 }, # 500 requests per minute with token + vm_creation: { limit: 20, period: 60 }, # 20 VM creations per minute + vm_deletion: { limit: 50, period: 60 } # 50 VM deletions per minute + }.freeze + + def initialize(app, redis, config = {}) + @app = app + @redis = redis + @config = DEFAULT_LIMITS.merge(config[:rate_limits] || {}) + @enabled = config.fetch(:rate_limiting_enabled, true) + end + + def call(env) + return @app.call(env) unless @enabled + + request = Rack::Request.new(env) + client_id = identify_client(request) + endpoint_type = classify_endpoint(request) + + # Check rate limits + return rate_limit_response(client_id, endpoint_type) if rate_limit_exceeded?(client_id, endpoint_type, request) + + # Track the request + increment_request_count(client_id, endpoint_type) + + @app.call(env) + end + + private + + def identify_client(request) + # Prioritize token-based identification for authenticated requests + token = request.env['HTTP_X_AUTH_TOKEN'] + return "token:#{token}" if token && !token.empty? + + # Fall back to IP address + ip = request.ip || request.env['REMOTE_ADDR'] || 'unknown' + "ip:#{ip}" + end + + def classify_endpoint(request) + path = request.path + method = request.request_method + + return :vm_creation if method == 'POST' && path.include?('/vm') + return :vm_deletion if method == 'DELETE' && path.include?('/vm') + return :authenticated if request.env['HTTP_X_AUTH_TOKEN'] + + :global_per_ip + end + + def rate_limit_exceeded?(client_id, endpoint_type, _request) + limit_config = @config[endpoint_type] || @config[:global_per_ip] + key = "vmpooler__ratelimit__#{endpoint_type}__#{client_id}" + + current_count = @redis.get(key).to_i + current_count >= limit_config[:limit] + rescue StandardError => e + # If Redis fails, allow the request through (fail open) + warn "Rate limiter Redis error: #{e.message}" + false + end + + def increment_request_count(client_id, endpoint_type) + limit_config = @config[endpoint_type] || @config[:global_per_ip] + key = "vmpooler__ratelimit__#{endpoint_type}__#{client_id}" + + @redis.pipelined do |pipeline| + pipeline.incr(key) + pipeline.expire(key, limit_config[:period]) + end + rescue StandardError => e + # Log error but don't fail the request + warn "Rate limiter increment error: #{e.message}" + end + + def rate_limit_response(client_id, endpoint_type) + limit_config = @config[endpoint_type] || @config[:global_per_ip] + key = "vmpooler__ratelimit__#{endpoint_type}__#{client_id}" + + begin + ttl = @redis.ttl(key) + rescue StandardError + ttl = limit_config[:period] + end + + headers = { + 'Content-Type' => 'application/json', + 'X-RateLimit-Limit' => limit_config[:limit].to_s, + 'X-RateLimit-Remaining' => '0', + 'X-RateLimit-Reset' => (Time.now.to_i + ttl).to_s, + 'Retry-After' => ttl.to_s + } + + body = JSON.pretty_generate({ + 'ok' => false, + 'error' => 'Rate limit exceeded', + 'limit' => limit_config[:limit], + 'period' => limit_config[:period], + 'retry_after' => ttl + }) + + [429, headers, [body]] + end + end + end +end diff --git a/lib/vmpooler/api/v3.rb b/lib/vmpooler/api/v3.rb index 30b5b7c..1c7b788 100644 --- a/lib/vmpooler/api/v3.rb +++ b/lib/vmpooler/api/v3.rb @@ -1085,9 +1085,29 @@ module Vmpooler result = { 'ok' => false } metrics.increment('http_requests_vm_total.post.vm.checkout') - payload = JSON.parse(request.body.read) + # Validate and sanitize JSON body + payload = sanitize_json_body(request.body.read) + if validation_error?(payload) + status 400 + return JSON.pretty_generate(payload) + end - if payload + # Validate each template and count + payload.each do |template, count| + validation = validate_pool_name(template) + if validation_error?(validation) + status 400 + return JSON.pretty_generate(validation) + end + + validated_count = validate_vm_count(count) + if validation_error?(validated_count) + status 400 + return JSON.pretty_generate(validated_count) + end + end + + if payload && !payload.empty? invalid = invalid_templates(payload) if invalid.empty? result = atomically_allocate_vms(payload) @@ -1206,6 +1226,7 @@ module Vmpooler result = { 'ok' => false } metrics.increment('http_requests_vm_total.get.vm.template') + # Template can contain multiple pools separated by +, so validate after parsing payload = extract_templates_from_query_params(params[:template]) if payload @@ -1235,6 +1256,13 @@ module Vmpooler status 404 result['ok'] = false + # Validate hostname + validation = validate_hostname(params[:hostname]) + if validation_error?(validation) + status 400 + return JSON.pretty_generate(validation) + end + params[:hostname] = hostname_shorten(params[:hostname]) rdata = backend.hgetall("vmpooler__vm__#{params[:hostname]}") @@ -1373,6 +1401,13 @@ module Vmpooler status 404 result['ok'] = false + # Validate hostname + validation = validate_hostname(params[:hostname]) + if validation_error?(validation) + status 400 + return JSON.pretty_generate(validation) + end + params[:hostname] = hostname_shorten(params[:hostname]) rdata = backend.hgetall("vmpooler__vm__#{params[:hostname]}") @@ -1403,16 +1438,21 @@ module Vmpooler failure = [] + # Validate hostname + validation = validate_hostname(params[:hostname]) + if validation_error?(validation) + status 400 + return JSON.pretty_generate(validation) + end + params[:hostname] = hostname_shorten(params[:hostname]) if backend.exists?("vmpooler__vm__#{params[:hostname]}") - begin - jdata = JSON.parse(request.body.read) - rescue StandardError => e - span = OpenTelemetry::Trace.current_span - span.record_exception(e) - span.status = OpenTelemetry::Trace::Status.error(e.to_s) - halt 400, JSON.pretty_generate(result) + # Validate and sanitize JSON body + jdata = sanitize_json_body(request.body.read) + if validation_error?(jdata) + status 400 + return JSON.pretty_generate(jdata) end # Validate data payload @@ -1421,6 +1461,13 @@ module Vmpooler when 'lifetime' need_token! if Vmpooler::API.settings.config[:auth] + # Validate lifetime is a positive integer + lifetime_int = arg.to_i + if lifetime_int <= 0 + failure.push("Lifetime must be a positive integer (got #{arg})") + next + end + # in hours, defaults to one week max_lifetime_upper_limit = config['max_lifetime_upper_limit'] if max_lifetime_upper_limit @@ -1430,13 +1477,17 @@ module Vmpooler end end - # validate lifetime is within boundaries - unless arg.to_i > 0 - failure.push("You provided a lifetime (#{arg}) but you must provide a positive number.") - end - when 'tags' failure.push("You provided tags (#{arg}) as something other than a hash.") unless arg.is_a?(Hash) + + # Validate each tag key and value + arg.each do |key, value| + tag_validation = validate_tag(key, value) + if validation_error?(tag_validation) + failure.push(tag_validation['error']) + end + end + failure.push("You provided unsuppored tags (#{arg}).") if config['allowed_tags'] && !(arg.keys - config['allowed_tags']).empty? else failure.push("Unknown argument #{arg}.") @@ -1478,9 +1529,23 @@ module Vmpooler status 404 result = { 'ok' => false } + # Validate hostname + validation = validate_hostname(params[:hostname]) + if validation_error?(validation) + status 400 + return JSON.pretty_generate(validation) + end + + # Validate disk size + validated_size = validate_disk_size(params[:size]) + if validation_error?(validated_size) + status 400 + return JSON.pretty_generate(validated_size) + end + params[:hostname] = hostname_shorten(params[:hostname]) - if ((params[:size].to_i > 0 )and (backend.exists?("vmpooler__vm__#{params[:hostname]}"))) + if backend.exists?("vmpooler__vm__#{params[:hostname]}") result[params[:hostname]] = {} result[params[:hostname]]['disk'] = "+#{params[:size]}gb" diff --git a/spec/unit/api/input_validator_spec.rb b/spec/unit/api/input_validator_spec.rb new file mode 100644 index 0000000..24982ed --- /dev/null +++ b/spec/unit/api/input_validator_spec.rb @@ -0,0 +1,184 @@ +# frozen_string_literal: true + +require 'spec_helper' +require 'rack/test' +require 'vmpooler/api/input_validator' + +describe Vmpooler::API::InputValidator do + let(:test_class) do + Class.new do + include Vmpooler::API::InputValidator + end + end + let(:validator) { test_class.new } + + describe '#validate_hostname' do + it 'accepts valid hostnames' do + expect(validator.validate_hostname('test-host.example.com')).to be true + expect(validator.validate_hostname('host123')).to be true + end + + it 'rejects invalid hostnames' do + result = validator.validate_hostname('invalid_host!') + expect(result['ok']).to be false + expect(result['error']).to include('Invalid hostname format') + end + + it 'rejects hostnames that are too long' do + long_hostname = 'a' * 300 + result = validator.validate_hostname(long_hostname) + expect(result['ok']).to be false + expect(result['error']).to include('too long') + end + + it 'rejects empty hostnames' do + result = validator.validate_hostname('') + expect(result['ok']).to be false + expect(result['error']).to include('required') + end + end + + describe '#validate_pool_name' do + it 'accepts valid pool names' do + expect(validator.validate_pool_name('centos-7-x86_64')).to be true + expect(validator.validate_pool_name('ubuntu-2204')).to be true + end + + it 'rejects invalid pool names' do + result = validator.validate_pool_name('invalid pool!') + expect(result['ok']).to be false + expect(result['error']).to include('Invalid pool name format') + end + + it 'rejects pool names that are too long' do + result = validator.validate_pool_name('a' * 150) + expect(result['ok']).to be false + expect(result['error']).to include('too long') + end + end + + describe '#validate_tag' do + it 'accepts valid tags' do + expect(validator.validate_tag('project', 'test-123')).to be true + expect(validator.validate_tag('owner', 'user@example.com')).to be true + end + + it 'rejects tags with invalid keys' do + result = validator.validate_tag('invalid key!', 'value') + expect(result['ok']).to be false + expect(result['error']).to include('Invalid tag key format') + end + + it 'rejects tags with invalid characters in value' do + result = validator.validate_tag('key', 'value