diff --git a/lib/typeprof/core/ast/sig_decl.rb b/lib/typeprof/core/ast/sig_decl.rb index 886b34b8..c20f50e0 100644 --- a/lib/typeprof/core/ast/sig_decl.rb +++ b/lib/typeprof/core/ast/sig_decl.rb @@ -21,14 +21,21 @@ def initialize(raw_decl, lenv) end.compact # TODO?: param.variance, param.unchecked, param.upper_bound @params = raw_decl.type_params.map {|param| param.name } + @params_default_types = raw_decl.type_params.map do |param| + ty = param.default_type + ty ? AST.create_rbs_type(ty, lenv) : nil + end end - attr_reader :cpath, :members, :params + attr_reader :cpath, :members, :params, :params_default_types - def subnodes = { members: } + def subnodes = { members:, params_default_types: } def attrs = { cpath:, params: } def define0(genv) + @params_default_types.each do |ty| + ty&.define(genv) + end @members.each do |member| member.define(genv) end @@ -46,6 +53,9 @@ def define_copy(genv) def undefine0(genv) mod = genv.resolve_cpath(@cpath) mod.remove_module_decl(genv, self) + @params_default_types.each do |ty| + ty&.undefine(genv) + end @members.each do |member| member.undefine(genv) end diff --git a/lib/typeprof/core/ast/sig_type.rb b/lib/typeprof/core/ast/sig_type.rb index c400d916..aa0c896f 100644 --- a/lib/typeprof/core/ast/sig_type.rb +++ b/lib/typeprof/core/ast/sig_type.rb @@ -33,7 +33,7 @@ def self.typecheck_for_prepended_modules(genv, changes, a_ty, f_mod, f_args, sub if prep_decl.is_a?(AST::SigPrependNode) && prep_mod.type_params prep_ty = genv.get_instance_type(prep_mod, prep_decl.args, changes, {}, a_ty) else - type_params = prep_mod.type_params.map {|ty_param| Source.new() } # TODO: better support + type_params = prep_mod.type_params.map {|(_name, _default_ty)| Source.new() } # TODO: better support prep_ty = Type::Instance.new(genv, prep_mod, type_params) end if prep_ty.mod == f_mod @@ -58,7 +58,7 @@ def self.typecheck_for_included_modules(genv, changes, a_ty, f_mod, f_args, subs if inc_decl.is_a?(AST::SigIncludeNode) && inc_mod.type_params inc_ty = genv.get_instance_type(inc_mod, inc_decl.args, changes, {}, a_ty) else - type_params = inc_mod.type_params.map {|ty_param| Source.new() } # TODO: better support + type_params = inc_mod.type_params.map {|(_name, _default_ty)| Source.new() } # TODO: better support inc_ty = Type::Instance.new(genv, inc_mod, type_params) end if inc_ty.mod == f_mod diff --git a/lib/typeprof/core/env.rb b/lib/typeprof/core/env.rb index 13a4a8e5..4facb095 100644 --- a/lib/typeprof/core/env.rb +++ b/lib/typeprof/core/env.rb @@ -113,16 +113,20 @@ def get_superclass(singleton, mod) def get_instance_type(mod, type_args, changes, base_ty_env, base_ty) ty_env = base_ty_env.dup if base_ty.is_a?(Type::Instance) - base_ty.mod.type_params.zip(base_ty.args) do |param, arg| - ty_env[param] = arg || Source.new + base_ty.mod.type_params.zip(base_ty.args) do |(param, default_ty), arg| + ty_env[param] = arg || (default_ty ? default_ty.covariant_vertex(self, changes, ty_env) : Source.new) end elsif base_ty.is_a?(Type::Singleton) - base_ty.mod.type_params&.each do |param| - ty_env[param] = Source.new + base_ty.mod.type_params&.each do |(param, default_ty)| + ty_env[param] = default_ty ? default_ty.covariant_vertex(self, changes, ty_env) : Source.new end end - args = mod.type_params.zip(type_args).map do |param, arg| - arg && changes ? arg.covariant_vertex(self, changes, ty_env) : Source.new + args = mod.type_params.zip(type_args).map do |(param, default_ty), arg| + if changes + (arg || default_ty)&.covariant_vertex(self, changes, ty_env) || Source.new + else + Source.new + end end Type::Instance.new(self, mod, args) end @@ -380,7 +384,7 @@ def get_self(genv) case @scope_level when :instance mod = genv.resolve_cpath(@cpath || []) - type_params = mod.type_params.map {|ty_param| Source.new() } # TODO: better support + type_params = mod.type_params.map {|(_name, _default_ty)| Source.new() } # TODO: better support ty = Type::Instance.new(genv, mod, type_params) Source.new(ty) when :class diff --git a/lib/typeprof/core/env/module_entity.rb b/lib/typeprof/core/env/module_entity.rb index 005a83ce..9ca3d39c 100644 --- a/lib/typeprof/core/env/module_entity.rb +++ b/lib/typeprof/core/env/module_entity.rb @@ -25,7 +25,7 @@ def initialize(cpath, outer_module = self) # class Foo[X, Y, Z] < Bar[A, B, C] @superclass_type_args = nil # A, B, C - @type_params = [] # X, Y, Z + @type_params = {} # X, Y, Z @consts = {} @methods = { true => {}, false => {} } @@ -115,9 +115,10 @@ def add_module_decl(genv, decl) @module_decls << decl if @type_params - update_type_params if @type_params != decl.params + update_type_params else - @type_params = decl.params + @type_params = {} + decl.params.zip(decl.params_default_types) {|name, default_type| @type_params[name] = default_type } end if decl.is_a?(AST::SigClassNode) && !@superclass_type_args @@ -133,7 +134,7 @@ def remove_module_decl(genv, decl) @outer_module.get_const(get_cname).remove_decl(decl) @module_decls.delete(decl) || raise - update_type_params if @type_params == decl.params + update_type_params if decl.is_a?(AST::SigClassNode) && @superclass_type_args == decl.superclass_args @superclass_type_args = nil @module_decls.each do |decl| @@ -152,13 +153,12 @@ def update_type_params @module_decls.each do |decl| params = decl.params next unless params - if @type_params - @type_params = params if (@type_params <=> params) > 0 - else - @type_params = params + if !@type_params || @type_params.size < params.size + @type_params = {} + params.zip(decl.params_default_types) {|name, default_type| @type_params[name] = default_type } end end - @type_params ||= [] + @type_params ||= {} # TODO: report an error if there are multiple inconsistent declarations end diff --git a/lib/typeprof/core/graph/box.rb b/lib/typeprof/core/graph/box.rb index 103bb0d9..9c62aeb6 100644 --- a/lib/typeprof/core/graph/box.rb +++ b/lib/typeprof/core/graph/box.rb @@ -87,7 +87,7 @@ def run0(genv, changes) mod = genv.resolve_cpath(@node.cpath) if mod.type_params && !mod.type_params.empty? # Create a substitution map where each type parameter maps to a type variable vertex - subst = mod.type_params.to_h do |param| + subst = mod.type_params.to_h do |param, _default_ty| type_var_vtx = Vertex.new(@node) [param, type_var_vtx] end @@ -185,8 +185,8 @@ def match_arguments?(genv, changes, param_map, a_args, method_type) def resolve_overload(changes, genv, method_type, node, param_map, a_args, ret, force) param_map0 = param_map.dup if method_type.type_params - method_type.type_params.zip(yield(method_type)) do |var, vtx| - param_map0[var] = vtx + method_type.type_params.zip(yield(method_type)) do |(var, _default_ty), vtx| + param_map0[var] = vtx # TODO: default_ty? end end @@ -443,17 +443,17 @@ def run0(genv, changes) ty = Type::Singleton.new(genv, mod) param_map0 = Type.default_param_map(genv, ty) else - type_params = mod.type_params.map {|ty_param| Source.new() } # TODO: better support + type_params = mod.type_params.map {|(_name, _default_ty)| Source.new() } # TODO: better support ty = Type::Instance.new(genv, mod, type_params) param_map0 = Type.default_param_map(genv, ty) if ty.is_a?(Type::Instance) - ty.mod.type_params.zip(ty.args) do |param, arg| - param_map0[param] = arg + ty.mod.type_params.zip(ty.args) do |(name, _default_ty), arg| + param_map0[name] = arg end end end - method_type.type_params.each do |param| - param_map0[param] = Source.new() + method_type.type_params.each do |name, _default_ty| + param_map0[name] = Source.new() end positional_args = [] @@ -739,12 +739,12 @@ def run0(genv, changes) # TODO: add_depended_method_entity for types used to resolve overloads ty_env = Type.default_param_map(genv, orig_ty) if ty.is_a?(Type::Instance) - ty.mod.type_params.zip(ty.args) do |param, arg| - ty_env[param] = arg + ty.mod.type_params.zip(ty.args) do |(param, default_ty), arg| + ty_env[param] = arg || (default_ty ? default_ty.covariant_vertex(genv, changes, ty_env) : Source.new) end end mdecl.resolve_overloads(changes, genv, @node, ty_env, @a_args, @ret) do |method_type| - @generics[method_type] ||= method_type.type_params.map {|var| Vertex.new(@node) } + @generics[method_type] ||= method_type.type_params.map { Vertex.new(@node) } end end elsif !me.defs.empty? @@ -846,7 +846,7 @@ def resolve_prepended_modules(genv, changes, base_ty_env, ty, mid, &blk) if prep_decl.is_a?(AST::SigPrependNode) && prep_mod.type_params prep_ty = genv.get_instance_type(prep_mod, prep_decl.args, changes, base_ty_env, ty) else - type_params = prep_mod.type_params.map {|ty_param| Source.new() } # TODO: better support + type_params = prep_mod.type_params.map { Source.new() } # TODO: better support prep_ty = Type::Instance.new(genv, prep_mod, type_params) end @@ -901,7 +901,7 @@ def resolve_included_modules(genv, changes, base_ty_env, ty, mid, &blk) if inc_decl.is_a?(AST::SigIncludeNode) && inc_mod.type_params inc_ty = genv.get_instance_type(inc_mod, inc_decl.args, changes, base_ty_env, ty) else - type_params = inc_mod.type_params.map {|ty_param| Source.new() } # TODO: better support + type_params = inc_mod.type_params.map { Source.new() } # TODO: better support inc_ty = Type::Instance.new(genv, inc_mod, type_params) end diff --git a/lib/typeprof/core/type.rb b/lib/typeprof/core/type.rb index 960840d3..cda5377b 100644 --- a/lib/typeprof/core/type.rb +++ b/lib/typeprof/core/type.rb @@ -60,7 +60,7 @@ def show def get_instance_type(genv) params = @mod.type_params - Instance.new(genv, @mod, params ? params.map { Source.new } : []) + Instance.new(genv, @mod, params ? params.map { Source.new } : []) # TODO: respect param_default_types end end diff --git a/scenario/rbs/param-default-type.rb b/scenario/rbs/param-default-type.rb new file mode 100644 index 00000000..d2af2013 --- /dev/null +++ b/scenario/rbs/param-default-type.rb @@ -0,0 +1,27 @@ +## update: test.rbs +class Foo[X, Y = Integer] + def get_x: -> X + def get_y: -> Y +end + +class Object + def create_foo_str_int: -> Foo[String] + def create_foo_str_str: -> Foo[String, String] +end + +## update: test.rb +def check1 + x = create_foo_str_int + [x.get_x, x.get_y] +end + +def check2 + x = create_foo_str_str + [x.get_x, x.get_y] +end + +## assert +class Object + def check1: -> [String, Integer] + def check2: -> [String, String] +end