Skip to content

Commit 60eac4b

Browse files
committed
Improve get_names get_types
get_names, get_types changes: * added `service_fields` support * support nullable types * add `record*` support (flatten takes a single place) * fix `get_types` prints name instead of type in some cases * `get_types` exports `union` as `union_type`, `union_value` get_names, get_types are also added to a compiled schema, and can be called from it without any arguments. Closes #58, #56
1 parent 6f8d05c commit 60eac4b

File tree

3 files changed

+161
-77
lines changed

3 files changed

+161
-77
lines changed

avro_schema/frontend.lua

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1243,11 +1243,66 @@ export_helper = function(node, already_built)
12431243
end
12441244
end
12451245

1246+
local get_names_helper
1247+
get_names_helper = function(res, pos, names, rec)
1248+
local fields = rec.fields
1249+
for i = 1, #fields do
1250+
local ftype = fields[i].type
1251+
insert(names, fields[i].name)
1252+
if type(ftype) == 'string' then
1253+
res[pos] = concat(names, '.')
1254+
pos = pos + 1
1255+
elseif ftype.type == 'record' and not ftype.nullable then
1256+
pos = get_names_helper(res, pos, names, ftype)
1257+
elseif not ftype.type then -- union
1258+
local path = concat(names, '.')
1259+
res[pos] = path .. '.$type$'
1260+
res[pos + 1] = path
1261+
pos = pos + 2
1262+
else
1263+
-- record*, scalar*, fixed, array, map
1264+
res[pos] = concat(names, '.')
1265+
pos = pos + 1
1266+
end
1267+
remove(names)
1268+
end
1269+
return pos
1270+
end
1271+
1272+
local get_types_helper
1273+
get_types_helper = function(res, pos, rec)
1274+
local fields = rec.fields
1275+
for i = 1, #fields do
1276+
local ftype = fields[i].type
1277+
if type(ftype) == 'string' then
1278+
res[pos] = ftype
1279+
pos = pos + 1
1280+
elseif ftype.type == 'record' and not ftype.nullable then
1281+
pos = get_types_helper(res, pos, ftype)
1282+
elseif not ftype.type then -- union
1283+
res[pos] = "union_type"
1284+
res[pos + 1] = "union_value"
1285+
pos = pos + 2
1286+
else
1287+
-- record*, scalar*, fixed, array, map
1288+
local xtype = ftype.type
1289+
assert(type(xtype) == "string",
1290+
"Deep type declarations are not supported")
1291+
if ftype.nullable then xtype = xtype .. "*" end
1292+
res[pos] = xtype
1293+
pos = pos + 1
1294+
end
1295+
end
1296+
return pos
1297+
end
1298+
12461299
return {
12471300
create_schema = create_schema,
12481301
validate_data = validate_data,
12491302
create_ir = create_ir,
12501303
get_enum_symbol_map = get_enum_symbol_map,
12511304
get_union_tag_map = get_union_tag_map,
1252-
export_helper = export_helper
1305+
export_helper = export_helper,
1306+
get_names_helper = get_names_helper,
1307+
get_types_helper = get_types_helper
12531308
}

avro_schema/init.lua

Lines changed: 46 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -381,12 +381,20 @@ v0 = encode_proc(r, v0)]],
381381
})
382382
end
383383

384-
-- service fields, a subset of AVRO types
385-
local valid_service_field = {
386-
boolean = 1, int = 1, long = 1, float = 1,
387-
double = 1, string = 1, bytes = 1
388-
}
384+
local function validate_service_fields(sfs)
385+
-- service fields, a subset of AVRO types
386+
local valid_service_field = {
387+
boolean = 1, int = 1, long = 1, float = 1,
388+
double = 1, string = 1, bytes = 1
389+
}
390+
for i, field in ipairs(sfs) do
391+
if not valid_service_field[field] then
392+
error(format('service_fields[%d]: Invalid type: %s', i, field), 0)
393+
end
394+
end
395+
end
389396

