diff options
Diffstat (limited to 'app/controllers/concerns')
7 files changed, 187 insertions, 41 deletions
diff --git a/app/controllers/concerns/account_controller_concern.rb b/app/controllers/concerns/account_controller_concern.rb index 2f7d84df0..e9cff22ca 100644 --- a/app/controllers/concerns/account_controller_concern.rb +++ b/app/controllers/concerns/account_controller_concern.rb @@ -10,7 +10,8 @@ module AccountControllerConcern included do before_action :set_instance_presenter - before_action :set_link_headers, if: -> { request.format.nil? || request.format == :html } + + after_action :set_link_headers, if: -> { request.format.nil? || request.format == :html } end private diff --git a/app/controllers/concerns/admin_export_controller_concern.rb b/app/controllers/concerns/admin_export_controller_concern.rb index b40c76557..4ac48a04b 100644 --- a/app/controllers/concerns/admin_export_controller_concern.rb +++ b/app/controllers/concerns/admin_export_controller_concern.rb @@ -26,14 +26,4 @@ module AdminExportControllerConcern def import_params params.require(:admin_import).permit(:data) end - - def import_data_path - params[:admin_import][:data].path - end - - def parse_import_data!(default_headers) - data = CSV.read(import_data_path, headers: true, encoding: 'UTF-8') - data = CSV.read(import_data_path, headers: default_headers, encoding: 'UTF-8') unless data.headers&.first&.strip&.include?(default_headers[0]) - @data = data.reject(&:blank?) - end end diff --git a/app/controllers/concerns/cache_concern.rb b/app/controllers/concerns/cache_concern.rb index 05e431b19..a5a9ba3e1 100644 --- a/app/controllers/concerns/cache_concern.rb +++ b/app/controllers/concerns/cache_concern.rb @@ -3,6 +3,158 @@ module CacheConcern extend ActiveSupport::Concern + module ActiveRecordCoder + EMPTY_HASH = {}.freeze + + class << self + def dump(record) + instances = InstanceTracker.new + serialized_associations = serialize_associations(record, instances) + serialized_records = instances.map { |r| serialize_record(r) } + [serialized_associations, *serialized_records] + end + + def load(payload) + instances = InstanceTracker.new + serialized_associations, *serialized_records = payload + serialized_records.each { |attrs| instances.push(deserialize_record(*attrs)) } + deserialize_associations(serialized_associations, instances) + end + + private + + # Records without associations, or which have already been visited before, + # are serialized by their id alone. + # + # Records with associations are serialized as a two-element array including + # their id and the record's association cache. + # + def serialize_associations(record, instances) + return unless record + + if (id = instances.lookup(record)) + payload = id + else + payload = instances.push(record) + + cached_associations = record.class.reflect_on_all_associations.select do |reflection| + record.association_cached?(reflection.name) + end + + unless cached_associations.empty? + serialized_associations = cached_associations.map do |reflection| + association = record.association(reflection.name) + + serialized_target = if reflection.collection? + association.target.map { |target_record| serialize_associations(target_record, instances) } + else + serialize_associations(association.target, instances) + end + + [reflection.name, serialized_target] + end + + payload = [payload, serialized_associations] + end + end + + payload + end + + def deserialize_associations(payload, instances) + return unless payload + + id, associations = payload + record = instances.fetch(id) + + associations&.each do |name, serialized_target| + begin + association = record.association(name) + rescue ActiveRecord::AssociationNotFoundError + raise AssociationMissingError, "undefined association: #{name}" + end + + target = if association.reflection.collection? + serialized_target.map! { |serialized_record| deserialize_associations(serialized_record, instances) } + else + deserialize_associations(serialized_target, instances) + end + + association.target = target + end + + record + end + + def serialize_record(record) + arguments = [record.class.name, attributes_for_database(record)] + arguments << true if record.new_record? + arguments + end + + if Rails.gem_version >= Gem::Version.new('7.0') + def attributes_for_database(record) + attributes = record.attributes_for_database + attributes.transform_values! { |attr| attr.is_a?(::ActiveModel::Type::Binary::Data) ? attr.to_s : attr } + attributes + end + else + def attributes_for_database(record) + attributes = record.instance_variable_get(:@attributes).send(:attributes).transform_values(&:value_for_database) + attributes.transform_values! { |attr| attr.is_a?(::ActiveModel::Type::Binary::Data) ? attr.to_s : attr } + attributes + end + end + + def deserialize_record(class_name, attributes_from_database, new_record = false) # rubocop:disable Style/OptionalBooleanParameter + begin + klass = Object.const_get(class_name) + rescue NameError + raise ClassMissingError, "undefined class: #{class_name}" + end + + # Ideally we'd like to call `klass.instantiate`, however it doesn't allow to pass + # wether the record was persisted or not. + attributes = klass.attributes_builder.build_from_database(attributes_from_database, EMPTY_HASH) + klass.allocate.init_with_attributes(attributes, new_record) + end + end + + class Error < StandardError + end + + class ClassMissingError < Error + end + + class AssociationMissingError < Error + end + + class InstanceTracker + def initialize + @instances = [] + @ids = {}.compare_by_identity + end + + def map(&block) + @instances.map(&block) + end + + def fetch(...) + @instances.fetch(...) + end + + def push(instance) + id = @ids[instance] = @instances.size + @instances << instance + id + end + + def lookup(instance) + @ids[instance] + end + end + end + def render_with_cache(**options) raise ArgumentError, 'only JSON render calls are supported' unless options.key?(:json) || block_given? @@ -34,8 +186,13 @@ module CacheConcern raw = raw.cache_ids.to_a if raw.is_a?(ActiveRecord::Relation) return [] if raw.empty? - cached_keys_with_value = Rails.cache.read_multi(*raw).transform_keys(&:id) - uncached_ids = raw.map(&:id) - cached_keys_with_value.keys + cached_keys_with_value = begin + Rails.cache.read_multi(*raw).transform_keys(&:id).transform_values { |r| ActiveRecordCoder.load(r) } + rescue ActiveRecordCoder::Error + {} # The serialization format may have changed, let's pretend it's a cache miss. + end + + uncached_ids = raw.map(&:id) - cached_keys_with_value.keys klass.reload_stale_associations!(cached_keys_with_value.values) if klass.respond_to?(:reload_stale_associations!) @@ -43,7 +200,7 @@ module CacheConcern uncached = klass.where(id: uncached_ids).with_includes.index_by(&:id) uncached.each_value do |item| - Rails.cache.write(item, item) + Rails.cache.write(item, ActiveRecordCoder.dump(item)) end end diff --git a/app/controllers/concerns/rate_limit_headers.rb b/app/controllers/concerns/rate_limit_headers.rb index b8696df73..30702f00e 100644 --- a/app/controllers/concerns/rate_limit_headers.rb +++ b/app/controllers/concerns/rate_limit_headers.rb @@ -6,13 +6,11 @@ module RateLimitHeaders class_methods do def override_rate_limit_headers(method_name, options = {}) around_action(only: method_name, if: :current_account) do |_controller, block| - begin - block.call - ensure - rate_limiter = RateLimiter.new(current_account, options) - rate_limit_headers = rate_limiter.to_headers - response.headers.merge!(rate_limit_headers) unless response.headers['X-RateLimit-Remaining'].present? && rate_limit_headers['X-RateLimit-Remaining'].to_i > response.headers['X-RateLimit-Remaining'].to_i - end + block.call + ensure + rate_limiter = RateLimiter.new(current_account, options) + rate_limit_headers = rate_limiter.to_headers + response.headers.merge!(rate_limit_headers) unless response.headers['X-RateLimit-Remaining'].present? && rate_limit_headers['X-RateLimit-Remaining'].to_i > response.headers['X-RateLimit-Remaining'].to_i end end end @@ -67,6 +65,6 @@ module RateLimitHeaders end def reset_period_offset - api_throttle_data[:period] - request_time.to_i % api_throttle_data[:period] + api_throttle_data[:period] - (request_time.to_i % api_throttle_data[:period]) end end diff --git a/app/controllers/concerns/session_tracking_concern.rb b/app/controllers/concerns/session_tracking_concern.rb index eaaa4ac59..3f56c0d02 100644 --- a/app/controllers/concerns/session_tracking_concern.rb +++ b/app/controllers/concerns/session_tracking_concern.rb @@ -13,6 +13,7 @@ module SessionTrackingConcern def set_session_activity return unless session_needs_update? + current_session.touch end diff --git a/app/controllers/concerns/signature_verification.rb b/app/controllers/concerns/signature_verification.rb index 4502da698..931725943 100644 --- a/app/controllers/concerns/signature_verification.rb +++ b/app/controllers/concerns/signature_verification.rb @@ -46,11 +46,11 @@ module SignatureVerification end def require_account_signature! - render plain: signature_verification_failure_reason, status: signature_verification_failure_code unless signed_request_account + render json: signature_verification_failure_reason, status: signature_verification_failure_code unless signed_request_account end def require_actor_signature! - render plain: signature_verification_failure_reason, status: signature_verification_failure_code unless signed_request_actor + render json: signature_verification_failure_reason, status: signature_verification_failure_code unless signed_request_actor end def signed_request? @@ -97,11 +97,11 @@ module SignatureVerification actor = stoplight_wrap_request { actor_refresh_key!(actor) } - raise SignatureVerificationError, "Public key not found for key #{signature_params['keyId']}" if actor.nil? + raise SignatureVerificationError, "Could not refresh public key #{signature_params['keyId']}" if actor.nil? return actor unless verify_signature(actor, signature, compare_signed_string).nil? - fail_with! "Verification failed for #{actor.to_log_human_identifier} #{actor.uri} using rsa-sha256 (RSASSA-PKCS1-v1_5 with SHA-256)" + fail_with! "Verification failed for #{actor.to_log_human_identifier} #{actor.uri} using rsa-sha256 (RSASSA-PKCS1-v1_5 with SHA-256)", signed_string: compare_signed_string, signature: signature_params['signature'] rescue SignatureVerificationError => e fail_with! e.message rescue HTTP::Error, OpenSSL::SSL::SSLError => e @@ -118,8 +118,8 @@ module SignatureVerification private - def fail_with!(message) - @signature_verification_failure_reason = message + def fail_with!(message, **options) + @signature_verification_failure_reason = { error: message }.merge(options) @signed_request_actor = nil end @@ -138,7 +138,7 @@ module SignatureVerification end def signed_headers - signature_params.fetch('headers', signature_algorithm == 'hs2019' ? '(created)' : 'date').downcase.split(' ') + signature_params.fetch('headers', signature_algorithm == 'hs2019' ? '(created)' : 'date').downcase.split end def verify_signature_strength! @@ -165,6 +165,7 @@ module SignatureVerification end raise SignatureVerificationError, "Invalid Digest value. The provided Digest value is not a SHA-256 digest. Given digest: #{sha256[1]}" if digest_size != 32 + raise SignatureVerificationError, "Invalid Digest value. Computed SHA-256 digest: #{body_digest}; given: #{sha256[1]}" end @@ -209,8 +210,8 @@ module SignatureVerification end expires_time = Time.at(signature_params['expires'].to_i).utc if signature_params['expires'].present? - rescue ArgumentError - return false + rescue ArgumentError => e + raise SignatureVerificationError, "Invalid Date header: #{e.message}" end expires_time ||= created_time + 5.minutes unless created_time.nil? @@ -227,7 +228,7 @@ module SignatureVerification end def to_header_name(name) - name.split(/-/).map(&:capitalize).join('-') + name.split('-').map(&:capitalize).join('-') end def missing_required_signature_parameters? diff --git a/app/controllers/concerns/two_factor_authentication_concern.rb b/app/controllers/concerns/two_factor_authentication_concern.rb index c9477a1d4..b30cd354d 100644 --- a/app/controllers/concerns/two_factor_authentication_concern.rb +++ b/app/controllers/concerns/two_factor_authentication_concern.rb @@ -57,10 +57,10 @@ module TwoFactorAuthenticationConcern if valid_webauthn_credential?(user, webauthn_credential) on_authentication_success(user, :webauthn) - render json: { redirect_path: after_sign_in_path_for(user) }, status: :ok + render json: { redirect_path: after_sign_in_path_for(user) }, status: 200 else on_authentication_failure(user, :webauthn, :invalid_credential) - render json: { error: t('webauthn_credentials.invalid_credential') }, status: :unprocessable_entity + render json: { error: t('webauthn_credentials.invalid_credential') }, status: 422 end end @@ -81,13 +81,11 @@ module TwoFactorAuthenticationConcern @body_classes = 'lighter' @webauthn_enabled = user.webauthn_enabled? - @scheme_type = begin - if user.webauthn_enabled? && user_params[:otp_attempt].blank? - 'webauthn' - else - 'totp' - end - end + @scheme_type = if user.webauthn_enabled? && user_params[:otp_attempt].blank? + 'webauthn' + else + 'totp' + end set_locale { render :two_factor } end |