Skip to content

Commit

Permalink
Merge pull request #30577 from JuliaLang/jb/splatnew
Browse files Browse the repository at this point in the history
allow splatting in calls to `new`
  • Loading branch information
JeffBezanson authored Feb 7, 2019
2 parents c6c3d72 + e456a72 commit 2ecc499
Show file tree
Hide file tree
Showing 18 changed files with 200 additions and 67 deletions.
11 changes: 7 additions & 4 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@ Julia v1.2 Release Notes
New language features
---------------------

* The `extrema` function now accepts a function argument in the same manner as `minimum` and
`maximum` ([#30323]).
* `hasmethod` can now check for matching keyword argument names ([#30712]).
* `startswith` and `endswith` now accept a `Regex` for the second argument ([#29790]).
* Argument splatting (`x...`) can now be used in calls to the `new` pseudo-function in
constructors ([#30577]).

Multi-threading changes
-----------------------
Expand Down Expand Up @@ -35,6 +33,11 @@ New library functions
Standard library changes
------------------------

* The `extrema` function now accepts a function argument in the same manner as `minimum` and
`maximum` ([#30323]).
* `hasmethod` can now check for matching keyword argument names ([#30712]).
* `startswith` and `endswith` now accept a `Regex` for the second argument ([#29790]).

#### LinearAlgebra

* Added keyword arguments `rtol`, `atol` to `pinv` and `nullspace` ([#29998]).
Expand Down
29 changes: 2 additions & 27 deletions base/boot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -548,33 +548,8 @@ NamedTuple{names}(args::Tuple) where {names} = NamedTuple{names,typeof(args)}(ar

using .Intrinsics: sle_int, add_int

macro generated()
return Expr(:generated)
end

function NamedTuple{names,T}(args::T) where {names, T <: Tuple}
if @generated
N = nfields(names)
flds = Array{Any,1}(undef, N)
i = 1
while sle_int(i, N)
arrayset(false, flds, :(getfield(args, $i)), i)
i = add_int(i, 1)
end
Expr(:new, :(NamedTuple{names,T}), flds...)
else
N = nfields(names)
NT = NamedTuple{names,T}
flds = Array{Any,1}(undef, N)
i = 1
while sle_int(i, N)
arrayset(false, flds, getfield(args, i), i)
i = add_int(i, 1)
end
ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), NT,
ccall(:jl_array_ptr, Ptr{Cvoid}, (Any,), flds), toUInt32(N))::NT
end
end
eval(Core, :(NamedTuple{names,T}(args::T) where {names, T <: Tuple} =
$(Expr(:splatnew, :(NamedTuple{names,T}), :args))))

# constructors for built-in types

Expand Down
3 changes: 3 additions & 0 deletions base/compiler/abstractinterpretation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,9 @@ function abstract_eval(@nospecialize(e), vtypes::VarTable, sv::InferenceState)
t = Const(ccall(:jl_new_structv, Any, (Any, Ptr{Cvoid}, UInt32), t, args, length(args)))
end
end
elseif e.head === :splatnew
t = instanceof_tfunc(abstract_eval(e.args[1], vtypes, sv))[1]
# TODO: improve
elseif e.head === :&
abstract_eval(e.args[1], vtypes, sv)
t = Any
Expand Down
25 changes: 25 additions & 0 deletions base/compiler/ssair/inlining.jl
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,31 @@ function assemble_inline_todo!(ir::IRCode, linetable::Vector{LineInfoNode}, sv::
todo = Any[]
for idx in 1:length(ir.stmts)
stmt = ir.stmts[idx]

if isexpr(stmt, :splatnew)
ty = ir.types[idx]
nf = nfields_tfunc(ty)
if nf isa Const
eargs = stmt.args
tup = eargs[2]
tt = argextype(tup, ir, sv.sptypes)
tnf = nfields_tfunc(tt)
if tnf isa Const && tnf.val <= nf.val
n = tnf.val
new_argexprs = Any[eargs[1]]
for j = 1:n
atype = getfield_tfunc(tt, Const(j))
new_call = Expr(:call, Core.getfield, tup, j)
new_argexpr = insert_node!(ir, idx, atype, new_call)
push!(new_argexprs, new_argexpr)
end
stmt.head = :new
stmt.args = new_argexprs
end
end
continue
end

isexpr(stmt, :call) || continue
eargs = stmt.args
isempty(eargs) && continue
Expand Down
2 changes: 1 addition & 1 deletion base/compiler/ssair/ir.jl
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ function getindex(x::UseRef)
end

function is_relevant_expr(e::Expr)
return e.head in (:call, :invoke, :new, :(=), :(&),
return e.head in (:call, :invoke, :new, :splatnew, :(=), :(&),
:gc_preserve_begin, :gc_preserve_end,
:foreigncall, :isdefined, :copyast,
:undefcheck, :throw_undef_if_not,
Expand Down
20 changes: 10 additions & 10 deletions base/compiler/tfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -329,17 +329,17 @@ function sizeof_tfunc(@nospecialize(x),)
return Int
end
add_tfunc(Core.sizeof, 1, 1, sizeof_tfunc, 0)
add_tfunc(nfields, 1, 1,
function (@nospecialize(x),)
isa(x, Const) && return Const(nfields(x.val))
isa(x, Conditional) && return Const(0)
if isa(x, DataType) && !x.abstract && !(x.name === Tuple.name && isvatuple(x))
if !(x.name === _NAMEDTUPLE_NAME && !isconcretetype(x))
return Const(length(x.types))
end
function nfields_tfunc(@nospecialize(x))
isa(x, Const) && return Const(nfields(x.val))
isa(x, Conditional) && return Const(0)
if isa(x, DataType) && !x.abstract && !(x.name === Tuple.name && isvatuple(x))
if !(x.name === _NAMEDTUPLE_NAME && !isconcretetype(x))
return Const(length(x.types))
end
return Int
end, 0)
end
return Int
end
add_tfunc(nfields, 1, 1, nfields_tfunc, 0)
add_tfunc(Core._expr, 1, INT_INF, (@nospecialize args...)->Expr, 100)
function typevar_tfunc(@nospecialize(n), @nospecialize(lb_arg), @nospecialize(ub_arg))
lb = Union{}
Expand Down
5 changes: 3 additions & 2 deletions base/compiler/validation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ const VALID_EXPR_HEADS = IdDict{Any,Any}(
:method => 1:4,
:const => 1:1,
:new => 1:typemax(Int),
:splatnew => 2:2,
:return => 1:1,
:unreachable => 0:0,
:the_exception => 0:0,
Expand Down Expand Up @@ -142,7 +143,7 @@ function validate_code!(errors::Vector{>:InvalidCodeError}, c::CodeInfo, is_top_
head === :inbounds || head === :foreigncall || head === :cfunction ||
head === :const || head === :enter || head === :leave || head == :pop_exception ||
head === :method || head === :global || head === :static_parameter ||
head === :new || head === :thunk || head === :simdloop ||
head === :new || head === :splatnew || head === :thunk || head === :simdloop ||
head === :throw_undef_if_not || head === :unreachable
validate_val!(x)
else
Expand Down Expand Up @@ -224,7 +225,7 @@ end

function is_valid_rvalue(@nospecialize(x))
is_valid_argument(x) && return true
if isa(x, Expr) && x.head in (:new, :the_exception, :isdefined, :call, :invoke, :foreigncall, :cfunction, :gc_preserve_begin, :copyast)
if isa(x, Expr) && x.head in (:new, :splatnew, :the_exception, :isdefined, :call, :invoke, :foreigncall, :cfunction, :gc_preserve_begin, :copyast)
return true
end
return false
Expand Down
7 changes: 7 additions & 0 deletions base/essentials.jl
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,13 @@ convert(::Type{Tuple{Vararg{V}}}, x::Tuple{Vararg{V}}) where {V} = x
convert(T::Type{Tuple{Vararg{V}}}, x::Tuple) where {V} =
(convert(tuple_type_head(T), x[1]), convert(T, tail(x))...)

# used for splatting in `new`
convert_prefix(::Type{Tuple{}}, x::Tuple) = x
convert_prefix(::Type{<:AtLeast1}, x::Tuple{}) = x
convert_prefix(::Type{T}, x::T) where {T<:AtLeast1} = x
convert_prefix(::Type{T}, x::AtLeast1) where {T<:AtLeast1} =
(convert(tuple_type_head(T), x[1]), convert_prefix(tuple_type_tail(T), tail(x))...)

# TODO: the following definitions are equivalent (behaviorally) to the above method
# I think they may be faster / more efficient for inference,
# if we could enable them, but are they?
Expand Down
4 changes: 2 additions & 2 deletions base/show.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1223,8 +1223,8 @@ function show_unquoted(io::IO, ex::Expr, indent::Int, prec::Int)
end

# new expr
elseif head === :new
show_enclosed_list(io, "%new(", args, ", ", ")", indent)
elseif head === :new || head === :splatnew
show_enclosed_list(io, "%$head(", args, ", ", ")", indent)

# other call-like expressions ("A[1,2]", "T{X,Y}", "f.(X,Y)")
elseif haskey(expr_calls, head) && nargs >= 1 # :ref/:curly/:calldecl/:(.)
Expand Down
5 changes: 5 additions & 0 deletions doc/src/devdocs/ast.md
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,11 @@ These symbols appear in the `head` field of [`Expr`](@ref)s in lowered form.
to this, and the type is always inserted by the compiler. This is very much an internal-only
feature, and does no checking. Evaluating arbitrary `new` expressions can easily segfault.

* `splatnew`

Similar to `new`, except field values are passed as a single tuple. Works similarly to
`Base.splat(new)` if `new` were a first-class function, hence the name.

* `return`

Returns its argument as the value of the enclosing function.
Expand Down
2 changes: 2 additions & 0 deletions src/ast.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ jl_sym_t *enter_sym; jl_sym_t *leave_sym;
jl_sym_t *pop_exception_sym;
jl_sym_t *exc_sym; jl_sym_t *error_sym;
jl_sym_t *new_sym; jl_sym_t *using_sym;
jl_sym_t *splatnew_sym;
jl_sym_t *const_sym; jl_sym_t *thunk_sym;
jl_sym_t *abstracttype_sym; jl_sym_t *primtype_sym;
jl_sym_t *structtype_sym; jl_sym_t *foreigncall_sym;
Expand Down Expand Up @@ -325,6 +326,7 @@ void jl_init_frontend(void)
leave_sym = jl_symbol("leave");
pop_exception_sym = jl_symbol("pop_exception");
new_sym = jl_symbol("new");
splatnew_sym = jl_symbol("splatnew");
const_sym = jl_symbol("const");
global_sym = jl_symbol("global");
thunk_sym = jl_symbol("thunk");
Expand Down
21 changes: 21 additions & 0 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,7 @@ static Function *jltls_states_func;

// important functions
static Function *jlnew_func;
static Function *jlsplatnew_func;
static Function *jlthrow_func;
static Function *jlerror_func;
static Function *jltypeerror_func;
Expand Down Expand Up @@ -4069,6 +4070,15 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaval)
// it to the inferred type.
return mark_julia_type(ctx, val, true, (jl_value_t*)jl_any_type);
}
else if (head == splatnew_sym) {
jl_cgval_t argv[2];
argv[0] = emit_expr(ctx, args[0]);
argv[1] = emit_expr(ctx, args[1]);
Value *typ = boxed(ctx, argv[0]);
Value *tup = boxed(ctx, argv[1]);
Value *val = ctx.builder.CreateCall(prepare_call(jlsplatnew_func), { typ, tup });
return mark_julia_type(ctx, val, true, (jl_value_t*)jl_any_type);
}
else if (head == exc_sym) {
return mark_julia_type(ctx,
ctx.builder.CreateCall(prepare_call(jl_current_exception_func)),
Expand Down Expand Up @@ -6981,6 +6991,17 @@ static void init_julia_llvm_env(Module *m)
jlnew_func->addFnAttr(Thunk);
add_named_global(jlnew_func, &jl_new_structv);

std::vector<Type *> args_2rptrs_(0);
args_2rptrs_.push_back(T_prjlvalue);
args_2rptrs_.push_back(T_prjlvalue);
jlsplatnew_func =
Function::Create(FunctionType::get(T_prjlvalue, args_2rptrs_, false),
Function::ExternalLinkage,
"jl_new_structt", m);
add_return_attr(jlsplatnew_func, Attribute::NonNull);
jlsplatnew_func->addFnAttr(Thunk);
add_named_global(jlsplatnew_func, &jl_new_structt);

std::vector<Type*> args2(0);
args2.push_back(T_pint8);
#ifndef _OS_WINDOWS_
Expand Down
62 changes: 50 additions & 12 deletions src/datatype.c
Original file line number Diff line number Diff line change
Expand Up @@ -797,8 +797,24 @@ JL_DLLEXPORT jl_value_t *jl_new_struct(jl_datatype_t *type, ...)
return jv;
}

JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args,
uint32_t na)
static void init_struct_tail(jl_datatype_t *type, jl_value_t *jv, size_t na)
{
size_t nf = jl_datatype_nfields(type);
for(size_t i=na; i < nf; i++) {
if (jl_field_isptr(type, i)) {
*(jl_value_t**)((char*)jl_data_ptr(jv)+jl_field_offset(type,i)) = NULL;
}
else {
jl_value_t *ft = jl_field_type(type, i);
if (jl_is_uniontype(ft)) {
uint8_t *psel = &((uint8_t *)jv)[jl_field_offset(type, i) + jl_field_size(type, i) - 1];
*psel = 0;
}
}
}
}

JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args, uint32_t na)
{
jl_ptls_t ptls = jl_get_ptls_states();
if (type->instance != NULL) {
Expand All @@ -811,7 +827,6 @@ JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args,
}
if (type->layout == NULL)
jl_type_error("new", (jl_value_t*)jl_datatype_type, (jl_value_t*)type);
size_t nf = jl_datatype_nfields(type);
jl_value_t *jv = jl_gc_alloc(ptls, jl_datatype_size(type), type);
JL_GC_PUSH1(&jv);
for (size_t i = 0; i < na; i++) {
Expand All @@ -820,18 +835,41 @@ JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args,
jl_type_error("new", ft, args[i]);
jl_set_nth_field(jv, i, args[i]);
}
for(size_t i=na; i < nf; i++) {
if (jl_field_isptr(type, i)) {
*(jl_value_t**)((char*)jl_data_ptr(jv)+jl_field_offset(type,i)) = NULL;
}
else {
init_struct_tail(type, jv, na);
JL_GC_POP();
return jv;
}

JL_DLLEXPORT jl_value_t *jl_new_structt(jl_datatype_t *type, jl_value_t *tup)
{
jl_ptls_t ptls = jl_get_ptls_states();
if (!jl_is_tuple(tup))
jl_type_error("new", (jl_value_t*)jl_tuple_type, tup);
size_t na = jl_nfields(tup);
size_t nf = jl_datatype_nfields(type);
if (na > nf)
jl_too_many_args("new", nf);
if (type->instance != NULL) {
for (size_t i = 0; i < na; i++) {
jl_value_t *ft = jl_field_type(type, i);
if (jl_is_uniontype(ft)) {
uint8_t *psel = &((uint8_t *)jv)[jl_field_offset(type, i) + jl_field_size(type, i) - 1];
*psel = 0;
}
jl_value_t *fi = jl_get_nth_field(tup, i);
if (!jl_isa(fi, ft))
jl_type_error("new", ft, fi);
}
return type->instance;
}
if (type->layout == NULL)
jl_type_error("new", (jl_value_t*)jl_datatype_type, (jl_value_t*)type);
jl_value_t *jv = jl_gc_alloc(ptls, jl_datatype_size(type), type);
JL_GC_PUSH1(&jv);
for (size_t i = 0; i < na; i++) {
jl_value_t *ft = jl_field_type(type, i);
jl_value_t *fi = jl_get_nth_field(tup, i);
if (!jl_isa(fi, ft))
jl_type_error("new", ft, fi);
jl_set_nth_field(jv, i, fi);
}
init_struct_tail(type, jv, na);
JL_GC_POP();
return jv;
}
Expand Down
10 changes: 10 additions & 0 deletions src/interpreter.c
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,16 @@ SECT_INTERP static jl_value_t *eval_value(jl_value_t *e, interpreter_state *s)
JL_GC_POP();
return v;
}
else if (head == splatnew_sym) {
jl_value_t **argv;
JL_GC_PUSHARGS(argv, 2);
argv[0] = eval_value(args[0], s);
argv[1] = eval_value(args[1], s);
assert(jl_is_structtype(argv[0]));
jl_value_t *v = jl_new_structt((jl_datatype_t*)argv[0], argv[1]);
JL_GC_POP();
return v;
}
else if (head == static_parameter_sym) {
ssize_t n = jl_unbox_long(args[0]);
assert(n > 0);
Expand Down
Loading

0 comments on commit 2ecc499

Please sign in to comment.