397+
local get_names, get_types
390398
-- compile(schema)
391399
-- compile(schema1, schema2)
392400
-- compile({schema1, schema2, downgrade = true, service_fields = { ... }})
@@ -402,18 +410,18 @@ local function compile(...)
402410
args = args[1]
403411
end
404412
local service_fields = args.service_fields or {}
413+
-- Make private copy for get_names & get_types.
414+
service_fields = table.copy(service_fields)
405415
-- would be deleted after #85
406416
local alpha_nullable_record_xflatten = args.alpha_nullable_record_xflatten
407417
if type(service_fields) ~= 'table' then
408418
error('service_fields: Expecting a table', 0)
409419
end
410-
for i, field in ipairs(service_fields) do
411-
if not valid_service_field[field] then
412-
error(format('service_fields[%d]: Invalid type: %s', i, field), 0)
413-
end
414-
end
420+
validate_service_fields(service_fields)
415421
local list = {}
422+
local handler_schema_to
416423
for i = 1, n do
424+
handler_schema_to = args[i]
417425
insert(list, get_schema(args[i]))
418426
end
419427
if #list == 0 then
@@ -460,79 +468,46 @@ local function compile(...)
460468
xflatten = process_lua.xflatten,
461469
flatten_msgpack = process_msgpack.flatten,
462470
unflatten_msgpack = process_msgpack.unflatten,
463-
xflatten_msgpack = process_msgpack.xflatten
471+
xflatten_msgpack = process_msgpack.xflatten,
472+
get_names = function ()
473+
return get_names(handler_schema_to, service_fields)
474+
end,
475+
get_types = function ()
476+
return get_types(handler_schema_to, service_fields)
477+
end
464478
}
465479
end
466480
end
467481

