diff --git a/src/cbor/decoder.cr b/src/cbor/decoder.cr index 510fb45..117759f 100644 --- a/src/cbor/decoder.cr +++ b/src/cbor/decoder.cr @@ -2,6 +2,19 @@ class CBOR::Decoder @lexer : Lexer getter current_token : Token::T? + # Decode until a certain point in the history (use_cbor_discriminator helper). + def reset(value : Int32 | Int64 = 0) + @lexer.reset + while pos < value + @current_token = @lexer.next_token + end + end + + # Give the current position in the decoder (use_cbor_discriminator helper). + def pos + @lexer.io.pos + end + def initialize(input) @lexer = Lexer.new(input) @current_token = @lexer.next_token diff --git a/src/cbor/from_cbor.cr b/src/cbor/from_cbor.cr index e2610e7..0374c1f 100644 --- a/src/cbor/from_cbor.cr +++ b/src/cbor/from_cbor.cr @@ -3,6 +3,10 @@ def Object.from_cbor(string_or_io) new(parser) end +def Object.from_cbor(parser : CBOR::Decoder) + new(parser) +end + def String.new(decoder : CBOR::Decoder) decoder.read_string end diff --git a/src/cbor/lexer.cr b/src/cbor/lexer.cr index 8ba2a11..883263a 100644 --- a/src/cbor/lexer.cr +++ b/src/cbor/lexer.cr @@ -1,4 +1,5 @@ class CBOR::Lexer + property io : IO def self.new(slice : Bytes) new IO::Memory.new(slice) end @@ -8,6 +9,11 @@ class CBOR::Lexer def initialize(@io : IO) end + def reset(value : Int32 | Int64 = 0) + @io.seek value + @eof = false + end + def next_token : Token::T? return nil if @eof diff --git a/src/cbor/serializable.cr b/src/cbor/serializable.cr index 4925470..1bd6d67 100644 --- a/src/cbor/serializable.cr +++ b/src/cbor/serializable.cr @@ -345,6 +345,36 @@ module CBOR end end + macro use_cbor_discriminator(field, mapping) + {% unless mapping.is_a?(HashLiteral) || mapping.is_a?(NamedTupleLiteral) %} + {% mapping.raise "mapping argument must be a HashLiteral or a NamedTupleLiteral, not #{mapping.class_name.id}" %} + {% end %} + + # SLOW. Read everything, get the type, read everything again. + def self.new(decoder : ::CBOR::Decoder) + current_offset = decoder.pos + if v = decoder.read_value + decoder.reset current_offset + case v + when Hash(CBOR::Type, CBOR::Type) + discriminator_value = v[{{field.id.stringify}}]? + case discriminator_value + {% for key, value in mapping %} + when {{key.id.stringify}} + return {{value.id}}.from_cbor(decoder) + {% end %} + else + raise "Unknown '{{field.id}}' discriminator value: #{discriminator_value.inspect}" + end + else + raise "cannot get cbor discriminator #{ {{ field.id.stringify }} }" + end + else + raise "cannot decode cbor value" + end + end + end + # Tells this class to decode CBOR by using a field as a discriminator. # # - *field* must be the field name to use as a discriminator