Skip to content

Commit 91b2a15

Browse files
Refactor MessageEncryptor and MessageVerifier
This factors common methods into a `ActiveSupport::Messages::Codec` base class. This also disentangles serialization (and deserialization) from encryption (and decryption) in `MessageEncryptor`.
1 parent 81ded4d commit 91b2a15

File tree

5 files changed

+116
-118
lines changed

5 files changed

+116
-118
lines changed

activesupport/lib/active_support/message_encryptor.rb

Lines changed: 27 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
require "openssl"
44
require "base64"
55
require "active_support/core_ext/module/attribute_accessors"
6+
require "active_support/messages/codec"
7+
require "active_support/messages/rotator"
68
require "active_support/message_verifier"
7-
require "active_support/messages/metadata"
89

910
module ActiveSupport
1011
# MessageEncryptor is a simple way to encrypt values which get stored
@@ -84,8 +85,7 @@ module ActiveSupport
8485
# the above should be combined into:
8586
#
8687
# crypt.rotate old_secret, cipher: "aes-256-cbc"
87-
class MessageEncryptor
88-
include Messages::Metadata
88+
class MessageEncryptor < Messages::Codec
8989
prepend Messages::Rotator::Encryptor
9090

9191
cattr_accessor :use_authenticated_message_encryption, instance_accessor: false, default: false
@@ -111,16 +111,6 @@ def self.dump(value)
111111
end
112112
end
113113

114-
module NullVerifier # :nodoc:
115-
def self.verify(value)
116-
value
117-
end
118-
119-
def self.generate(value)
120-
value
121-
end
122-
end
123-
124114
class InvalidMessage < StandardError; end
125115
OpenSSLCipherError = OpenSSL::Cipher::CipherError
126116

@@ -147,21 +137,13 @@ class InvalidMessage < StandardError; end
147137
# * <tt>:url_safe</tt> - Whether to encode messages using a URL-safe
148138
# encoding. Default is +false+ for backward compatibility.
149139
def initialize(secret, sign_secret = nil, cipher: nil, digest: nil, serializer: nil, url_safe: false)
140+
super(serializer: serializer || @@default_message_encryptor_serializer, url_safe: url_safe)
150141
@secret = secret
151-
@sign_secret = sign_secret
152142
@cipher = cipher || self.class.default_cipher
153143
@aead_mode = new_cipher.authenticated?
154-
@digest = digest || "SHA1" unless aead_mode?
155-
@serializer = serializer ||
156-
if @@default_message_encryptor_serializer.equal?(:marshal)
157-
Marshal
158-
elsif @@default_message_encryptor_serializer.equal?(:hybrid)
159-
JsonWithMarshalFallback
160-
elsif @@default_message_encryptor_serializer.equal?(:json)
161-
JSON
162-
end
163-
@url_safe = url_safe
164-
@verifier = resolve_verifier
144+
@verifier = if !@aead_mode
145+
MessageVerifier.new(sign_secret || secret, digest: digest || "SHA1", serializer: NullSerializer, url_safe: url_safe)
146+
end
165147
end
166148