468482
-----------------------------------------------------------------------
469483
-- misc
470-
local get_names_helper
471-
get_names_helper = function(res, pos, names, rec)
472-
local fields = rec.fields
473-
for i = 1, #fields do
474-
local ftype = fields[i].type
475-
insert(names, fields[i].name)
476-
if type(ftype) == 'string' then
477-
res[pos] = concat(names, '.')
478-
pos = pos + 1
479-
elseif ftype.type == 'record' then
480-
pos = get_names_helper(res, pos, names, ftype)
481-
elseif not ftype.type then -- union
482-
local path = concat(names, '.')
483-
res[pos] = path .. '.$type$'
484-
res[pos + 1] = path
485-
pos = pos + 2
486-
else
487-
res[pos] = concat(names, '.')
488-
pos = pos + 1
489-
end
490-
remove(names)
491-
end
492-
return pos
493-
end
494-
495-
local function get_names(schema_h)
484+
get_names = function(schema_h, service_fields)
496485
local schema = get_schema(schema_h)
497-
if type(schema) == 'table' and schema.type == 'record' then
498-
local res = {}
499-
local names = {}
500-
get_names_helper(res, 1, names, schema)
501-
return res
502-
else
503-
return {}
504-
end
505-
end
506-
507-
local get_types_helper
508-
get_types_helper = function(res, pos, rec)
509-
local fields = rec.fields
510-
for i = 1, #fields do
511-
local ftype = fields[i].type
512-
if type(ftype) == 'string' then
513-
res[pos] = ftype
514-
pos = pos + 1
515-
elseif ftype.type == 'record' then
516-
pos = get_types_helper(res, pos, ftype)
517-
elseif not ftype.type then -- union
518-
pos = pos + 2
519-
else
520-
res[pos] = ftype.name or ftype.type
521-
pos = pos + 1
522-
end
486+
service_fields = service_fields or {}
487+
validate_service_fields(service_fields)
488+
local res = {}
489+
for i = 1, #service_fields do
490+
insert(res, "$service_field$")
523491
end
524-
return pos
492+
assert(type(schema) == 'table' and schema.type == 'record' and
493+
not schema.nullable, "expected non-nullable record at the top level")
494+
local names = {}
495+
front.get_names_helper(res, #res + 1, names, schema)
496+
return res
525497
end
526498

527-
local function get_types(schema_h)
499+
get_types = function(schema_h, service_fields)
528500
local schema = get_schema(schema_h)
529-
if type(schema) == 'table' and schema.type == 'record' then
530-
local res = {}
531-
get_types_helper(res, 1, schema)
532-
return res
533-
else
534-
return {}
501+
service_fields = service_fields or {}
502+
validate_service_fields(service_fields)
503+
local res = {}
504+
for _, sf in ipairs(service_fields) do
505+
insert(res, sf)
535506
end
507+
assert(type(schema) == 'table' and schema.type == 'record' and
508+
not schema.nullable, "expected non-nullable record at the top level")
509+
front.get_types_helper(res, #res + 1, schema)
510+
return res
536511
end
537512

538513
local function export(schema_h)

test/api_tests/var.lua

Lines changed: 59 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ local msgpack = require('msgpack')
55

66
local test = tap.test('api-tests')
77

8-
test:plan(47)
8+
test:plan(53)
99

1010
test:is_deeply({schema.create()}, {false, 'Unknown Avro type: nil'},
1111
'error unknown type')
@@ -173,16 +173,20 @@ test:test("compile / int", function(test)
173173
end)
174174

175175
-- get_names
176-
test:is_deeply(schema.get_names(int), {}, 'get_names (int)')
176+
local ok, err_msg = pcall(schema.get_names, int)
177+
test:like(tostring(err_msg), "expected non%-nullable record at the top level",
178+
'get_names (int)')
177179
test:is_deeply(schema.get_names(foobar),
178180
{'A.X','A.Y','B.X','B.Y','C.$type$','C','D'},
179181
'get_names (FooBar)')
180182

181183
-- get_types
182-
test:is_deeply(schema.get_types(int), {}, 'get_types (int)')
184+
local ok, err_msg = pcall(schema.get_types, int)
185+
test:like(tostring(err_msg), "expected non%-nullable record at the top level",
186+
'get_types (int)')
183187
test:is_deeply(schema.get_types(foobar),
184-
{'double','double','double','double',nil,nil,'string'},
185-
'get_types (FooBar)')
188+
{'double','double','double','double','union_type','union_value',
189+
'string'}, 'get_types (FooBar)')
186190

187191
-- is
188192
test:is(schema.is(int), true, 'schema handle is a schema (1)')
@@ -333,5 +337,55 @@ local ok, err_mesg = compiled.xflatten({y={f1="a"}})
333337
test:like(err_mesg, "xflatten for nullable record is on developement stage",
334338
"nullable record xflatten prohibited")
335339

340+
local get_names_types = [[
341+
{"type": "record", "name": "X", "fields":[
342+
{"name": "x1", "type":"string*"},
343+
{"name": "x2", "type": {
344+
"type":"record", "name": "Y", "fields": [
345+
{"name":"y1", "type": "string"},
346+
{"name":"y2", "type": "long"}]}},
347+
{"name": "x3", "type": {
348+
"type":"record*","name": "Z", "fields": [
349+
{"name":"z1", "type": "string*"},
350+
{"name":"z2", "type": "long*"}]}},
351+
{"name": "x4", "type": ["int", "string*" ]},
352+
{"name": "x5", "type": {"type": "array*", "items": "int*"}},
353+
{"name": "x6", "type": {"type": "map", "values": "float"}},
354+
{"name": "x7", "type": {"type": "fixed*", "name":"W", "size":5}}
355+
]}
356+
]]
357+
get_names_types = json.decode(get_names_types)
358+
359+
local ok, handle = schema.create(get_names_types)
360+
local service_fields = {"string", "int"}
361+
local ok, compiled = schema.compile({handle, service_fields = service_fields})
362+
assert(ok, compiled)
363+
-- get_names
364+
test:is_deeply(schema.get_names(handle),
365+
{"x1","x2.y1","x2.y2","x3","x4.$type$","x4","x5","x6","x7"},
366+
"get_names")
367+
test:is_deeply(schema.get_names(handle, service_fields),
368+
{"$service_field$", "$service_field$", "x1","x2.y1","x2.y2","x3",
369+
"x4.$type$", "x4","x5","x6","x7"},
370+
"get_names sf")
371+
test:is_deeply(compiled.get_names(),
372+
{"$service_field$", "$service_field$", "x1","x2.y1","x2.y2","x3",
373+
"x4.$type$", "x4","x5","x6","x7"},
374+
"compiled.get_names")
375+
376+
-- get_types
377+
test:is_deeply(schema.get_types(handle),
378+
{"string*","string","long","record*","union_type","union_value",
379+
"array*","map","fixed*"},
380+
"get_types")
381+
test:is_deeply(schema.get_types(handle, service_fields),
382+
{"string", "int", "string*","string","long","record*",
383+
"union_type","union_value", "array*","map","fixed*"},
384+
"get_types sf")
385+
test:is_deeply(compiled.get_types(),
386+
{"string", "int", "string*","string","long","record*",
387+
"union_type","union_value", "array*","map","fixed*"},
388+
"compiled.get_types")
389+
336390
test:check()
337391
os.exit(test.planned == test.total and test.failed == 0 and 0 or -1)

0 commit comments

Comments
 (0)