167149
# Encrypt and sign a message. We need to sign the message in order to avoid
@@ -191,8 +173,8 @@ def initialize(secret, sign_secret = nil, cipher: nil, digest: nil, serializer:
191173
# The purpose of the message. If specified, the same purpose must be
192174
# specified when verifying the message; otherwise, verification will fail.
193175
# (See #decrypt_and_verify.)
194-
def encrypt_and_sign(value, expires_at: nil, expires_in: nil, purpose: nil)
195-
verifier.generate(_encrypt(value, expires_at: expires_at, expires_in: expires_in, purpose: purpose))
176+
def encrypt_and_sign(value, **options)
177+
sign(encrypt(serialize_with_metadata(value, **options)))
196178
end
197179

198180
# Decrypt and verify a message. We need to verify the message in order to
@@ -212,8 +194,10 @@ def encrypt_and_sign(value, expires_at: nil, expires_in: nil, purpose: nil)
212194
# encryptor.decrypt_and_verify(message) # => "bye"
213195
# encryptor.decrypt_and_verify(message, purpose: "greeting") # => nil
214196
#
215-
def decrypt_and_verify(data, purpose: nil, **)
216-
_decrypt(verifier.verify(data), purpose)
197+
def decrypt_and_verify(message, **options)
198+
deserialize_with_metadata(decrypt(verify(message)), **options)
199+
rescue TypeError, ArgumentError, ::JSON::ParserError
200+
raise InvalidMessage
217201
end
218202

219203
# Given a cipher, returns the key length of the cipher to help generate the key of desired size
@@ -222,17 +206,15 @@ def self.key_len(cipher = default_cipher)
222206
end
223207

224208
private
225-
attr_reader :serializer
226-
227-
def encode(data)
228-
@url_safe ? ::Base64.urlsafe_encode64(data, padding: false) : ::Base64.strict_encode64(data)
209+
def sign(data)
210+
@verifier ? @verifier.generate(data) : data
229211
end
230212

231-
def decode(data)
232-
@url_safe ? ::Base64.urlsafe_decode64(data) : ::Base64.strict_decode64(data)
213+
def verify(data)
214+
@verifier ? @verifier.verify(data) : data
233215
end
234216

235-
def _encrypt(value, **metadata_options)
217+
def encrypt(data)
236218
cipher = new_cipher
237219
cipher.encrypt
238220
cipher.key = @secret
@@ -241,16 +223,16 @@ def _encrypt(value, **metadata_options)
241223
iv = cipher.random_iv
242224
cipher.auth_data = "" if aead_mode?
243225

244-
encrypted_data = cipher.update(serialize_with_metadata(value, **metadata_options))
226+
encrypted_data = cipher.update(data)
245227
encrypted_data << cipher.final
246228

247229
parts = [encrypted_data, iv]
248230
parts << cipher.auth_tag(AUTH_TAG_LENGTH) if aead_mode?
249231

250-
parts.map! { |part| encode(part) }.join(SEPARATOR)
232+
join_parts(parts)
251233
end
252234

253-
def _decrypt(encrypted_message, purpose)
235+
def decrypt(encrypted_message)
254236
cipher = new_cipher
255237
encrypted_data, iv, auth_tag = extract_parts(encrypted_message)
256238

@@ -269,9 +251,7 @@ def _decrypt(encrypted_message, purpose)
269251

270252
decrypted_data = cipher.update(encrypted_data)
271253
decrypted_data << cipher.final
272-
273-
deserialize_with_metadata(decrypted_data, purpose: purpose)
274-
rescue OpenSSLCipherError, TypeError, ArgumentError, ::JSON::ParserError
254+
rescue OpenSSLCipherError
275255
raise InvalidMessage
276256
end
277257

@@ -291,6 +271,10 @@ def length_of_encoded_auth_tag
291271
@length_of_encoded_auth_tag ||= length_after_encode(AUTH_TAG_LENGTH)
292272
end
293273

274+
def join_parts(parts)
275+
parts.map! { |part| encode(part) }.join(SEPARATOR)
276+
end
277+
294278
def extract_part(encrypted_message, rindex, length)
295279
index = rindex - length
296280

@@ -322,15 +306,7 @@ def new_cipher
322306
OpenSSL::Cipher.new(@cipher)
323307
end
324308

325-
attr_reader :verifier, :aead_mode
309+
attr_reader :aead_mode
326310
alias :aead_mode? :aead_mode
327-
328-
def resolve_verifier
329-
if aead_mode?
330-
NullVerifier
331-
else
332-
MessageVerifier.new(@sign_secret || @secret, digest: @digest, serializer: NullSerializer, url_safe: @url_safe)
333-
end
334-
end
335311
end
336312
end

activesupport/lib/active_support/message_verifier.rb

Lines changed: 27 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
require "base64"
55
require "active_support/core_ext/object/blank"
66
require "active_support/security_utils"
7-
require "active_support/messages/metadata"
7+
require "active_support/messages/codec"
88
require "active_support/messages/rotator"
99

1010
module ActiveSupport
@@ -118,8 +118,7 @@ module ActiveSupport
118118
#
119119
# @verifier = ActiveSupport::MessageVerifier.new("secret", url_safe: true)
120120
# @verifier.generate("signed message") #=> URL-safe string
121-
class MessageVerifier
122-
include Messages::Metadata
121+
class MessageVerifier < Messages::Codec
123122
prepend Messages::Rotator::Verifier
124123

125124
class InvalidSignature < StandardError; end
@@ -131,17 +130,9 @@ class InvalidSignature < StandardError; end
131130

132131
def initialize(secret, digest: nil, serializer: nil, url_safe: false)
133132
raise ArgumentError, "Secret should not be nil." unless secret
133+
super(serializer: serializer || @@default_message_verifier_serializer, url_safe: url_safe)
134134
@secret = secret
135135
@digest = digest&.to_s || "SHA1"
136-
@serializer = serializer ||
137-
if @@default_message_verifier_serializer.equal?(:marshal)
138-
Marshal
139-
elsif @@default_message_verifier_serializer.equal?(:hybrid)
140-
JsonWithMarshalFallback
141-
elsif @@default_message_verifier_serializer.equal?(:json)
142-
JSON
143-
end
144-
@url_safe = url_safe
145136
end
146137

147138
# Checks if a signed message could have been generated by signing an object
@@ -153,9 +144,8 @@ def initialize(secret, digest: nil, serializer: nil, url_safe: false)
153144
#
154145
# tampered_message = signed_message.chop # editing the message invalidates the signature
155146
# verifier.valid_message?(tampered_message) # => false
156-
def valid_message?(signed_message)
157-
data, digest = get_data_and_digest_from(signed_message)
158-
digest_matches_data?(digest, data)
147+
def valid_message?(message)
148+
!!extract_encoded(message)
159149
end
160150

161151
# Decodes the signed message using the +MessageVerifier+'s secret.
@@ -195,16 +185,11 @@ def valid_message?(signed_message)
195185
# verifier.verified(message) # => "bye"
196186
# verifier.verified(message, purpose: "greeting") # => nil
197187
#
198-
def verified(signed_message, purpose: nil, **)
199-
data, digest = get_data_and_digest_from(signed_message)
200-
if digest_matches_data?(digest, data)
201-
begin
202-
deserialize_with_metadata(decode(data), purpose: purpose)
203-
rescue ArgumentError => argument_error
204-
return if argument_error.message.include?("invalid base64")
205-
raise
206-
end
207-
end
188+
def verified(message, **options)
189+
encoded = extract_encoded(message)
190+
deserialize_with_metadata(decode(encoded), **options) if encoded
191+
rescue ArgumentError => error
192+
raise unless error.message.include?("invalid base64")
208193
end
209194

210195
# Decodes the signed message using the +MessageVerifier+'s secret.
@@ -273,21 +258,27 @@ def verify(*args, **options)
273258
# The purpose of the message. If specified, the same purpose must be
274259
# specified when verifying the message; otherwise, verification will fail.
275260
# (See #verified and #verify.)
276-
def generate(value, expires_at: nil, expires_in: nil, purpose: nil)
277-
data = encode(serialize_with_metadata(value, expires_at: expires_at, expires_in: expires_in, purpose: purpose))
278-
digest = generate_digest(data)
279-
data << SEPARATOR << digest
261+
def generate(value, **options)
262+
sign_encoded(encode(serialize_with_metadata(value, **options)))
280263
end
281264

282265
private
283-
attr_reader :serializer
284-
285-
def encode(data)
286-
@url_safe ? Base64.urlsafe_encode64(data, padding: false) : Base64.strict_encode64(data)
266+
def sign_encoded(encoded)
267+
digest = generate_digest(encoded)
268+
encoded << SEPARATOR << digest
287269
end
288270

289-
def decode(data)
290-
@url_safe ? Base64.urlsafe_decode64(data) : Base64.strict_decode64(data)
271+
def extract_encoded(signed)
272+
return if signed.nil? || !signed.valid_encoding?
273+
274+
if separator_index = separator_index_for(signed)
275+
encoded = signed[0, separator_index]
276+
digest = signed[separator_index + SEPARATOR_LENGTH, digest_length_in_hex]
277+
end
278+
279+
return unless digest_matches_data?(digest, encoded)
280+
281+
encoded
291282
end
292283

293284
def generate_digest(data)
@@ -308,21 +299,7 @@ def separator_at?(signed_message, index)
308299

309300
def separator_index_for(signed_message)
310301
index = signed_message.length - digest_length_in_hex - SEPARATOR_LENGTH
311-
return if index.negative? || !separator_at?(signed_message, index)
312-
313-
index
314-
end
315-
316-
def get_data_and_digest_from(signed_message)
317-
return if signed_message.nil? || !signed_message.valid_encoding? || signed_message.empty?
318-
319-
separator_index = separator_index_for(signed_message)
320-
return if separator_index.nil?
321-
322-
data = signed_message[0, separator_index]
323-
digest = signed_message[separator_index + SEPARATOR_LENGTH, digest_length_in_hex]
324-
325-
[data, digest]
302+
index unless index.negative? || !separator_at?(signed_message, index)
326303
end
327304

328305
def digest_matches_data?(digest, data)
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# frozen_string_literal: true
2+
3+
require "active_support/messages/metadata"
4+
5+
module ActiveSupport
6+
module Messages # :nodoc:
7+
class Codec # :nodoc:
8+
include Metadata
9+
10+
def initialize(serializer:, url_safe:)
11+
@serializer =
12+
case serializer
13+
when :marshal
14+
Marshal
15+
when :hybrid
16+
JsonWithMarshalFallback
17+
when :json
18+
JSON
19+
else
20+
serializer
21+
end
22+
23+
@url_safe = url_safe
24+
end
25+
26+
private
27+
attr_reader :serializer
28+
29+
def encode(data, url_safe: @url_safe)
30+
url_safe ? ::Base64.urlsafe_encode64(data, padding: false) : ::Base64.strict_encode64(data)
31+
end
32+
33+
def decode(encoded, url_safe: @url_safe)
34+
url_safe ? ::Base64.urlsafe_decode64(encoded) : ::Base64.strict_decode64(encoded)
35+
end
36+
37+
def serialize(data)
38+
serializer.dump(data)
39+
end
40+
41+
def deserialize(serialized)
42+
serializer.load(serialized)
43+
end
44+
end
45+
end
46+
end

activesupport/lib/active_support/messages/metadata.rb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def serialize_with_metadata(data, **metadata)
2525
ActiveSupport::JSON.encode(envelope)
2626
else
2727
data = wrap_in_metadata_envelope({ "data" => data }, **metadata) if has_metadata
28-
serializer.dump(data)
28+
serialize(data)
2929
end
3030
end
3131

@@ -35,7 +35,7 @@ def deserialize_with_metadata(message, **expected_metadata)
3535
extracted = extract_from_metadata_envelope(envelope, **expected_metadata)
3636
deserialize_from_json_safe_string(extracted["message"]) if extracted
3737
else
38-
deserialized = serializer.load(message)
38+
deserialized = deserialize(message)
3939
if metadata_envelope?(deserialized)
4040
extracted = extract_from_metadata_envelope(deserialized, **expected_metadata)
4141
extracted["data"] if extracted
@@ -90,11 +90,11 @@ def parse_expiry(expires_at)
9090
end
9191

9292
def serialize_to_json_safe_string(data)
93-
::Base64.strict_encode64(serializer.dump(data))
93+
encode(serialize(data), url_safe: false)
9494
end
9595

9696
def deserialize_from_json_safe_string(string)
97-
serializer.load(::Base64.strict_decode64(string))
97+
deserialize(decode(string, url_safe: false))
9898
end
9999
end
100100
end

0 commit comments

Comments
 (0)