From 08c66ef2e17f685e2a2b3195909ac28816feb19a Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 4 Nov 2021 19:07:18 -0700 Subject: [PATCH 01/32] Sync to upstream/release/502 Changes: - Support for time tracing for analysis/compiler (not currently exposed through CLI) - Support for type pack arguments in type aliases (#83) - Basic support for require(path) in luau-analyze - Add a lint warning for table.move with 0 index as part of TableOperation lint - Remove last STL dependency from Luau.VM - Minor VS2022 performance tuning Co-authored-by: Rodactor --- .gitignore | 7 + Analysis/include/Luau/BuiltinDefinitions.h | 3 +- Analysis/include/Luau/Error.h | 1 + Analysis/include/Luau/FileResolver.h | 58 +- Analysis/include/Luau/Module.h | 10 +- Analysis/include/Luau/ModuleResolver.h | 6 - Analysis/include/Luau/RequireTracer.h | 5 +- Analysis/include/Luau/Scope.h | 67 ++ Analysis/include/Luau/Substitution.h | 14 +- Analysis/include/Luau/TypeInfer.h | 56 +- Analysis/include/Luau/TypePack.h | 3 +- Analysis/include/Luau/TypeVar.h | 22 +- Analysis/include/Luau/Unifier.h | 18 +- Analysis/src/AstQuery.cpp | 17 +- Analysis/src/Autocomplete.cpp | 49 +- Analysis/src/BuiltinDefinitions.cpp | 95 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 21 - Analysis/src/Error.cpp | 108 ++- Analysis/src/Frontend.cpp | 51 +- Analysis/src/IostreamHelpers.cpp | 18 +- Analysis/src/JsonEncoder.cpp | 32 +- Analysis/src/Linter.cpp | 19 +- Analysis/src/Module.cpp | 44 +- Analysis/src/RequireTracer.cpp | 215 ++++- Analysis/src/Scope.cpp | 123 +++ Analysis/src/Substitution.cpp | 34 +- Analysis/src/ToString.cpp | 131 ++- Analysis/src/Transpiler.cpp | 38 +- Analysis/src/TypeAttach.cpp | 124 ++- Analysis/src/TypeInfer.cpp | 567 ++++++------ Analysis/src/TypePack.cpp | 13 + Analysis/src/TypeUtils.cpp | 18 +- Analysis/src/TypeVar.cpp | 36 +- Analysis/src/Unifier.cpp | 500 +++++++++-- Ast/include/Luau/Ast.h | 33 +- Ast/include/Luau/DenseHash.h | 5 +- Ast/include/Luau/Parser.h | 8 +- Ast/include/Luau/TimeTrace.h | 223 +++++ Ast/src/Ast.cpp | 37 +- Ast/src/Parser.cpp | 145 ++- Ast/src/TimeTrace.cpp | 248 ++++++ CLI/Analyze.cpp | 18 +- Compiler/src/Compiler.cpp | 10 + Sources.cmake | 5 + VM/src/ldo.cpp | 10 +- VM/src/lgc.cpp | 191 +++- VM/src/ltablib.cpp | 22 - VM/src/lvmexecute.cpp | 12 +- VM/src/lvmload.cpp | 34 +- bench/tests/deltablue.lua | 934 -------------------- tests/Autocomplete.test.cpp | 853 +++++++++--------- tests/Fixture.cpp | 49 + tests/Fixture.h | 2 + tests/Frontend.test.cpp | 31 +- tests/Linter.test.cpp | 9 +- tests/Module.test.cpp | 1 + tests/NonstrictMode.test.cpp | 1 + tests/Parser.test.cpp | 15 + tests/RequireTracer.test.cpp | 68 +- tests/ToString.test.cpp | 3 +- tests/TypeInfer.aliases.test.cpp | 557 ++++++++++++ tests/TypeInfer.provisional.test.cpp | 36 +- tests/TypeInfer.refinements.test.cpp | 72 +- tests/TypeInfer.tables.test.cpp | 230 +++-- tests/TypeInfer.test.cpp | 563 +----------- tests/TypeInfer.tryUnify.test.cpp | 1 + tests/TypeInfer.typePacks.cpp | 366 ++++++++ tests/TypeInfer.unionTypes.test.cpp | 16 +- tests/TypeVar.test.cpp | 1 + tools/tracegraph.py | 95 ++ 70 files changed, 4475 insertions(+), 2952 deletions(-) create mode 100644 .gitignore create mode 100644 Analysis/include/Luau/Scope.h create mode 100644 Analysis/src/Scope.cpp create mode 100644 Ast/include/Luau/TimeTrace.h create mode 100644 Ast/src/TimeTrace.cpp delete mode 100644 bench/tests/deltablue.lua create mode 100644 tests/TypeInfer.aliases.test.cpp create mode 100644 tools/tracegraph.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..0b2422ced --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +^build/ +^coverage/ +^fuzz/luau.pb.* +^crash-* +^default.prof* +^fuzz-* +^luau$ diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 8f17fff65..57a1907a5 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -1,7 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "TypeInfer.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" namespace Luau { diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 946bc9288..ac6f13e96 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -120,6 +120,7 @@ struct IncorrectGenericParameterCount Name name; TypeFun typeFun; size_t actualParameters; + size_t actualPackParameters; bool operator==(const IncorrectGenericParameterCount& rhs) const; }; diff --git a/Analysis/include/Luau/FileResolver.h b/Analysis/include/Luau/FileResolver.h index 71f9464b8..a05ec5e91 100644 --- a/Analysis/include/Luau/FileResolver.h +++ b/Analysis/include/Luau/FileResolver.h @@ -25,51 +25,39 @@ struct SourceCode Type type; }; +struct ModuleInfo +{ + ModuleName name; + bool optional = false; +}; + struct FileResolver { virtual ~FileResolver() {} - /** Fetch the source code associated with the provided ModuleName. - * - * FIXME: This requires a string copy! - * - * @returns The actual Lua code on success. - * @returns std::nullopt if no such file exists. When this occurs, type inference will report an UnknownRequire error. - */ virtual std::optional readSource(const ModuleName& name) = 0; - /** Does the module exist? - * - * Saves a string copy over reading the source and throwing it away. - */ - virtual bool moduleExists(const ModuleName& name) const = 0; - - virtual std::optional fromAstFragment(AstExpr* expr) const = 0; - - /** Given a valid module name and a string of arbitrary data, figure out the concatenation. - */ - virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0; - - /** Goes "up" a level in the hierarchy that the ModuleName represents. - * - * For instances, this is analogous to someInstance.Parent; for paths, this is equivalent to removing the last - * element of the path. Other ModuleName representations may have other ways of doing this. - * - * @returns The parent ModuleName, if one exists. - * @returns std::nullopt if there is no parent for this module name. - */ - virtual std::optional getParentModuleName(const ModuleName& name) const = 0; + virtual std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) + { + return std::nullopt; + } - virtual std::optional getHumanReadableModuleName_(const ModuleName& name) const + virtual std::string getHumanReadableModuleName(const ModuleName& name) const { return name; } - virtual std::optional getEnvironmentForModule(const ModuleName& name) const = 0; + virtual std::optional getEnvironmentForModule(const ModuleName& name) const + { + return std::nullopt; + } - /** LanguageService only: - * std::optional fromInstance(Instance* inst) - */ + // DEPRECATED APIS + // These are going to be removed with LuauNewRequireTracer + virtual bool moduleExists(const ModuleName& name) const = 0; + virtual std::optional fromAstFragment(AstExpr* expr) const = 0; + virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0; + virtual std::optional getParentModuleName(const ModuleName& name) const = 0; }; struct NullFileResolver : FileResolver @@ -94,10 +82,6 @@ struct NullFileResolver : FileResolver { return std::nullopt; } - std::optional getEnvironmentForModule(const ModuleName& name) const override - { - return std::nullopt; - } }; } // namespace Luau diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 413b68f40..d08448351 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -90,10 +90,12 @@ struct Module TypeArena internalTypes; std::vector> scopes; // never empty - std::unordered_map astTypes; - std::unordered_map astExpectedTypes; - std::unordered_map astOriginalCallTypes; - std::unordered_map astOverloadResolvedTypes; + + DenseHashMap astTypes{nullptr}; + DenseHashMap astExpectedTypes{nullptr}; + DenseHashMap astOriginalCallTypes{nullptr}; + DenseHashMap astOverloadResolvedTypes{nullptr}; + std::unordered_map declaredGlobals; ErrorVec errors; Mode mode; diff --git a/Analysis/include/Luau/ModuleResolver.h b/Analysis/include/Luau/ModuleResolver.h index a394a21b4..d892ccd7f 100644 --- a/Analysis/include/Luau/ModuleResolver.h +++ b/Analysis/include/Luau/ModuleResolver.h @@ -15,12 +15,6 @@ struct Module; using ModulePtr = std::shared_ptr; -struct ModuleInfo -{ - ModuleName name; - bool optional = false; -}; - struct ModuleResolver { virtual ~ModuleResolver() {} diff --git a/Analysis/include/Luau/RequireTracer.h b/Analysis/include/Luau/RequireTracer.h index e9778876c..c25545f57 100644 --- a/Analysis/include/Luau/RequireTracer.h +++ b/Analysis/include/Luau/RequireTracer.h @@ -17,12 +17,11 @@ struct AstLocal; struct RequireTraceResult { - DenseHashMap exprs{0}; - DenseHashMap optional{0}; + DenseHashMap exprs{nullptr}; std::vector> requires; }; -RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, ModuleName currentModuleName); +RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName); } // namespace Luau diff --git a/Analysis/include/Luau/Scope.h b/Analysis/include/Luau/Scope.h new file mode 100644 index 000000000..45338409f --- /dev/null +++ b/Analysis/include/Luau/Scope.h @@ -0,0 +1,67 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Location.h" +#include "Luau/TypeVar.h" + +#include +#include +#include + +namespace Luau +{ + +struct Scope; + +using ScopePtr = std::shared_ptr; + +struct Binding +{ + TypeId typeId; + Location location; + bool deprecated = false; + std::string deprecatedSuggestion; + std::optional documentationSymbol; +}; + +struct Scope +{ + explicit Scope(TypePackId returnType); // root scope + explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr. + + const ScopePtr parent; // null for the root + std::unordered_map bindings; + TypePackId returnType; + bool breakOk = false; + std::optional varargPack; + + TypeLevel level; + + std::unordered_map exportedTypeBindings; + std::unordered_map privateTypeBindings; + std::unordered_map typeAliasLocations; + + std::unordered_map> importedTypeBindings; + + std::optional lookup(const Symbol& name); + + std::optional lookupType(const Name& name); + std::optional lookupImportedType(const Name& moduleAlias, const Name& name); + + std::unordered_map privateTypePackBindings; + std::optional lookupPack(const Name& name); + + // WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2) + std::optional linearSearchForBinding(const std::string& name, bool traverseScopeChain = true); + + RefinementMap refinements; + + // For mutually recursive type aliases, it's important that + // they use the same types for the same names. + // For instance, in `type Tree { data: T, children: Forest } type Forest = {Tree}` + // we need that the generic type `T` in both cases is the same, so we use a cache. + std::unordered_map typeAliasTypeParameters; + std::unordered_map typeAliasTypePackParameters; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 6ac868f76..80a14e8fb 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -52,8 +52,6 @@ // `T`, and the type of `f` are in the same SCC, which is why `f` gets // replaced. -LUAU_FASTFLAG(DebugLuauTrackOwningArena) - namespace Luau { @@ -188,20 +186,12 @@ struct Substitution : FindDirty template TypeId addType(const T& tv) { - TypeId allocated = currentModule->internalTypes.typeVars.allocate(tv); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addType(tv); } template TypePackId addTypePack(const T& tp) { - TypePackId allocated = currentModule->internalTypes.typePacks.allocate(tp); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addTypePack(TypePackVar{tp}); } }; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index ec2a1a26f..d701eb248 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -86,7 +86,10 @@ struct ApplyTypeFunction : Substitution { TypeLevel level; bool encounteredForwardedType; - std::unordered_map arguments; + std::unordered_map typeArguments; + std::unordered_map typePackArguments; + bool ignoreChildren(TypeId ty) override; + bool ignoreChildren(TypePackId tp) override; bool isDirty(TypeId ty) override; bool isDirty(TypePackId tp) override; TypeId clean(TypeId ty) override; @@ -328,7 +331,8 @@ struct TypeChecker TypeId resolveType(const ScopePtr& scope, const AstType& annotation, bool canBeGeneric = false); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation); - TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, const Location& location); + TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, + const std::vector& typePackParams, const Location& location); // Note: `scope` must be a fresh scope. std::pair, std::vector> createGenericTypes( @@ -398,54 +402,6 @@ struct TypeChecker int recursionCount = 0; }; -struct Binding -{ - TypeId typeId; - Location location; - bool deprecated = false; - std::string deprecatedSuggestion; - std::optional documentationSymbol; -}; - -struct Scope -{ - explicit Scope(TypePackId returnType); // root scope - explicit Scope(const ScopePtr& parent, int subLevel = 0); // child scope. Parent must not be nullptr. - - const ScopePtr parent; // null for the root - std::unordered_map bindings; - TypePackId returnType; - bool breakOk = false; - std::optional varargPack; - - TypeLevel level; - - std::unordered_map exportedTypeBindings; - std::unordered_map privateTypeBindings; - std::unordered_map typeAliasLocations; - - std::unordered_map> importedTypeBindings; - - std::optional lookup(const Symbol& name); - - std::optional lookupType(const Name& name); - std::optional lookupImportedType(const Name& moduleAlias, const Name& name); - - std::unordered_map privateTypePackBindings; - std::optional lookupPack(const Name& name); - - // WARNING: This function linearly scans for a string key of equal value! It is thus O(n**2) - std::optional linearSearchForBinding(const std::string& name, bool traverseScopeChain = true); - - RefinementMap refinements; - - // For mutually recursive type aliases, it's important that - // they use the same types for the same names. - // For instance, in `type Tree { data: T, children: Forest } type Forest = {Tree}` - // we need that the generic type `T` in both cases is the same, so we use a cache. - std::unordered_map typeAliasParameters; -}; - // Unit test hook void setPrintLine(void (*pl)(const std::string& s)); void resetPrintLine(); diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 0d0adce7e..d987d46ca 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -117,7 +117,8 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs); TypePackId follow(TypePackId tp); -size_t size(const TypePackId tp); +size_t size(TypePackId tp); +bool finite(TypePackId tp); size_t size(const TypePack& tp); std::optional first(TypePackId tp); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 90a28b20d..d4e4e4913 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -228,6 +228,7 @@ struct TableTypeVar std::map methodDefinitionLocations; std::vector instantiatedTypeParams; + std::vector instantiatedTypePackParams; ModuleName definitionModuleName; std::optional boundTo; @@ -284,8 +285,9 @@ struct ClassTypeVar struct TypeFun { - /// These should all be generic + // These should all be generic std::vector typeParams; + std::vector typePackParams; /** The underlying type. * @@ -293,6 +295,20 @@ struct TypeFun * You must first use TypeChecker::instantiateTypeFun to turn it into a real type. */ TypeId type; + + TypeFun() = default; + TypeFun(std::vector typeParams, TypeId type) + : typeParams(std::move(typeParams)) + , type(type) + { + } + + TypeFun(std::vector typeParams, std::vector typePackParams, TypeId type) + : typeParams(std::move(typeParams)) + , typePackParams(std::move(typePackParams)) + , type(type) + { + } }; // Anything! All static checking is off. @@ -524,8 +540,4 @@ UnionTypeVarIterator end(const UnionTypeVar* utv); using TypeIdPredicate = std::function(TypeId)>; std::vector filterMap(TypeId type, TypeIdPredicate predicate); -// TEMP: Clip this prototype with FFlag::LuauStringMetatable -std::optional> magicFunctionFormat( - struct TypeChecker& typechecker, const std::shared_ptr& scope, const AstExprCall& expr, ExprResult exprResult); - } // namespace Luau diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 0ddc3cc0b..522914b2f 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -36,12 +36,17 @@ struct Unifier Variance variance = Covariant; CountMismatch::Context ctx = CountMismatch::Arg; - std::shared_ptr counters; + UnifierCounters* counters; + UnifierCounters countersData; + + std::shared_ptr counters_DEPRECATED; + InternalErrorReporter* iceHandler; Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler); Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& seen, const Location& location, - Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters = nullptr); + Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters_DEPRECATED = nullptr, + UnifierCounters* counters = nullptr); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId superTy, TypeId subTy); @@ -58,11 +63,13 @@ struct Unifier void tryUnifyPrimitives(TypeId superTy, TypeId subTy); void tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall = false); void tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); + void DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); void tryUnifyFreeTable(TypeId free, TypeId other); void tryUnifySealedTables(TypeId left, TypeId right, bool isIntersection); void tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed); void tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed); void tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer); + TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); public: void tryUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); @@ -80,9 +87,9 @@ struct Unifier public: // Report an "infinite type error" if the type "needle" already occurs within "haystack" void occursCheck(TypeId needle, TypeId haystack); - void occursCheck(std::unordered_set& seen, TypeId needle, TypeId haystack); + void occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypeId needle, TypeId haystack); void occursCheck(TypePackId needle, TypePackId haystack); - void occursCheck(std::unordered_set& seen, TypePackId needle, TypePackId haystack); + void occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypePackId needle, TypePackId haystack); Unifier makeChildUnifier(); @@ -93,6 +100,9 @@ struct Unifier [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); + + DenseHashSet tempSeenTy{nullptr}; + DenseHashSet tempSeenTp{nullptr}; }; } // namespace Luau diff --git a/Analysis/src/AstQuery.cpp b/Analysis/src/AstQuery.cpp index d3de1754a..0aed34c0a 100644 --- a/Analysis/src/AstQuery.cpp +++ b/Analysis/src/AstQuery.cpp @@ -2,6 +2,7 @@ #include "Luau/AstQuery.h" #include "Luau/Module.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" #include "Luau/ToString.h" @@ -143,8 +144,8 @@ std::optional findTypeAtPosition(const Module& module, const SourceModul { if (auto expr = findExprAtPosition(sourceModule, pos)) { - if (auto it = module.astTypes.find(expr); it != module.astTypes.end()) - return it->second; + if (auto it = module.astTypes.find(expr)) + return *it; } return std::nullopt; @@ -154,8 +155,8 @@ std::optional findExpectedTypeAtPosition(const Module& module, const Sou { if (auto expr = findExprAtPosition(sourceModule, pos)) { - if (auto it = module.astExpectedTypes.find(expr); it != module.astExpectedTypes.end()) - return it->second; + if (auto it = module.astExpectedTypes.find(expr)) + return *it; } return std::nullopt; @@ -322,9 +323,9 @@ std::optional getDocumentationSymbolAtPosition(const Source TypeId matchingOverload = nullptr; if (parentExpr && parentExpr->is()) { - if (auto it = module.astOverloadResolvedTypes.find(parentExpr); it != module.astOverloadResolvedTypes.end()) + if (auto it = module.astOverloadResolvedTypes.find(parentExpr)) { - matchingOverload = it->second; + matchingOverload = *it; } } @@ -345,9 +346,9 @@ std::optional getDocumentationSymbolAtPosition(const Source { if (AstExprIndexName* indexName = targetExpr->as()) { - if (auto it = module.astTypes.find(indexName->expr); it != module.astTypes.end()) + if (auto it = module.astTypes.find(indexName->expr)) { - TypeId parentTy = follow(it->second); + TypeId parentTy = follow(*it); if (const TableTypeVar* ttv = get(parentTy)) { if (auto propIt = ttv->props.find(indexName->index.value); propIt != ttv->props.end()) diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index dce92a0c9..235abf36f 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -210,10 +210,10 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ return TypeCorrectKind::None; auto it = module.astExpectedTypes.find(expr); - if (it == module.astExpectedTypes.end()) + if (!it) return TypeCorrectKind::None; - TypeId expectedType = follow(it->second); + TypeId expectedType = follow(*it); if (canUnify(expectedType, ty)) return TypeCorrectKind::Correct; @@ -682,10 +682,10 @@ static std::optional functionIsExpectedAt(const Module& module, AstNode* n return std::nullopt; auto it = module.astExpectedTypes.find(expr); - if (it == module.astExpectedTypes.end()) + if (!it) return std::nullopt; - TypeId expectedType = follow(it->second); + TypeId expectedType = follow(*it); if (const FunctionTypeVar* ftv = get(expectedType)) return true; @@ -784,9 +784,9 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi if (AstExprCall* exprCall = expr->as()) { - if (auto it = module.astTypes.find(exprCall->func); it != module.astTypes.end()) + if (auto it = module.astTypes.find(exprCall->func)) { - if (const FunctionTypeVar* ftv = get(follow(it->second))) + if (const FunctionTypeVar* ftv = get(follow(*it))) { if (auto ty = tryGetTypePackTypeAt(ftv->retType, tailPos)) inferredType = *ty; @@ -798,8 +798,8 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi if (tailPos != 0) break; - if (auto it = module.astTypes.find(expr); it != module.astTypes.end()) - inferredType = it->second; + if (auto it = module.astTypes.find(expr)) + inferredType = *it; } if (inferredType) @@ -815,10 +815,10 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi auto tryGetExpectedFunctionType = [](const Module& module, AstExpr* expr) -> const FunctionTypeVar* { auto it = module.astExpectedTypes.find(expr); - if (it == module.astExpectedTypes.end()) + if (!it) return nullptr; - TypeId ty = follow(it->second); + TypeId ty = follow(*it); if (const FunctionTypeVar* ftv = get(ty)) return ftv; @@ -1129,9 +1129,8 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul if (node->is()) { - auto it = module.astTypes.find(node->asExpr()); - if (it != module.astTypes.end()) - autocompleteProps(module, typeArena, it->second, PropIndexType::Point, ancestry, result); + if (auto it = module.astTypes.find(node->asExpr())) + autocompleteProps(module, typeArena, *it, PropIndexType::Point, ancestry, result); } else if (FFlag::LuauIfElseExpressionAnalysisSupport && autocompleteIfElseExpression(node, ancestry, position, result)) return; @@ -1203,13 +1202,13 @@ static std::optional getMethodContainingClass(const ModuleP return std::nullopt; } - auto parentIter = module->astTypes.find(parentExpr); - if (parentIter == module->astTypes.end()) + auto parentIt = module->astTypes.find(parentExpr); + if (!parentIt) { return std::nullopt; } - Luau::TypeId parentType = Luau::follow(parentIter->second); + Luau::TypeId parentType = Luau::follow(*parentIt); if (auto parentClass = Luau::get(parentType)) { @@ -1250,8 +1249,8 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } - auto iter = module->astTypes.find(candidate->func); - if (iter == module->astTypes.end()) + auto it = module->astTypes.find(candidate->func); + if (!it) { return std::nullopt; } @@ -1267,7 +1266,7 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; }; - auto followedId = Luau::follow(iter->second); + auto followedId = Luau::follow(*it); if (auto functionType = Luau::get(followedId)) { return performCallback(functionType); @@ -1316,10 +1315,10 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (auto indexName = node->as()) { auto it = module->astTypes.find(indexName->expr); - if (it == module->astTypes.end()) + if (!it) return {}; - TypeId ty = follow(it->second); + TypeId ty = follow(*it); PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; if (isString(ty)) @@ -1447,9 +1446,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M // If item doesn't have a key, maybe the value is actually the key if (key ? key == node : node->is() && value == node) { - if (auto it = module->astExpectedTypes.find(exprTable); it != module->astExpectedTypes.end()) + if (auto it = module->astExpectedTypes.find(exprTable)) { - auto result = autocompleteProps(*module, typeArena, it->second, PropIndexType::Key, finder.ancestry); + auto result = autocompleteProps(*module, typeArena, *it, PropIndexType::Key, finder.ancestry); // Remove keys that are already completed for (const auto& item : exprTable->items) @@ -1485,9 +1484,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { if (auto idxExpr = finder.ancestry.at(finder.ancestry.size() - 2)->as()) { - if (auto it = module->astTypes.find(idxExpr->expr); it != module->astTypes.end()) + if (auto it = module->astTypes.find(idxExpr->expr)) { - return {autocompleteProps(*module, typeArena, follow(it->second), PropIndexType::Point, finder.ancestry), finder.ancestry}; + return {autocompleteProps(*module, typeArena, follow(*it), PropIndexType::Point, finder.ancestry), finder.ancestry}; } } } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 68ad5ac9f..3b0c21638 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -11,7 +11,7 @@ LUAU_FASTFLAG(LuauParseGenericFunctions) LUAU_FASTFLAG(LuauGenericFunctions) LUAU_FASTFLAG(LuauRankNTypes) -LUAU_FASTFLAG(LuauStringMetatable) +LUAU_FASTFLAG(LuauNewRequireTrace) /** FIXME: Many of these type definitions are not quite completely accurate. * @@ -218,7 +218,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypePackId anyTypePack = typeChecker.anyTypePack; TypePackId numberVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{numberType}}); - TypePackId stringVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{stringType}}); TypePackId listOfAtLeastOneNumber = arena.addTypePack(TypePack{{numberType}, numberVariadicList}); TypeId listOfAtLeastOneNumberToNumberType = arena.addType(FunctionTypeVar{ @@ -255,85 +254,18 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypeId genericV = arena.addType(GenericTypeVar{"V"}); TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level}); - if (FFlag::LuauStringMetatable) - { - std::optional stringMetatableTy = getMetatable(singletonTypes.stringType); - LUAU_ASSERT(stringMetatableTy); - const TableTypeVar* stringMetatableTable = get(follow(*stringMetatableTy)); - LUAU_ASSERT(stringMetatableTable); + std::optional stringMetatableTy = getMetatable(singletonTypes.stringType); + LUAU_ASSERT(stringMetatableTy); + const TableTypeVar* stringMetatableTable = get(follow(*stringMetatableTy)); + LUAU_ASSERT(stringMetatableTable); - auto it = stringMetatableTable->props.find("__index"); - LUAU_ASSERT(it != stringMetatableTable->props.end()); + auto it = stringMetatableTable->props.find("__index"); + LUAU_ASSERT(it != stringMetatableTable->props.end()); - TypeId stringLib = it->second.type; - addGlobalBinding(typeChecker, "string", stringLib, "@luau"); - } + addGlobalBinding(typeChecker, "string", it->second.type, "@luau"); - if (FFlag::LuauParseGenericFunctions && FFlag::LuauGenericFunctions) - { - if (!FFlag::LuauStringMetatable) - { - TypeId stringLibTy = getGlobalBinding(typeChecker, "string"); - TableTypeVar* stringLib = getMutable(stringLibTy); - TypeId replArgType = makeUnion( - arena, {stringType, - arena.addType(TableTypeVar({}, TableIndexer(stringType, stringType), typeChecker.globalScope->level, TableState::Generic)), - makeFunction(arena, std::nullopt, {stringType}, {stringType})}); - TypeId gsubFunc = makeFunction(arena, stringType, {stringType, replArgType, optionalNumber}, {stringType, numberType}); - - stringLib->props["gsub"] = makeProperty(gsubFunc, "@luau/global/string.gsub"); - } - } - else + if (!FFlag::LuauParseGenericFunctions || !FFlag::LuauGenericFunctions) { - if (!FFlag::LuauStringMetatable) - { - TypeId stringToStringType = makeFunction(arena, std::nullopt, {stringType}, {stringType}); - - TypeId gmatchFunc = makeFunction(arena, stringType, {stringType}, {arena.addType(FunctionTypeVar{emptyPack, stringVariadicList})}); - - TypeId replArgType = makeUnion( - arena, {stringType, - arena.addType(TableTypeVar({}, TableIndexer(stringType, stringType), typeChecker.globalScope->level, TableState::Generic)), - makeFunction(arena, std::nullopt, {stringType}, {stringType})}); - TypeId gsubFunc = makeFunction(arena, stringType, {stringType, replArgType, optionalNumber}, {stringType, numberType}); - - TypeId formatFn = arena.addType(FunctionTypeVar{arena.addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack}); - - TableTypeVar::Props stringLib = { - // FIXME string.byte "can" return a pack of numbers, but only if 2nd or 3rd arguments were supplied - {"byte", {makeFunction(arena, stringType, {optionalNumber, optionalNumber}, {optionalNumber})}}, - // FIXME char takes a variadic pack of numbers - {"char", {makeFunction(arena, std::nullopt, {numberType, optionalNumber, optionalNumber, optionalNumber}, {stringType})}}, - {"find", {makeFunction(arena, stringType, {stringType, optionalNumber, optionalBoolean}, {optionalNumber, optionalNumber})}}, - {"format", {formatFn}}, // FIXME - {"gmatch", {gmatchFunc}}, - {"gsub", {gsubFunc}}, - {"len", {makeFunction(arena, stringType, {}, {numberType})}}, - {"lower", {stringToStringType}}, - {"match", {makeFunction(arena, stringType, {stringType, optionalNumber}, {optionalString})}}, - {"rep", {makeFunction(arena, stringType, {numberType}, {stringType})}}, - {"reverse", {stringToStringType}}, - {"sub", {makeFunction(arena, stringType, {numberType, optionalNumber}, {stringType})}}, - {"upper", {stringToStringType}}, - {"split", {makeFunction(arena, stringType, {stringType, optionalString}, - {arena.addType(TableTypeVar{{}, TableIndexer{numberType, stringType}, typeChecker.globalScope->level})})}}, - {"pack", {arena.addType(FunctionTypeVar{ - arena.addTypePack(TypePack{{stringType}, anyTypePack}), - oneStringPack, - })}}, - {"packsize", {makeFunction(arena, stringType, {}, {numberType})}}, - {"unpack", {arena.addType(FunctionTypeVar{ - arena.addTypePack(TypePack{{stringType, stringType, optionalNumber}}), - anyTypePack, - })}}, - }; - - assignPropDocumentationSymbols(stringLib, "@luau/global/string"); - addGlobalBinding(typeChecker, "string", - arena.addType(TableTypeVar{stringLib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau"); - } - TableTypeVar::Props debugLib{ {"info", {makeIntersection(arena, { @@ -601,9 +533,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker) auto tableLib = getMutable(getGlobalBinding(typeChecker, "table")); attachMagicFunction(tableLib->props["pack"].type, magicFunctionPack); - auto stringLib = getMutable(getGlobalBinding(typeChecker, "string")); - attachMagicFunction(stringLib->props["format"].type, magicFunctionFormat); - attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire); } @@ -791,11 +720,11 @@ static std::optional> magicFunctionRequire( return std::nullopt; } - AstExpr* require = expr.args.data[0]; - - if (!checkRequirePath(typechecker, require)) + if (!checkRequirePath(typechecker, expr.args.data[0])) return std::nullopt; + const AstExpr* require = FFlag::LuauNewRequireTrace ? &expr : expr.args.data[0]; + if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, *require)) return ExprResult{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 61a63f067..1e91561a6 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -206,27 +206,6 @@ std::string getBuiltinDefinitionSource() graphemes: (string, number?, number?) -> (() -> (number, number)), } - declare string: { - byte: (string, number?, number?) -> ...number, - char: (number, ...number) -> string, - find: (string, string, number?, boolean?) -> (number?, number?), - -- `string.format` has a magic function attached that will provide more type information for literal format strings. - format: (string, A...) -> string, - gmatch: (string, string) -> () -> (...string), - -- gsub is defined in C++ because we don't have syntax for describing a generic table. - len: (string) -> number, - lower: (string) -> string, - match: (string, string, number?) -> string?, - rep: (string, number) -> string, - reverse: (string) -> string, - sub: (string, number, number?) -> string, - upper: (string) -> string, - split: (string, string, string?) -> {string}, - pack: (string, A...) -> string, - packsize: (string) -> number, - unpack: (string, string, number?) -> R..., - } - -- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. declare function unpack(tab: {V}, i: number?, j: number?): ...V )"; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 680bcf3f0..92fbffc80 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -7,9 +7,9 @@ #include -LUAU_FASTFLAG(LuauFasterStringifier) +LUAU_FASTFLAG(LuauTypeAliasPacks) -static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, bool isTypeArgs = false) +static std::string wrongNumberOfArgsString_DEPRECATED(size_t expectedCount, size_t actualCount, bool isTypeArgs = false) { std::string s = "expects " + std::to_string(expectedCount) + " "; @@ -41,6 +41,52 @@ static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCo return s; } +static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) +{ + std::string s; + + if (FFlag::LuauTypeAliasPacks) + { + s = "expects "; + + if (isVariadic) + s += "at least "; + + s += std::to_string(expectedCount) + " "; + } + else + { + s = "expects " + std::to_string(expectedCount) + " "; + } + + if (argPrefix) + s += std::string(argPrefix) + " "; + + s += "argument"; + if (expectedCount != 1) + s += "s"; + + s += ", but "; + + if (actualCount == 0) + { + s += "none"; + } + else + { + if (actualCount < expectedCount) + s += "only "; + + s += std::to_string(actualCount); + } + + s += (actualCount == 1) ? " is" : " are"; + + s += " specified"; + + return s; +} + namespace Luau { @@ -128,7 +174,10 @@ struct ErrorConverter else return "Function only returns " + std::to_string(e.expected) + " values. " + std::to_string(e.actual) + " are required here"; case CountMismatch::Arg: - return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); + if (FFlag::LuauTypeAliasPacks) + return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); + else + return "Argument count mismatch. Function " + wrongNumberOfArgsString_DEPRECATED(e.expected, e.actual); } LUAU_ASSERT(!"Unknown context"); @@ -160,13 +209,16 @@ struct ErrorConverter std::string operator()(const Luau::UnknownRequire& e) const { - return "Unknown require: " + e.modulePath; + if (e.modulePath.empty()) + return "Unknown require: unsupported path"; + else + return "Unknown require: " + e.modulePath; } std::string operator()(const Luau::IncorrectGenericParameterCount& e) const { std::string name = e.name; - if (!e.typeFun.typeParams.empty()) + if (!e.typeFun.typeParams.empty() || (FFlag::LuauTypeAliasPacks && !e.typeFun.typePackParams.empty())) { name += "<"; bool first = true; @@ -179,10 +231,37 @@ struct ErrorConverter name += toString(t); } + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId t : e.typeFun.typePackParams) + { + if (first) + first = false; + else + name += ", "; + + name += toString(t); + } + } + name += ">"; } - return "Generic type '" + name + "' " + wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, /*isTypeArgs*/ true); + if (FFlag::LuauTypeAliasPacks) + { + if (e.typeFun.typeParams.size() != e.actualParameters) + return "Generic type '" + name + "' " + + wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, "type", !e.typeFun.typePackParams.empty()); + + return "Generic type '" + name + "' " + + wrongNumberOfArgsString(e.typeFun.typePackParams.size(), e.actualPackParameters, "type pack", /*isVariadic*/ false); + } + else + { + return "Generic type '" + name + "' " + + wrongNumberOfArgsString_DEPRECATED(e.typeFun.typeParams.size(), e.actualParameters, /*isTypeArgs*/ true); + } } std::string operator()(const Luau::SyntaxError& e) const @@ -471,9 +550,26 @@ bool IncorrectGenericParameterCount::operator==(const IncorrectGenericParameterC if (typeFun.typeParams.size() != rhs.typeFun.typeParams.size()) return false; + if (FFlag::LuauTypeAliasPacks) + { + if (typeFun.typePackParams.size() != rhs.typeFun.typePackParams.size()) + return false; + } + for (size_t i = 0; i < typeFun.typeParams.size(); ++i) + { if (typeFun.typeParams[i] != rhs.typeFun.typeParams[i]) return false; + } + + if (FFlag::LuauTypeAliasPacks) + { + for (size_t i = 0; i < typeFun.typePackParams.size(); ++i) + { + if (typeFun.typePackParams[i] != rhs.typeFun.typePackParams[i]) + return false; + } + } return true; } diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 4d385ec19..b2529840b 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -1,9 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Frontend.h" +#include "Luau/Common.h" #include "Luau/Config.h" #include "Luau/FileResolver.h" +#include "Luau/Scope.h" #include "Luau/StringUtils.h" +#include "Luau/TimeTrace.h" #include "Luau/TypeInfer.h" #include "Luau/Variant.h" #include "Luau/Common.h" @@ -19,6 +22,7 @@ LUAU_FASTFLAGVARIABLE(LuauSecondTypecheckKnowsTheDataModel, false) LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) +LUAU_FASTFLAG(LuauNewRequireTrace) namespace Luau { @@ -69,6 +73,8 @@ static void generateDocumentationSymbols(TypeId ty, const std::string& rootName) LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr targetScope, std::string_view source, const std::string& packageName) { + LUAU_TIMETRACE_SCOPE("loadDefinitionFile", "Frontend"); + Luau::Allocator allocator; Luau::AstNameTable names(allocator); @@ -350,6 +356,9 @@ FrontendModuleResolver::FrontendModuleResolver(Frontend* frontend) CheckResult Frontend::check(const ModuleName& name) { + LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + CheckResult checkResult; auto it = sourceNodes.find(name); @@ -479,6 +488,9 @@ CheckResult Frontend::check(const ModuleName& name) bool Frontend::parseGraph(std::vector& buildQueue, CheckResult& checkResult, const ModuleName& root) { + LUAU_TIMETRACE_SCOPE("Frontend::parseGraph", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("root", root.c_str()); + // https://en.wikipedia.org/wiki/Topological_sorting#Depth-first_search enum Mark { @@ -597,6 +609,9 @@ ScopePtr Frontend::getModuleEnvironment(const SourceModule& module, const Config LintResult Frontend::lint(const ModuleName& name, std::optional enabledLintWarnings) { + LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + CheckResult checkResult; auto [_sourceNode, sourceModule] = getSourceNode(checkResult, name); @@ -608,6 +623,8 @@ LintResult Frontend::lint(const ModuleName& name, std::optional Frontend::lintFragment(std::string_view source, std::optional enabledLintWarnings) { + LUAU_TIMETRACE_SCOPE("Frontend::lintFragment", "Frontend"); + const Config& config = configResolver->getConfig(""); SourceModule sourceModule = parse(ModuleName{}, source, config.parseOptions); @@ -627,6 +644,9 @@ std::pair Frontend::lintFragment(std::string_view sour CheckResult Frontend::check(const SourceModule& module) { + LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); + const Config& config = configResolver->getConfig(module.name); Mode mode = module.mode.value_or(config.mode); @@ -648,6 +668,9 @@ CheckResult Frontend::check(const SourceModule& module) LintResult Frontend::lint(const SourceModule& module, std::optional enabledLintWarnings) { + LUAU_TIMETRACE_SCOPE("Frontend::lint", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); + const Config& config = configResolver->getConfig(module.name); LintOptions options = enabledLintWarnings.value_or(config.enabledLint); @@ -746,6 +769,9 @@ const SourceModule* Frontend::getSourceModule(const ModuleName& moduleName) cons // Read AST into sourceModules if necessary. Trace require()s. Report parse errors. std::pair Frontend::getSourceNode(CheckResult& checkResult, const ModuleName& name) { + LUAU_TIMETRACE_SCOPE("Frontend::getSourceNode", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + auto it = sourceNodes.find(name); if (it != sourceNodes.end() && !it->second.dirty) { @@ -815,6 +841,9 @@ std::pair Frontend::getSourceNode(CheckResult& check */ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const ParseOptions& parseOptions) { + LUAU_TIMETRACE_SCOPE("Frontend::parse", "Frontend"); + LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); + SourceModule sourceModule; double timestamp = getTimestamp(); @@ -864,20 +893,11 @@ std::optional FrontendModuleResolver::resolveModuleInfo(const Module const auto& exprs = it->second.exprs; - const ModuleName* relativeName = exprs.find(&pathExpr); - if (!relativeName || relativeName->empty()) + const ModuleInfo* info = exprs.find(&pathExpr); + if (!info || (!FFlag::LuauNewRequireTrace && info->name.empty())) return std::nullopt; - if (FFlag::LuauTraceRequireLookupChild) - { - const bool* optional = it->second.optional.find(&pathExpr); - - return {{*relativeName, optional ? *optional : false}}; - } - else - { - return {{*relativeName, false}}; - } + return *info; } const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) const @@ -891,12 +911,15 @@ const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) bool FrontendModuleResolver::moduleExists(const ModuleName& moduleName) const { - return frontend->fileResolver->moduleExists(moduleName); + if (FFlag::LuauNewRequireTrace) + return frontend->sourceNodes.count(moduleName) != 0; + else + return frontend->fileResolver->moduleExists(moduleName); } std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName& moduleName) const { - return frontend->fileResolver->getHumanReadableModuleName_(moduleName).value_or(moduleName); + return frontend->fileResolver->getHumanReadableModuleName(moduleName); } ScopePtr Frontend::addEnvironment(const std::string& environmentName) diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 84e9b77f7..3b2671213 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -2,6 +2,8 @@ #include "Luau/IostreamHelpers.h" #include "Luau/ToString.h" +LUAU_FASTFLAG(LuauTypeAliasPacks) + namespace Luau { @@ -92,7 +94,7 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo { stream << "IncorrectGenericParameterCount { name = " << error.name; - if (!error.typeFun.typeParams.empty()) + if (!error.typeFun.typeParams.empty() || (FFlag::LuauTypeAliasPacks && !error.typeFun.typePackParams.empty())) { stream << "<"; bool first = true; @@ -105,6 +107,20 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo stream << toString(t); } + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId t : error.typeFun.typePackParams) + { + if (first) + first = false; + else + stream << ", "; + + stream << toString(t); + } + } + stream << ">"; } diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index a1018297e..064accba5 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -3,6 +3,9 @@ #include "Luau/Ast.h" #include "Luau/StringUtils.h" +#include "Luau/Common.h" + +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -612,6 +615,12 @@ struct AstJsonEncoder : public AstVisitor writeNode(node, "AstStatTypeAlias", [&]() { PROP(name); PROP(generics); + + if (FFlag::LuauTypeAliasPacks) + { + PROP(genericPacks); + } + PROP(type); PROP(exported); }); @@ -664,13 +673,21 @@ struct AstJsonEncoder : public AstVisitor }); } + void write(struct AstTypeOrPack node) + { + if (node.type) + write(node.type); + else + write(node.typePack); + } + void write(class AstTypeReference* node) { writeNode(node, "AstTypeReference", [&]() { if (node->hasPrefix) PROP(prefix); PROP(name); - PROP(generics); + PROP(parameters); }); } @@ -734,6 +751,13 @@ struct AstJsonEncoder : public AstVisitor }); } + void write(class AstTypePackExplicit* node) + { + writeNode(node, "AstTypePackExplicit", [&]() { + PROP(typeList); + }); + } + void write(class AstTypePackVariadic* node) { writeNode(node, "AstTypePackVariadic", [&]() { @@ -1018,6 +1042,12 @@ struct AstJsonEncoder : public AstVisitor return false; } + bool visit(class AstTypePackExplicit* node) override + { + write(node); + return false; + } + bool visit(class AstTypePackVariadic* node) override { write(node); diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index f97f6a4ad..bff947a56 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -3,6 +3,7 @@ #include "Luau/AstQuery.h" #include "Luau/Module.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/StringUtils.h" #include "Luau/Common.h" @@ -12,6 +13,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauLinterUnknownTypeVectorAware, false) +LUAU_FASTFLAGVARIABLE(LuauLinterTableMoveZero, false) namespace Luau { @@ -85,10 +87,10 @@ struct LintContext return std::nullopt; auto it = module->astTypes.find(expr); - if (it == module->astTypes.end()) + if (!it) return std::nullopt; - return it->second; + return *it; } }; @@ -2144,6 +2146,19 @@ class LintTableOperations : AstVisitor "wrap it in parentheses to silence"); } + if (FFlag::LuauLinterTableMoveZero && func->index == "move" && node->args.size >= 4) + { + // table.move(t, 0, _, _) + if (isConstant(args[1], 0.0)) + emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, + "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + + // table.move(t, _, _, 0) + else if (isConstant(args[3], 0.0)) + emitWarning(*context, LintWarning::Code_TableOperations, args[3]->location, + "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + } + return true; } diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index f1d975fec..df6be767b 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Module.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" @@ -13,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -188,7 +190,7 @@ struct TypePackCloner template void defaultClone(const T& t) { - TypePackId cloned = dest.typePacks.allocate(t); + TypePackId cloned = dest.addTypePack(TypePackVar{t}); seenTypePacks[typePackId] = cloned; } @@ -197,7 +199,7 @@ struct TypePackCloner if (encounteredFreeType) *encounteredFreeType = true; - seenTypePacks[typePackId] = dest.typePacks.allocate(TypePackVar{Unifiable::Error{}}); + seenTypePacks[typePackId] = dest.addTypePack(TypePackVar{Unifiable::Error{}}); } void operator()(const Unifiable::Generic& t) @@ -219,13 +221,13 @@ struct TypePackCloner void operator()(const VariadicTypePack& t) { - TypePackId cloned = dest.typePacks.allocate(VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, encounteredFreeType)}); + TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, encounteredFreeType)}}); seenTypePacks[typePackId] = cloned; } void operator()(const TypePack& t) { - TypePackId cloned = dest.typePacks.allocate(TypePack{}); + TypePackId cloned = dest.addTypePack(TypePack{}); TypePack* destTp = getMutable(cloned); LUAU_ASSERT(destTp != nullptr); seenTypePacks[typePackId] = cloned; @@ -241,7 +243,7 @@ struct TypePackCloner template void TypeCloner::defaultClone(const T& t) { - TypeId cloned = dest.typeVars.allocate(t); + TypeId cloned = dest.addType(t); seenTypes[typeId] = cloned; } @@ -250,7 +252,7 @@ void TypeCloner::operator()(const Unifiable::Free& t) if (encounteredFreeType) *encounteredFreeType = true; - seenTypes[typeId] = dest.typeVars.allocate(ErrorTypeVar{}); + seenTypes[typeId] = dest.addType(ErrorTypeVar{}); } void TypeCloner::operator()(const Unifiable::Generic& t) @@ -275,7 +277,7 @@ void TypeCloner::operator()(const PrimitiveTypeVar& t) void TypeCloner::operator()(const FunctionTypeVar& t) { - TypeId result = dest.typeVars.allocate(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); + TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); FunctionTypeVar* ftv = getMutable(result); LUAU_ASSERT(ftv != nullptr); @@ -297,7 +299,7 @@ void TypeCloner::operator()(const FunctionTypeVar& t) void TypeCloner::operator()(const TableTypeVar& t) { - TypeId result = dest.typeVars.allocate(TableTypeVar{}); + TypeId result = dest.addType(TableTypeVar{}); TableTypeVar* ttv = getMutable(result); LUAU_ASSERT(ttv != nullptr); @@ -323,7 +325,13 @@ void TypeCloner::operator()(const TableTypeVar& t) ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); for (TypeId& arg : ttv->instantiatedTypeParams) - arg = (clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType)); + arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId& arg : ttv->instantiatedTypePackParams) + arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType); + } if (ttv->state == TableState::Free) { @@ -343,7 +351,7 @@ void TypeCloner::operator()(const TableTypeVar& t) void TypeCloner::operator()(const MetatableTypeVar& t) { - TypeId result = dest.typeVars.allocate(MetatableTypeVar{}); + TypeId result = dest.addType(MetatableTypeVar{}); MetatableTypeVar* mtv = getMutable(result); seenTypes[typeId] = result; @@ -353,7 +361,7 @@ void TypeCloner::operator()(const MetatableTypeVar& t) void TypeCloner::operator()(const ClassTypeVar& t) { - TypeId result = dest.typeVars.allocate(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData}); + TypeId result = dest.addType(ClassTypeVar{t.name, {}, std::nullopt, std::nullopt, t.tags, t.userData}); ClassTypeVar* ctv = getMutable(result); seenTypes[typeId] = result; @@ -378,7 +386,7 @@ void TypeCloner::operator()(const AnyTypeVar& t) void TypeCloner::operator()(const UnionTypeVar& t) { - TypeId result = dest.typeVars.allocate(UnionTypeVar{}); + TypeId result = dest.addType(UnionTypeVar{}); seenTypes[typeId] = result; UnionTypeVar* option = getMutable(result); @@ -390,7 +398,7 @@ void TypeCloner::operator()(const UnionTypeVar& t) void TypeCloner::operator()(const IntersectionTypeVar& t) { - TypeId result = dest.typeVars.allocate(IntersectionTypeVar{}); + TypeId result = dest.addType(IntersectionTypeVar{}); seenTypes[typeId] = result; IntersectionTypeVar* option = getMutable(result); @@ -451,8 +459,14 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType) { TypeFun result; - for (TypeId param : typeFun.typeParams) - result.typeParams.push_back(clone(param, dest, seenTypes, seenTypePacks, encounteredFreeType)); + for (TypeId ty : typeFun.typeParams) + result.typeParams.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId tp : typeFun.typePackParams) + result.typePackParams.push_back(clone(tp, dest, seenTypes, seenTypePacks, encounteredFreeType)); + } result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, encounteredFreeType); diff --git a/Analysis/src/RequireTracer.cpp b/Analysis/src/RequireTracer.cpp index 5b3997e2c..ad4d5ef43 100644 --- a/Analysis/src/RequireTracer.cpp +++ b/Analysis/src/RequireTracer.cpp @@ -5,6 +5,7 @@ #include "Luau/Module.h" LUAU_FASTFLAGVARIABLE(LuauTraceRequireLookupChild, false) +LUAU_FASTFLAGVARIABLE(LuauNewRequireTrace, false) namespace Luau { @@ -12,17 +13,18 @@ namespace Luau namespace { -struct RequireTracer : AstVisitor +struct RequireTracerOld : AstVisitor { - explicit RequireTracer(FileResolver* fileResolver, ModuleName currentModuleName) + explicit RequireTracerOld(FileResolver* fileResolver, const ModuleName& currentModuleName) : fileResolver(fileResolver) - , currentModuleName(std::move(currentModuleName)) + , currentModuleName(currentModuleName) { + LUAU_ASSERT(!FFlag::LuauNewRequireTrace); } FileResolver* const fileResolver; ModuleName currentModuleName; - DenseHashMap locals{0}; + DenseHashMap locals{nullptr}; RequireTraceResult result; std::optional fromAstFragment(AstExpr* expr) @@ -50,9 +52,9 @@ struct RequireTracer : AstVisitor AstExpr* expr = stat->values.data[i]; expr->visit(this); - const ModuleName* name = result.exprs.find(expr); - if (name) - locals[local] = *name; + const ModuleInfo* info = result.exprs.find(expr); + if (info) + locals[local] = info->name; } } @@ -63,7 +65,7 @@ struct RequireTracer : AstVisitor { std::optional name = fromAstFragment(global); if (name) - result.exprs[global] = *name; + result.exprs[global] = {*name}; return false; } @@ -72,7 +74,7 @@ struct RequireTracer : AstVisitor { const ModuleName* name = locals.find(local->local); if (name) - result.exprs[local] = *name; + result.exprs[local] = {*name}; return false; } @@ -81,16 +83,16 @@ struct RequireTracer : AstVisitor { indexName->expr->visit(this); - const ModuleName* name = result.exprs.find(indexName->expr); - if (name) + const ModuleInfo* info = result.exprs.find(indexName->expr); + if (info) { if (indexName->index == "parent" || indexName->index == "Parent") { - if (auto parent = fileResolver->getParentModuleName(*name)) - result.exprs[indexName] = *parent; + if (auto parent = fileResolver->getParentModuleName(info->name)) + result.exprs[indexName] = {*parent}; } else - result.exprs[indexName] = fileResolver->concat(*name, indexName->index.value); + result.exprs[indexName] = {fileResolver->concat(info->name, indexName->index.value)}; } return false; @@ -100,11 +102,11 @@ struct RequireTracer : AstVisitor { indexExpr->expr->visit(this); - const ModuleName* name = result.exprs.find(indexExpr->expr); + const ModuleInfo* info = result.exprs.find(indexExpr->expr); const AstExprConstantString* str = indexExpr->index->as(); - if (name && str) + if (info && str) { - result.exprs[indexExpr] = fileResolver->concat(*name, std::string_view(str->value.data, str->value.size)); + result.exprs[indexExpr] = {fileResolver->concat(info->name, std::string_view(str->value.data, str->value.size))}; } indexExpr->index->visit(this); @@ -129,8 +131,8 @@ struct RequireTracer : AstVisitor AstExprGlobal* globalName = call->func->as(); if (globalName && globalName->name == "require" && call->args.size >= 1) { - if (const ModuleName* moduleName = result.exprs.find(call->args.data[0])) - result.requires.push_back({*moduleName, call->location}); + if (const ModuleInfo* moduleInfo = result.exprs.find(call->args.data[0])) + result.requires.push_back({moduleInfo->name, call->location}); return false; } @@ -143,8 +145,8 @@ struct RequireTracer : AstVisitor if (FFlag::LuauTraceRequireLookupChild && !rootName) { - if (const ModuleName* moduleName = result.exprs.find(indexName->expr)) - rootName = *moduleName; + if (const ModuleInfo* moduleInfo = result.exprs.find(indexName->expr)) + rootName = moduleInfo->name; } if (!rootName) @@ -167,24 +169,183 @@ struct RequireTracer : AstVisitor if (v.end() != std::find(v.begin(), v.end(), '/')) return false; - result.exprs[call] = fileResolver->concat(*rootName, v); + result.exprs[call] = {fileResolver->concat(*rootName, v)}; // 'WaitForChild' can be used on modules that are not awailable at the typecheck time, but will be awailable at runtime // If we fail to find such module, we will not report an UnknownRequire error if (FFlag::LuauTraceRequireLookupChild && indexName->index == "WaitForChild") - result.optional[call] = true; + result.exprs[call].optional = true; return false; } }; +struct RequireTracer : AstVisitor +{ + RequireTracer(RequireTraceResult& result, FileResolver * fileResolver, const ModuleName& currentModuleName) + : result(result) + , fileResolver(fileResolver) + , currentModuleName(currentModuleName) + , locals(nullptr) + { + LUAU_ASSERT(FFlag::LuauNewRequireTrace); + } + + bool visit(AstExprTypeAssertion* expr) override + { + // suppress `require() :: any` + return false; + } + + bool visit(AstExprCall* expr) override + { + AstExprGlobal* global = expr->func->as(); + + if (global && global->name == "require" && expr->args.size >= 1) + requires.push_back(expr); + + return true; + } + + bool visit(AstStatLocal* stat) override + { + for (size_t i = 0; i < stat->vars.size && i < stat->values.size; ++i) + { + AstLocal* local = stat->vars.data[i]; + AstExpr* expr = stat->values.data[i]; + + // track initializing expression to be able to trace modules through locals + locals[local] = expr; + } + + return true; + } + + bool visit(AstStatAssign* stat) override + { + for (size_t i = 0; i < stat->vars.size; ++i) + { + // locals that are assigned don't have a known expression + if (AstExprLocal* expr = stat->vars.data[i]->as()) + locals[expr->local] = nullptr; + } + + return true; + } + + bool visit(AstType* node) override + { + // allow resolving require inside `typeof` annotations + return true; + } + + AstExpr* getDependent(AstExpr* node) + { + if (AstExprLocal* expr = node->as()) + return locals[expr->local]; + else if (AstExprIndexName* expr = node->as()) + return expr->expr; + else if (AstExprIndexExpr* expr = node->as()) + return expr->expr; + else if (AstExprCall* expr = node->as(); expr && expr->self) + return expr->func->as()->expr; + else + return nullptr; + } + + void process() + { + ModuleInfo moduleContext{currentModuleName}; + + // seed worklist with require arguments + work.reserve(requires.size()); + + for (AstExprCall* require: requires) + work.push_back(require->args.data[0]); + + // push all dependent expressions to the work stack; note that the vector is modified during traversal + for (size_t i = 0; i < work.size(); ++i) + if (AstExpr* dep = getDependent(work[i])) + work.push_back(dep); + + // resolve all expressions to a module info + for (size_t i = work.size(); i > 0; --i) + { + AstExpr* expr = work[i - 1]; + + // when multiple expressions depend on the same one we push it to work queue multiple times + if (result.exprs.contains(expr)) + continue; + + std::optional info; + + if (AstExpr* dep = getDependent(expr)) + { + const ModuleInfo* context = result.exprs.find(dep); + + // locals just inherit their dependent context, no resolution required + if (expr->is()) + info = context ? std::optional(*context) : std::nullopt; + else + info = fileResolver->resolveModule(context, expr); + } + else + { + info = fileResolver->resolveModule(&moduleContext, expr); + } + + if (info) + result.exprs[expr] = std::move(*info); + } + + // resolve all requires according to their argument + result.requires.reserve(requires.size()); + + for (AstExprCall* require : requires) + { + AstExpr* arg = require->args.data[0]; + + if (const ModuleInfo* info = result.exprs.find(arg)) + { + result.requires.push_back({info->name, require->location}); + + ModuleInfo infoCopy = *info; // copy *info out since next line invalidates info! + result.exprs[require] = std::move(infoCopy); + } + else + { + result.exprs[require] = {}; // mark require as unresolved + } + } + } + + RequireTraceResult& result; + FileResolver* fileResolver; + ModuleName currentModuleName; + + DenseHashMap locals; + std::vector work; + std::vector requires; +}; + } // anonymous namespace -RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, ModuleName currentModuleName) +RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName) { - RequireTracer tracer{fileResolver, std::move(currentModuleName)}; - root->visit(&tracer); - return tracer.result; + if (FFlag::LuauNewRequireTrace) + { + RequireTraceResult result; + RequireTracer tracer{result, fileResolver, currentModuleName}; + root->visit(&tracer); + tracer.process(); + return result; + } + else + { + RequireTracerOld tracer{fileResolver, currentModuleName}; + root->visit(&tracer); + return tracer.result; + } } } // namespace Luau diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp new file mode 100644 index 000000000..c30db9c25 --- /dev/null +++ b/Analysis/src/Scope.cpp @@ -0,0 +1,123 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Scope.h" + +namespace Luau +{ + +Scope::Scope(TypePackId returnType) + : parent(nullptr) + , returnType(returnType) + , level(TypeLevel()) +{ +} + +Scope::Scope(const ScopePtr& parent, int subLevel) + : parent(parent) + , returnType(parent->returnType) + , level(parent->level.incr()) +{ + level.subLevel = subLevel; +} + +std::optional Scope::lookup(const Symbol& name) +{ + Scope* scope = this; + + while (scope) + { + auto it = scope->bindings.find(name); + if (it != scope->bindings.end()) + return it->second.typeId; + + scope = scope->parent.get(); + } + + return std::nullopt; +} + +std::optional Scope::lookupType(const Name& name) +{ + const Scope* scope = this; + while (true) + { + auto it = scope->exportedTypeBindings.find(name); + if (it != scope->exportedTypeBindings.end()) + return it->second; + + it = scope->privateTypeBindings.find(name); + if (it != scope->privateTypeBindings.end()) + return it->second; + + if (scope->parent) + scope = scope->parent.get(); + else + return std::nullopt; + } +} + +std::optional Scope::lookupImportedType(const Name& moduleAlias, const Name& name) +{ + const Scope* scope = this; + while (scope) + { + auto it = scope->importedTypeBindings.find(moduleAlias); + if (it == scope->importedTypeBindings.end()) + { + scope = scope->parent.get(); + continue; + } + + auto it2 = it->second.find(name); + if (it2 == it->second.end()) + { + scope = scope->parent.get(); + continue; + } + + return it2->second; + } + + return std::nullopt; +} + +std::optional Scope::lookupPack(const Name& name) +{ + const Scope* scope = this; + while (true) + { + auto it = scope->privateTypePackBindings.find(name); + if (it != scope->privateTypePackBindings.end()) + return it->second; + + if (scope->parent) + scope = scope->parent.get(); + else + return std::nullopt; + } +} + +std::optional Scope::linearSearchForBinding(const std::string& name, bool traverseScopeChain) +{ + Scope* scope = this; + + while (scope) + { + for (const auto& [n, binding] : scope->bindings) + { + if (n.local && n.local->name == name.c_str()) + return binding; + else if (n.global.value && n.global == name.c_str()) + return binding; + } + + scope = scope->parent.get(); + + if (!traverseScopeChain) + break; + } + + return std::nullopt; +} + +} // namespace Luau diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 7223998a3..d861eb3da 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -6,9 +6,11 @@ #include #include -LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 0) +LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000) +LUAU_FASTFLAGVARIABLE(LuauSubstitutionDontReplaceIgnoredTypes, false) LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAG(LuauRankNTypes) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -35,8 +37,15 @@ void Tarjan::visitChildren(TypeId ty, int index) visitChild(ttv->indexer->indexType); visitChild(ttv->indexer->indexResultType); } + for (TypeId itp : ttv->instantiatedTypeParams) visitChild(itp); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId itp : ttv->instantiatedTypePackParams) + visitChild(itp); + } } else if (const MetatableTypeVar* mtv = get(ty)) { @@ -332,9 +341,11 @@ std::optional Substitution::substitute(TypeId ty) return std::nullopt; for (auto [oldTy, newTy] : newTypes) - replaceChildren(newTy); + if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTy)) + replaceChildren(newTy); for (auto [oldTp, newTp] : newPacks) - replaceChildren(newTp); + if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTp)) + replaceChildren(newTp); TypeId newTy = replace(ty); return newTy; } @@ -350,9 +361,11 @@ std::optional Substitution::substitute(TypePackId tp) return std::nullopt; for (auto [oldTy, newTy] : newTypes) - replaceChildren(newTy); + if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTy)) + replaceChildren(newTy); for (auto [oldTp, newTp] : newPacks) - replaceChildren(newTp); + if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTp)) + replaceChildren(newTp); TypePackId newTp = replace(tp); return newTp; } @@ -382,6 +395,10 @@ TypeId Substitution::clone(TypeId ty) clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; clone.instantiatedTypeParams = ttv->instantiatedTypeParams; + + if (FFlag::LuauTypeAliasPacks) + clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; + if (FFlag::LuauSecondTypecheckKnowsTheDataModel) clone.tags = ttv->tags; result = addType(std::move(clone)); @@ -487,8 +504,15 @@ void Substitution::replaceChildren(TypeId ty) ttv->indexer->indexType = replace(ttv->indexer->indexType); ttv->indexer->indexResultType = replace(ttv->indexer->indexResultType); } + for (TypeId& itp : ttv->instantiatedTypeParams) itp = replace(itp); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId& itp : ttv->instantiatedTypePackParams) + itp = replace(itp); + } } else if (MetatableTypeVar* mtv = getMutable(ty)) { diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 9d2f47ba6..5651af7e9 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/ToString.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" @@ -9,10 +10,10 @@ #include #include -LUAU_FASTFLAG(LuauToStringFollowsBoundTo) LUAU_FASTFLAG(LuauExtraNilRecovery) LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) LUAU_FASTFLAGVARIABLE(LuauInstantiatedTypeParamRecursion, false) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -59,6 +60,13 @@ struct FindCyclicTypes { for (TypeId itp : ttv.instantiatedTypeParams) visitTypeVar(itp, *this, seen); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId itp : ttv.instantiatedTypePackParams) + visitTypeVar(itp, *this, seen); + } + return exhaustive; } @@ -258,23 +266,60 @@ struct TypeVarStringifier void stringify(TypePackId tp); void stringify(TypePackId tpid, const std::vector>& names); - void stringify(const std::vector& types) + void stringify(const std::vector& types, const std::vector& typePacks) { - if (types.size() == 0) + if (types.size() == 0 && (!FFlag::LuauTypeAliasPacks || typePacks.size() == 0)) return; - if (types.size()) + if (types.size() || (FFlag::LuauTypeAliasPacks && typePacks.size())) state.emit("<"); - for (size_t i = 0; i < types.size(); ++i) + if (FFlag::LuauTypeAliasPacks) { - if (i > 0) - state.emit(", "); + bool first = true; + + for (TypeId ty : types) + { + if (!first) + state.emit(", "); + first = false; + + stringify(ty); + } + + bool singleTp = typePacks.size() == 1; + + for (TypePackId tp : typePacks) + { + if (isEmpty(tp) && singleTp) + continue; + + if (!first) + state.emit(", "); + else + first = false; + + if (!singleTp) + state.emit("("); + + stringify(tp); - stringify(types[i]); + if (!singleTp) + state.emit(")"); + } } + else + { + for (size_t i = 0; i < types.size(); ++i) + { + if (i > 0) + state.emit(", "); - if (types.size()) + stringify(types[i]); + } + } + + if (types.size() || (FFlag::LuauTypeAliasPacks && typePacks.size())) state.emit(">"); } @@ -388,7 +433,7 @@ struct TypeVarStringifier void operator()(TypeId, const TableTypeVar& ttv) { - if (FFlag::LuauToStringFollowsBoundTo && ttv.boundTo) + if (ttv.boundTo) return stringify(*ttv.boundTo); if (!state.exhaustive) @@ -411,14 +456,14 @@ struct TypeVarStringifier } state.emit(*ttv.name); - stringify(ttv.instantiatedTypeParams); + stringify(ttv.instantiatedTypeParams, ttv.instantiatedTypePackParams); return; } if (ttv.syntheticName) { state.result.invalid = true; state.emit(*ttv.syntheticName); - stringify(ttv.instantiatedTypeParams); + stringify(ttv.instantiatedTypeParams, ttv.instantiatedTypePackParams); return; } } @@ -900,13 +945,26 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) result.name += ttv->name ? *ttv->name : *ttv->syntheticName; - if (ttv->instantiatedTypeParams.empty()) + if (ttv->instantiatedTypeParams.empty() && (!FFlag::LuauTypeAliasPacks || ttv->instantiatedTypePackParams.empty())) return result; std::vector params; for (TypeId tp : ttv->instantiatedTypeParams) params.push_back(toString(tp)); + if (FFlag::LuauTypeAliasPacks) + { + // Doesn't preserve grouping of multiple type packs + // But this is under a parent block of code that is being removed later + for (TypePackId tp : ttv->instantiatedTypePackParams) + { + std::string content = toString(tp); + + if (!content.empty()) + params.push_back(std::move(content)); + } + } + result.name += "<" + join(params, ", ") + ">"; return result; } @@ -950,30 +1008,37 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) result.name += ttv->name ? *ttv->name : *ttv->syntheticName; - if (ttv->instantiatedTypeParams.empty()) - return result; - - result.name += "<"; - - bool first = true; - for (TypeId ty : ttv->instantiatedTypeParams) + if (FFlag::LuauTypeAliasPacks) { - if (!first) - result.name += ", "; - else - first = false; - - tvs.stringify(ty); - } - - if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) - { - result.truncated = true; - result.name += "... "; + tvs.stringify(ttv->instantiatedTypeParams, ttv->instantiatedTypePackParams); } else { - result.name += ">"; + if (ttv->instantiatedTypeParams.empty() && (!FFlag::LuauTypeAliasPacks || ttv->instantiatedTypePackParams.empty())) + return result; + + result.name += "<"; + + bool first = true; + for (TypeId ty : ttv->instantiatedTypeParams) + { + if (!first) + result.name += ", "; + else + first = false; + + tvs.stringify(ty); + } + + if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) + { + result.truncated = true; + result.name += "... "; + } + else + { + result.name += ">"; + } } return result; diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 462c70ffc..1b83ccdc2 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -11,6 +11,7 @@ #include LUAU_FASTFLAG(LuauGenericFunctions) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace { @@ -280,10 +281,19 @@ struct Printer void visualizeTypePackAnnotation(const AstTypePack& annotation) { - if (const AstTypePackVariadic* variadic = annotation.as()) + if (const AstTypePackVariadic* variadicTp = annotation.as()) { writer.symbol("..."); - visualizeTypeAnnotation(*variadic->variadicType); + visualizeTypeAnnotation(*variadicTp->variadicType); + } + else if (const AstTypePackGeneric* genericTp = annotation.as()) + { + writer.symbol(genericTp->genericName.value); + writer.symbol("..."); + } + else if (const AstTypePackExplicit* explicitTp = annotation.as()) + { + visualizeTypeList(explicitTp->typeList, true); } else { @@ -807,7 +817,7 @@ struct Printer writer.keyword("type"); writer.identifier(a->name.value); - if (a->generics.size > 0) + if (a->generics.size > 0 || (FFlag::LuauTypeAliasPacks && a->genericPacks.size > 0)) { writer.symbol("<"); CommaSeparatorInserter comma(writer); @@ -817,6 +827,17 @@ struct Printer comma(); writer.identifier(o.value); } + + if (FFlag::LuauTypeAliasPacks) + { + for (auto o : a->genericPacks) + { + comma(); + writer.identifier(o.value); + writer.symbol("..."); + } + } + writer.symbol(">"); } writer.maybeSpace(a->type->location.begin, 2); @@ -960,15 +981,20 @@ struct Printer if (const auto& a = typeAnnotation.as()) { writer.write(a->name.value); - if (a->generics.size > 0) + if (a->parameters.size > 0) { CommaSeparatorInserter comma(writer); writer.symbol("<"); - for (auto o : a->generics) + for (auto o : a->parameters) { comma(); - visualizeTypeAnnotation(*o); + + if (o.type) + visualizeTypeAnnotation(*o.type); + else + visualizeTypePackAnnotation(*o.typePack); } + writer.symbol(">"); } } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 17c57c848..266c19865 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -5,6 +5,7 @@ #include "Luau/Module.h" #include "Luau/Parser.h" #include "Luau/RecursionCounter.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" @@ -12,6 +13,7 @@ #include LUAU_FASTFLAG(LuauGenericFunctions) +LUAU_FASTFLAG(LuauTypeAliasPacks) static char* allocateString(Luau::Allocator& allocator, std::string_view contents) { @@ -33,7 +35,6 @@ static char* allocateString(Luau::Allocator& allocator, const char* format, Data namespace Luau { - class TypeRehydrationVisitor { mutable std::map seen; @@ -57,6 +58,8 @@ class TypeRehydrationVisitor { } + AstTypePack* rehydrate(TypePackId tp) const; + AstType* operator()(const PrimitiveTypeVar& ptv) const { switch (ptv.type) @@ -85,16 +88,24 @@ class TypeRehydrationVisitor if (ttv.name && options.bannedNames.find(*ttv.name) == options.bannedNames.end()) { - AstArray generics; - generics.size = ttv.instantiatedTypeParams.size(); - generics.data = static_cast(allocator->allocate(sizeof(AstType*) * generics.size)); + AstArray parameters; + parameters.size = ttv.instantiatedTypeParams.size(); + parameters.data = static_cast(allocator->allocate(sizeof(AstTypeOrPack) * parameters.size)); for (size_t i = 0; i < ttv.instantiatedTypeParams.size(); ++i) { - generics.data[i] = Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty); + parameters.data[i] = {Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty), {}}; + } + + if (FFlag::LuauTypeAliasPacks) + { + for (size_t i = 0; i < ttv.instantiatedTypePackParams.size(); ++i) + { + parameters.data[i] = {{}, rehydrate(ttv.instantiatedTypePackParams[i])}; + } } - return allocator->alloc(Location(), std::nullopt, AstName(ttv.name->c_str()), generics); + return allocator->alloc(Location(), std::nullopt, AstName(ttv.name->c_str()), parameters.size != 0, parameters); } if (hasSeen(&ttv)) @@ -222,10 +233,17 @@ class TypeRehydrationVisitor AstTypePack* argTailAnnotation = nullptr; if (argTail) { - TypePackId tail = *argTail; - if (const VariadicTypePack* vtp = get(tail)) + if (FFlag::LuauTypeAliasPacks) { - argTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); + argTailAnnotation = rehydrate(*argTail); + } + else + { + TypePackId tail = *argTail; + if (const VariadicTypePack* vtp = get(tail)) + { + argTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); + } } } @@ -255,10 +273,17 @@ class TypeRehydrationVisitor AstTypePack* retTailAnnotation = nullptr; if (retTail) { - TypePackId tail = *retTail; - if (const VariadicTypePack* vtp = get(tail)) + if (FFlag::LuauTypeAliasPacks) + { + retTailAnnotation = rehydrate(*retTail); + } + else { - retTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); + TypePackId tail = *retTail; + if (const VariadicTypePack* vtp = get(tail)) + { + retTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); + } } } @@ -313,6 +338,68 @@ class TypeRehydrationVisitor const TypeRehydrationOptions& options; }; +class TypePackRehydrationVisitor +{ +public: + TypePackRehydrationVisitor(Allocator* allocator, const TypeRehydrationVisitor& typeVisitor) + : allocator(allocator) + , typeVisitor(typeVisitor) + { + } + + AstTypePack* operator()(const BoundTypePack& btp) const + { + return Luau::visit(*this, btp.boundTo->ty); + } + + AstTypePack* operator()(const TypePack& tp) const + { + AstArray head; + head.size = tp.head.size(); + head.data = static_cast(allocator->allocate(sizeof(AstType*) * tp.head.size())); + + for (size_t i = 0; i < tp.head.size(); i++) + head.data[i] = Luau::visit(typeVisitor, tp.head[i]->ty); + + AstTypePack* tail = nullptr; + + if (tp.tail) + tail = Luau::visit(*this, (*tp.tail)->ty); + + return allocator->alloc(Location(), AstTypeList{head, tail}); + } + + AstTypePack* operator()(const VariadicTypePack& vtp) const + { + return allocator->alloc(Location(), Luau::visit(typeVisitor, vtp.ty->ty)); + } + + AstTypePack* operator()(const GenericTypePack& gtp) const + { + return allocator->alloc(Location(), AstName(gtp.name.c_str())); + } + + AstTypePack* operator()(const FreeTypePack& gtp) const + { + return allocator->alloc(Location(), AstName("free")); + } + + AstTypePack* operator()(const Unifiable::Error&) const + { + return allocator->alloc(Location(), AstName("Unifiable")); + } + +private: + Allocator* allocator; + const TypeRehydrationVisitor& typeVisitor; +}; + +AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp) const +{ + TypePackRehydrationVisitor tprv(allocator, *this); + return Luau::visit(tprv, tp->ty); +} + class TypeAttacher : public AstVisitor { public: @@ -406,9 +493,16 @@ class TypeAttacher : public AstVisitor if (tail) { - TypePackId tailPack = *tail; - if (const VariadicTypePack* vtp = get(tailPack)) - variadicAnnotation = allocator->alloc(Location(), typeAst(vtp->ty)); + if (FFlag::LuauTypeAliasPacks) + { + variadicAnnotation = TypeRehydrationVisitor(allocator).rehydrate(*tail); + } + else + { + TypePackId tailPack = *tail; + if (const VariadicTypePack* vtp = get(tailPack)) + variadicAnnotation = allocator->alloc(Location(), typeAst(vtp->ty)); + } } fn->returnAnnotation = AstTypeList{typeAstPack(ret), variadicAnnotation}; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 2216881b7..3a1fdfff5 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -5,21 +5,22 @@ #include "Luau/ModuleResolver.h" #include "Luau/Parser.h" #include "Luau/RecursionCounter.h" +#include "Luau/Scope.h" #include "Luau/Substitution.h" #include "Luau/TopoSortStatements.h" #include "Luau/ToString.h" #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" #include "Luau/TypeVar.h" +#include "Luau/TimeTrace.h" #include #include LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) -LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 0) -LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 0) +LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) +LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) -LUAU_FASTFLAGVARIABLE(LuauIndexTablesWithIndexers, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctions, false) LUAU_FASTFLAGVARIABLE(LuauGenericVariadicsUnification, false) LUAU_FASTFLAG(LuauKnowsTheDataModel3) @@ -27,14 +28,11 @@ LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAGVARIABLE(LuauClassPropertyAccessAsString, false) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. -LUAU_FASTFLAGVARIABLE(LuauImprovedTypeGuardPredicate2, false) LUAU_FASTFLAG(LuauTraceRequireLookupChild) -LUAU_FASTFLAG(DebugLuauTrackOwningArena) LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAGVARIABLE(LuauRankNTypes, false) LUAU_FASTFLAGVARIABLE(LuauOrPredicate, false) -LUAU_FASTFLAGVARIABLE(LuauFixTableTypeAliasClone, false) LUAU_FASTFLAGVARIABLE(LuauExtraNilRecovery, false) LUAU_FASTFLAGVARIABLE(LuauMissingUnionPropertyError, false) LUAU_FASTFLAGVARIABLE(LuauInferReturnAssertAssign, false) @@ -45,6 +43,10 @@ LUAU_FASTFLAGVARIABLE(LuauSlightlyMoreFlexibleBinaryPredicates, false) LUAU_FASTFLAGVARIABLE(LuauInferFunctionArgsFix, false) LUAU_FASTFLAGVARIABLE(LuauFollowInTypeFunApply, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) +LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false) +LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes) +LUAU_FASTFLAG(LuauNewRequireTrace) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -216,9 +218,8 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan , nilType(singletonTypes.nilType) , numberType(singletonTypes.numberType) , stringType(singletonTypes.stringType) - , booleanType( - FFlag::LuauImprovedTypeGuardPredicate2 ? singletonTypes.booleanType : globalTypes.addType(PrimitiveTypeVar(PrimitiveTypeVar::Boolean))) - , threadType(FFlag::LuauImprovedTypeGuardPredicate2 ? singletonTypes.threadType : globalTypes.addType(PrimitiveTypeVar(PrimitiveTypeVar::Thread))) + , booleanType(singletonTypes.booleanType) + , threadType(singletonTypes.threadType) , anyType(singletonTypes.anyType) , errorType(singletonTypes.errorType) , optionalNumberType(globalTypes.addType(UnionTypeVar{{numberType, nilType}})) @@ -237,6 +238,9 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optional environmentScope) { + LUAU_TIMETRACE_SCOPE("TypeChecker::check", "TypeChecker"); + LUAU_TIMETRACE_ARGUMENT("module", module.name.c_str()); + currentModule.reset(new Module()); currentModule->type = module.type; @@ -1177,44 +1181,61 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { Location location = scope->typeAliasLocations[name]; reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); - bindingsMap[name] = TypeFun{binding->typeParams, errorType}; + + if (FFlag::LuauTypeAliasPacks) + bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorType}; + else + bindingsMap[name] = TypeFun{binding->typeParams, errorType}; } else { ScopePtr aliasScope = childScope(scope, typealias.location); - std::vector generics; - for (AstName generic : typealias.generics) + if (FFlag::LuauTypeAliasPacks) { - Name n = generic.value; + auto [generics, genericPacks] = createGenericTypes(aliasScope, typealias, typealias.generics, typealias.genericPacks); - // These generics are the only thing that will ever be added to aliasScope, so we can be certain that - // a collision can only occur when two generic typevars have the same name. - if (aliasScope->privateTypeBindings.end() != aliasScope->privateTypeBindings.find(n)) + TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); + FreeTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + ftv->forwardedTypeAlias = true; + bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; + } + else + { + std::vector generics; + for (AstName generic : typealias.generics) { - // TODO(jhuelsman): report the exact span of the generic type parameter whose name is a duplicate. - reportError(TypeError{typealias.location, DuplicateGenericParameter{n}}); - } + Name n = generic.value; - TypeId g; - if (FFlag::LuauRecursiveTypeParameterRestriction) - { - TypeId& cached = scope->typeAliasParameters[n]; - if (!cached) - cached = addType(GenericTypeVar{aliasScope->level, n}); - g = cached; + // These generics are the only thing that will ever be added to aliasScope, so we can be certain that + // a collision can only occur when two generic typevars have the same name. + if (aliasScope->privateTypeBindings.end() != aliasScope->privateTypeBindings.find(n)) + { + // TODO(jhuelsman): report the exact span of the generic type parameter whose name is a duplicate. + reportError(TypeError{typealias.location, DuplicateGenericParameter{n}}); + } + + TypeId g; + if (FFlag::LuauRecursiveTypeParameterRestriction) + { + TypeId& cached = scope->typeAliasTypeParameters[n]; + if (!cached) + cached = addType(GenericTypeVar{aliasScope->level, n}); + g = cached; + } + else + g = addType(GenericTypeVar{aliasScope->level, n}); + generics.push_back(g); + aliasScope->privateTypeBindings[n] = TypeFun{{}, g}; } - else - g = addType(GenericTypeVar{aliasScope->level, n}); - generics.push_back(g); - aliasScope->privateTypeBindings[n] = TypeFun{{}, g}; - } - TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); - FreeTypeVar* ftv = getMutable(ty); - LUAU_ASSERT(ftv); - ftv->forwardedTypeAlias = true; - bindingsMap[name] = {std::move(generics), ty}; + TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); + FreeTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + ftv->forwardedTypeAlias = true; + bindingsMap[name] = {std::move(generics), ty}; + } } } else @@ -1231,6 +1252,16 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, ty}; } + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId tp : binding->typePackParams) + { + auto generic = get(tp); + LUAU_ASSERT(generic); + aliasScope->privateTypePackBindings[generic->name] = tp; + } + } + TypeId ty = (FFlag::LuauRankNTypes ? resolveType(aliasScope, *typealias.type) : resolveType(aliasScope, *typealias.type, true)); if (auto ttv = getMutable(follow(ty))) { @@ -1238,7 +1269,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias if (ttv->name) { // Copy can be skipped if this is an identical alias - if (!FFlag::LuauFixTableTypeAliasClone || ttv->name != name || ttv->instantiatedTypeParams != binding->typeParams) + if (ttv->name != name || ttv->instantiatedTypeParams != binding->typeParams || + (FFlag::LuauTypeAliasPacks && ttv->instantiatedTypePackParams != binding->typePackParams)) { // This is a shallow clone, original recursive links to self are not updated TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; @@ -1249,6 +1281,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias clone.name = name; clone.instantiatedTypeParams = binding->typeParams; + if (FFlag::LuauTypeAliasPacks) + clone.instantiatedTypePackParams = binding->typePackParams; + ty = addType(std::move(clone)); } } @@ -1256,6 +1291,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { ttv->name = name; ttv->instantiatedTypeParams = binding->typeParams; + + if (FFlag::LuauTypeAliasPacks) + ttv->instantiatedTypePackParams = binding->typePackParams; } } else if (auto mtv = getMutable(follow(ty))) @@ -1280,7 +1318,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar } // We don't have generic classes, so this assertion _should_ never be hit. - LUAU_ASSERT(lookupType->typeParams.size() == 0); + LUAU_ASSERT(lookupType->typeParams.size() == 0 && (!FFlag::LuauTypeAliasPacks || lookupType->typePackParams.size() == 0)); superTy = lookupType->type; if (FFlag::LuauAddMissingFollow) @@ -1465,7 +1503,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& if (FFlag::LuauStoreMatchingOverloadFnType) { - currentModule->astTypes.try_emplace(&expr, result.type); + if (!currentModule->astTypes.find(&expr)) + currentModule->astTypes[&expr] = result.type; } else { @@ -2193,7 +2232,7 @@ TypeId TypeChecker::checkRelationalOperation( * have a better, more descriptive error teed up. */ Unifier state = mkUnifier(expr.location); - if (!FFlag::LuauEqConstraint || !isEquality) + if (!isEquality) state.tryUnify(lhsType, rhsType); bool needsMetamethod = !isEquality; @@ -2262,7 +2301,7 @@ TypeId TypeChecker::checkRelationalOperation( } } - if (get(FFlag::LuauAddMissingFollow ? follow(lhsType) : lhsType) && (!FFlag::LuauEqConstraint || !isEquality)) + if (get(FFlag::LuauAddMissingFollow ? follow(lhsType) : lhsType) && !isEquality) { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Comparison}); @@ -2276,18 +2315,6 @@ TypeId TypeChecker::checkRelationalOperation( return errorType; } - if (!FFlag::LuauEqConstraint) - { - if (isEquality) - { - ErrorVec errVec = tryUnify(rhsType, lhsType, expr.location); - if (!state.errors.empty() && !errVec.empty()) - reportError(expr.location, TypeMismatch{lhsType, rhsType}); - } - else - reportErrors(state.errors); - } - return booleanType; } @@ -2443,7 +2470,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi TypeId result = checkBinaryOperation(innerScope, expr, lhs.type, rhs.type, lhs.predicates); return {result, {OrPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; } - else if (FFlag::LuauEqConstraint && (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe)) + else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) { if (auto predicate = tryGetTypeGuardPredicate(expr)) return {booleanType, {std::move(*predicate)}}; @@ -2466,14 +2493,6 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi } else { - // Once we have EqPredicate, we should break this else branch into its' own branch. - // For now, fall through is intentional. - if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) - { - if (auto predicate = tryGetTypeGuardPredicate(expr)) - return {booleanType, {std::move(*predicate)}}; - } - ExprResult lhs = checkExpr(scope, *expr.left); ExprResult rhs = checkExpr(scope, *expr.right); @@ -2755,12 +2774,6 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; return std::pair(resultType, nullptr); } - else if (FFlag::LuauIndexTablesWithIndexers) - { - // We allow t[x] where x:string for tables without an indexer - unify(indexType, stringType, expr.location); - return std::pair(anyType, nullptr); - } else { TypeId resultType = freshType(scope); @@ -3076,6 +3089,13 @@ static Location getEndLocation(const AstExprFunction& function) void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstExprFunction& function) { + LUAU_TIMETRACE_SCOPE("TypeChecker::checkFunctionBody", "TypeChecker"); + + if (function.debugname.value) + LUAU_TIMETRACE_ARGUMENT("name", function.debugname.value); + else + LUAU_TIMETRACE_ARGUMENT("line", std::to_string(function.location.begin.line).c_str()); + if (FunctionTypeVar* funTy = getMutable(ty)) { check(scope, *function.body); @@ -3885,6 +3905,20 @@ std::optional TypeChecker::matchRequire(const AstExprCall& call) TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& moduleInfo, const Location& location) { + LUAU_TIMETRACE_SCOPE("TypeChecker::checkRequire", "TypeChecker"); + LUAU_TIMETRACE_ARGUMENT("moduleInfo", moduleInfo.name.c_str()); + + if (FFlag::LuauNewRequireTrace && moduleInfo.name.empty()) + { + if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict) + { + reportError(TypeError{location, UnknownRequire{}}); + return errorType; + } + + return anyType; + } + ModulePtr module = resolver->getModule(moduleInfo.name); if (!module) { @@ -4472,7 +4506,7 @@ TypeId TypeChecker::freshType(const ScopePtr& scope) TypeId TypeChecker::freshType(TypeLevel level) { - return currentModule->internalTypes.typeVars.allocate(TypeVar(FreeTypeVar(level))); + return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level))); } TypeId TypeChecker::DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneric) @@ -4482,11 +4516,7 @@ TypeId TypeChecker::DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneri TypeId TypeChecker::DEPRECATED_freshType(TypeLevel level, bool canBeGeneric) { - TypeId allocated = currentModule->internalTypes.typeVars.allocate(TypeVar(FreeTypeVar(level, canBeGeneric))); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level, canBeGeneric))); } std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) @@ -4506,20 +4536,12 @@ TypeId TypeChecker::addType(const UnionTypeVar& utv) TypeId TypeChecker::addTV(TypeVar&& tv) { - TypeId allocated = currentModule->internalTypes.typeVars.allocate(std::move(tv)); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addType(std::move(tv)); } TypePackId TypeChecker::addTypePack(TypePackVar&& tv) { - TypePackId allocated = currentModule->internalTypes.typePacks.allocate(std::move(tv)); - if (FFlag::DebugLuauTrackOwningArena) - asMutable(allocated)->owningArena = ¤tModule->internalTypes; - - return allocated; + return currentModule->internalTypes.addTypePack(std::move(tv)); } TypePackId TypeChecker::addTypePack(TypePack&& tp) @@ -4578,7 +4600,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else if (FFlag::DebugLuauMagicTypes && lit->name == "_luau_print") { - if (lit->generics.size != 1) + if (lit->parameters.size != 1 || !lit->parameters.data[0].type) { reportError(TypeError{annotation.location, GenericError{"_luau_print requires one generic parameter"}}); return addType(ErrorTypeVar{}); @@ -4588,7 +4610,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation opts.exhaustive = true; opts.maxTableLength = 0; - TypeId param = resolveType(scope, *lit->generics.data[0]); + TypeId param = resolveType(scope, *lit->parameters.data[0].type); luauPrintLine(format("_luau_print\t%s\t|\t%s", toString(param, opts).c_str(), toString(lit->location).c_str())); return param; } @@ -4614,18 +4636,86 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation return addType(ErrorTypeVar{}); } - if (lit->generics.size == 0 && tf->typeParams.empty()) + if (lit->parameters.size == 0 && tf->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf->typePackParams.empty())) + { return tf->type; - else if (lit->generics.size != tf->typeParams.size()) + } + else if (!FFlag::LuauTypeAliasPacks && lit->parameters.size != tf->typeParams.size()) { - reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->generics.size}}); + reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->parameters.size, 0}}); return addType(ErrorTypeVar{}); } + else if (FFlag::LuauTypeAliasPacks) + { + if (!lit->hasParameterList && !tf->typePackParams.empty()) + { + reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); + return addType(ErrorTypeVar{}); + } + + std::vector typeParams; + std::vector extraTypes; + std::vector typePackParams; + + for (size_t i = 0; i < lit->parameters.size; ++i) + { + if (AstType* type = lit->parameters.data[i].type) + { + TypeId ty = resolveType(scope, *type); + + if (typeParams.size() < tf->typeParams.size() || tf->typePackParams.empty()) + typeParams.push_back(ty); + else if (typePackParams.empty()) + extraTypes.push_back(ty); + else + reportError(TypeError{annotation.location, GenericError{"Type parameters must come before type pack parameters"}}); + } + else if (AstTypePack* typePack = lit->parameters.data[i].typePack) + { + TypePackId tp = resolveTypePack(scope, *typePack); + + // If we have collected an implicit type pack, materialize it + if (typePackParams.empty() && !extraTypes.empty()) + typePackParams.push_back(addTypePack(extraTypes)); + + // If we need more regular types, we can use single element type packs to fill those in + if (typeParams.size() < tf->typeParams.size() && size(tp) == 1 && finite(tp) && first(tp)) + typeParams.push_back(*first(tp)); + else + typePackParams.push_back(tp); + } + } + + // If we still haven't meterialized an implicit type pack, do it now + if (typePackParams.empty() && !extraTypes.empty()) + typePackParams.push_back(addTypePack(extraTypes)); + + // If we didn't combine regular types into a type pack and we're still one type pack short, provide an empty type pack + if (extraTypes.empty() && typePackParams.size() + 1 == tf->typePackParams.size()) + typePackParams.push_back(addTypePack({})); + + if (typeParams.size() != tf->typeParams.size() || typePackParams.size() != tf->typePackParams.size()) + { + reportError( + TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); + return addType(ErrorTypeVar{}); + } + + if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams) + { + // If the generic parameters and the type arguments are the same, we are about to + // perform an identity substitution, which we can just short-circuit. + return tf->type; + } + + return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location); + } else { std::vector typeParams; - for (AstType* paramAnnot : lit->generics) - typeParams.push_back(resolveType(scope, *paramAnnot)); + + for (const auto& param : lit->parameters) + typeParams.push_back(resolveType(scope, *param.type)); if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams) { @@ -4634,7 +4724,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation return tf->type; } - return instantiateTypeFun(scope, *tf, typeParams, annotation.location); + return instantiateTypeFun(scope, *tf, typeParams, {}, annotation.location); } } else if (const auto& table = annotation.as()) @@ -4765,6 +4855,18 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack return *genericTy; } + else if (const AstTypePackExplicit* explicitTp = annotation.as()) + { + std::vector types; + + for (auto type : explicitTp->typeList.types) + types.push_back(resolveType(scope, *type)); + + if (auto tailType = explicitTp->typeList.tailType) + return addTypePack(types, resolveTypePack(scope, *tailType)); + + return addTypePack(types); + } else { ice("Unknown AstTypePack kind"); @@ -4799,12 +4901,28 @@ bool ApplyTypeFunction::isDirty(TypePackId tp) return false; } +bool ApplyTypeFunction::ignoreChildren(TypeId ty) +{ + if (FFlag::LuauSubstitutionDontReplaceIgnoredTypes && get(ty)) + return true; + else + return false; +} + +bool ApplyTypeFunction::ignoreChildren(TypePackId tp) +{ + if (FFlag::LuauSubstitutionDontReplaceIgnoredTypes && get(tp)) + return true; + else + return false; +} + TypeId ApplyTypeFunction::clean(TypeId ty) { // Really this should just replace the arguments, // but for bug-compatibility with existing code, we replace // all generics by free type variables. - TypeId& arg = arguments[ty]; + TypeId& arg = typeArguments[ty]; if (arg) return arg; else @@ -4816,17 +4934,37 @@ TypePackId ApplyTypeFunction::clean(TypePackId tp) // Really this should just replace the arguments, // but for bug-compatibility with existing code, we replace // all generics by free type variables. - return addTypePack(FreeTypePack{level}); + if (FFlag::LuauTypeAliasPacks) + { + TypePackId& arg = typePackArguments[tp]; + if (arg) + return arg; + else + return addTypePack(FreeTypePack{level}); + } + else + { + return addTypePack(FreeTypePack{level}); + } } -TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, const Location& location) +TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, + const std::vector& typePackParams, const Location& location) { - if (tf.typeParams.empty()) + if (tf.typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf.typePackParams.empty())) return tf.type; - applyTypeFunction.arguments.clear(); + applyTypeFunction.typeArguments.clear(); for (size_t i = 0; i < tf.typeParams.size(); ++i) - applyTypeFunction.arguments[tf.typeParams[i]] = typeParams[i]; + applyTypeFunction.typeArguments[tf.typeParams[i]] = typeParams[i]; + + if (FFlag::LuauTypeAliasPacks) + { + applyTypeFunction.typePackArguments.clear(); + for (size_t i = 0; i < tf.typePackParams.size(); ++i) + applyTypeFunction.typePackArguments[tf.typePackParams[i]] = typePackParams[i]; + } + applyTypeFunction.currentModule = currentModule; applyTypeFunction.level = scope->level; applyTypeFunction.encounteredForwardedType = false; @@ -4875,6 +5013,9 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, if (ttv) { ttv->instantiatedTypeParams = typeParams; + + if (FFlag::LuauTypeAliasPacks) + ttv->instantiatedTypePackParams = typePackParams; } } else @@ -4890,6 +5031,9 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, } ttv->instantiatedTypeParams = typeParams; + + if (FFlag::LuauTypeAliasPacks) + ttv->instantiatedTypePackParams = typePackParams; } } @@ -4899,6 +5043,8 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, std::pair, std::vector> TypeChecker::createGenericTypes( const ScopePtr& scope, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) { + LUAU_ASSERT(scope->parent); + std::vector generics; for (const AstName& generic : genericNames) { @@ -4912,7 +5058,19 @@ std::pair, std::vector> TypeChecker::createGener reportError(TypeError{node.location, DuplicateGenericParameter{n}}); } - TypeId g = addType(Unifiable::Generic{scope->level, n}); + TypeId g; + if (FFlag::LuauRecursiveTypeParameterRestriction && FFlag::LuauTypeAliasPacks) + { + TypeId& cached = scope->parent->typeAliasTypeParameters[n]; + if (!cached) + cached = addType(GenericTypeVar{scope->level, n}); + g = cached; + } + else + { + g = addType(Unifiable::Generic{scope->level, n}); + } + generics.push_back(g); scope->privateTypeBindings[n] = TypeFun{{}, g}; } @@ -4930,7 +5088,19 @@ std::pair, std::vector> TypeChecker::createGener reportError(TypeError{node.location, DuplicateGenericParameter{n}}); } - TypePackId g = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); + TypePackId g; + if (FFlag::LuauRecursiveTypeParameterRestriction && FFlag::LuauTypeAliasPacks) + { + TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; + if (!cached) + cached = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); + g = cached; + } + else + { + g = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); + } + genericPacks.push_back(g); scope->privateTypePackBindings[n] = g; } @@ -5013,13 +5183,8 @@ void TypeChecker::resolve(const Predicate& predicate, ErrorVec& errVec, Refineme else if (auto isaP = get(predicate)) resolve(*isaP, errVec, refis, scope, sense); else if (auto typeguardP = get(predicate)) - { - if (FFlag::LuauImprovedTypeGuardPredicate2) - resolve(*typeguardP, errVec, refis, scope, sense); - else - DEPRECATED_resolve(*typeguardP, errVec, refis, scope, sense); - } - else if (auto eqP = get(predicate); eqP && FFlag::LuauEqConstraint) + resolve(*typeguardP, errVec, refis, scope, sense); + else if (auto eqP = get(predicate)) resolve(*eqP, errVec, refis, scope, sense); else ice("Unhandled predicate kind"); @@ -5145,7 +5310,7 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement return isaP.ty; } } - else if (FFlag::LuauImprovedTypeGuardPredicate2) + else { auto lctv = get(option); auto rctv = get(isaP.ty); @@ -5159,19 +5324,6 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement if (canUnify(option, isaP.ty, isaP.location).empty() == sense) return isaP.ty; } - else - { - auto lctv = get(option); - auto rctv = get(isaP.ty); - - if (lctv && rctv) - { - if (isSubclass(lctv, rctv) == sense) - return option; - else if (isSubclass(rctv, lctv) == sense) - return isaP.ty; - } - } return std::nullopt; }; @@ -5266,7 +5418,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); auto typeFun = globalScope->lookupType(typeguardP.kind); - if (!typeFun || !typeFun->typeParams.empty()) + if (!typeFun || !typeFun->typeParams.empty() || (FFlag::LuauTypeAliasPacks && !typeFun->typePackParams.empty())) return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); TypeId type = follow(typeFun->type); @@ -5292,7 +5444,8 @@ void TypeChecker::DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, Error "userdata", // no op. Requires special handling. }; - if (auto typeFun = globalScope->lookupType(typeguardP.kind); typeFun && typeFun->typeParams.empty()) + if (auto typeFun = globalScope->lookupType(typeguardP.kind); + typeFun && typeFun->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || typeFun->typePackParams.empty())) { if (auto it = std::find(primitives.begin(), primitives.end(), typeguardP.kind); it != primitives.end()) addRefinement(refis, typeguardP.lvalue, typeFun->type); @@ -5319,38 +5472,41 @@ void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMa return; } - std::optional ty = resolveLValue(refis, scope, eqP.lvalue); - if (!ty) - return; + if (FFlag::LuauEqConstraint) + { + std::optional ty = resolveLValue(refis, scope, eqP.lvalue); + if (!ty) + return; - std::vector lhs = options(*ty); - std::vector rhs = options(eqP.type); + std::vector lhs = options(*ty); + std::vector rhs = options(eqP.type); - if (sense && std::any_of(lhs.begin(), lhs.end(), isUndecidable)) - { - addRefinement(refis, eqP.lvalue, eqP.type); - return; - } - else if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) - return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. + if (sense && std::any_of(lhs.begin(), lhs.end(), isUndecidable)) + { + addRefinement(refis, eqP.lvalue, eqP.type); + return; + } + else if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) + return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. - std::unordered_set set; - for (TypeId left : lhs) - { - for (TypeId right : rhs) + std::unordered_set set; + for (TypeId left : lhs) { - // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. - if (canUnify(left, right, eqP.location).empty() == sense || (!sense && !isNil(left))) - set.insert(left); + for (TypeId right : rhs) + { + // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. + if (canUnify(left, right, eqP.location).empty() == sense || (!sense && !isNil(left))) + set.insert(left); + } } - } - if (set.empty()) - return; + if (set.empty()) + return; - std::vector viable(set.begin(), set.end()); - TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)}); - addRefinement(refis, eqP.lvalue, result); + std::vector viable(set.begin(), set.end()); + TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)}); + addRefinement(refis, eqP.lvalue, result); + } } bool TypeChecker::isNonstrictMode() const @@ -5379,119 +5535,4 @@ std::vector> TypeChecker::getScopes() const return currentModule->scopes; } -Scope::Scope(TypePackId returnType) - : parent(nullptr) - , returnType(returnType) - , level(TypeLevel()) -{ -} - -Scope::Scope(const ScopePtr& parent, int subLevel) - : parent(parent) - , returnType(parent->returnType) - , level(parent->level.incr()) -{ - level.subLevel = subLevel; -} - -std::optional Scope::lookup(const Symbol& name) -{ - Scope* scope = this; - - while (scope) - { - auto it = scope->bindings.find(name); - if (it != scope->bindings.end()) - return it->second.typeId; - - scope = scope->parent.get(); - } - - return std::nullopt; -} - -std::optional Scope::lookupType(const Name& name) -{ - const Scope* scope = this; - while (true) - { - auto it = scope->exportedTypeBindings.find(name); - if (it != scope->exportedTypeBindings.end()) - return it->second; - - it = scope->privateTypeBindings.find(name); - if (it != scope->privateTypeBindings.end()) - return it->second; - - if (scope->parent) - scope = scope->parent.get(); - else - return std::nullopt; - } -} - -std::optional Scope::lookupImportedType(const Name& moduleAlias, const Name& name) -{ - const Scope* scope = this; - while (scope) - { - auto it = scope->importedTypeBindings.find(moduleAlias); - if (it == scope->importedTypeBindings.end()) - { - scope = scope->parent.get(); - continue; - } - - auto it2 = it->second.find(name); - if (it2 == it->second.end()) - { - scope = scope->parent.get(); - continue; - } - - return it2->second; - } - - return std::nullopt; -} - -std::optional Scope::lookupPack(const Name& name) -{ - const Scope* scope = this; - while (true) - { - auto it = scope->privateTypePackBindings.find(name); - if (it != scope->privateTypePackBindings.end()) - return it->second; - - if (scope->parent) - scope = scope->parent.get(); - else - return std::nullopt; - } -} - -std::optional Scope::linearSearchForBinding(const std::string& name, bool traverseScopeChain) -{ - Scope* scope = this; - - while (scope) - { - for (const auto& [n, binding] : scope->bindings) - { - if (n.local && n.local->name == name.c_str()) - return binding; - else if (n.global.value && n.global == name.c_str()) - return binding; - } - - scope = scope->parent.get(); - - if (!traverseScopeChain) - break; - } - - return std::nullopt; -} - } // namespace Luau diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 5970f304d..68a16ef04 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -209,6 +209,19 @@ size_t size(TypePackId tp) return 0; } +bool finite(TypePackId tp) +{ + tp = follow(tp); + + if (auto pack = get(tp)) + return pack->tail ? finite(*pack->tail) : true; + + if (auto pack = get(tp)) + return false; + + return true; +} + size_t size(const TypePack& tp) { size_t result = tp.head.size(); diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index b9f509788..0d9d91e0c 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -1,11 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeUtils.h" +#include "Luau/Scope.h" #include "Luau/ToString.h" #include "Luau/TypeInfer.h" -LUAU_FASTFLAG(LuauStringMetatable) - namespace Luau { @@ -13,21 +12,6 @@ std::optional findMetatableEntry(ErrorVec& errors, const ScopePtr& globa { type = follow(type); - if (!FFlag::LuauStringMetatable) - { - if (const PrimitiveTypeVar* primType = get(type)) - { - if (primType->type != PrimitiveTypeVar::String || "__index" != entry) - return std::nullopt; - - auto it = globalScope->bindings.find(AstName{"string"}); - if (it != globalScope->bindings.end()) - return it->second.typeId; - else - return std::nullopt; - } - } - std::optional metatable = getMetatable(type); if (!metatable) return std::nullopt; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 111f4f53d..e963fc74e 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -19,11 +19,9 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) -LUAU_FASTFLAG(LuauImprovedTypeGuardPredicate2) -LUAU_FASTFLAGVARIABLE(LuauToStringFollowsBoundTo, false) LUAU_FASTFLAG(LuauRankNTypes) -LUAU_FASTFLAGVARIABLE(LuauStringMetatable, false) LUAU_FASTFLAG(LuauTypeGuardPeelsAwaySubclasses) +LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -193,27 +191,11 @@ bool isOptional(TypeId ty) bool isTableIntersection(TypeId ty) { - if (FFlag::LuauImprovedTypeGuardPredicate2) - { - if (!get(follow(ty))) - return false; - - std::vector parts = flattenIntersection(ty); - return std::all_of(parts.begin(), parts.end(), getTableType); - } - else - { - if (const IntersectionTypeVar* itv = get(ty)) - { - for (TypeId part : itv->parts) - { - if (getTableType(follow(part))) - return true; - } - } - + if (!get(follow(ty))) return false; - } + + std::vector parts = flattenIntersection(ty); + return std::all_of(parts.begin(), parts.end(), getTableType); } bool isOverloadedFunction(TypeId ty) @@ -236,7 +218,7 @@ std::optional getMetatable(TypeId type) else if (const ClassTypeVar* classType = get(type)) return classType->metatable; else if (const PrimitiveTypeVar* primitiveType = get(type); - FFlag::LuauStringMetatable && primitiveType && primitiveType->metatable) + primitiveType && primitiveType->metatable) { LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String); return primitiveType->metatable; @@ -871,6 +853,12 @@ void StateDot::visitChildren(TypeId ty, int index) } for (TypeId itp : ttv->instantiatedTypeParams) visitChild(itp, index, "typeParam"); + + if (FFlag::LuauTypeAliasPacks) + { + for (TypePackId itp : ttv->instantiatedTypePackParams) + visitChild(itp, index, "typePackParam"); + } } else if (const MetatableTypeVar* mtv = get(ty)) { diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 89c3f80ca..117cbc289 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -3,23 +3,25 @@ #include "Luau/Common.h" #include "Luau/RecursionCounter.h" +#include "Luau/Scope.h" #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" +#include "Luau/TimeTrace.h" #include LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); -LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 0); -LUAU_FASTFLAGVARIABLE(LuauLogTableTypeVarBoundTo, false) +LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); LUAU_FASTFLAG(LuauGenericFunctions) +LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance, false); LUAU_FASTFLAGVARIABLE(LuauDontMutatePersistentFunctions, false) LUAU_FASTFLAG(LuauRankNTypes) -LUAU_FASTFLAG(LuauStringMetatable) LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAGVARIABLE(LuauSealedTableUnifyOptionalFix, false) LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) +LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false) namespace Luau { @@ -43,21 +45,23 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Locati , globalScope(std::move(globalScope)) , location(location) , variance(variance) - , counters(std::make_shared()) + , counters(&countersData) + , counters_DEPRECATED(std::make_shared()) , iceHandler(iceHandler) { LUAU_ASSERT(iceHandler); } Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& seen, const Location& location, - Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters) + Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) : types(types) , mode(mode) , globalScope(std::move(globalScope)) , log(seen) , location(location) , variance(variance) - , counters(counters ? counters : std::make_shared()) + , counters(counters ? counters : &countersData) + , counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared()) , iceHandler(iceHandler) { LUAU_ASSERT(iceHandler); @@ -65,16 +69,26 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::v void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) { - counters->iterationCount = 0; + if (FFlag::LuauTypecheckOpts) + counters->iterationCount = 0; + else + counters_DEPRECATED->iterationCount = 0; + return tryUnify_(superTy, subTy, isFunctionCall, isIntersection); } void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) { - RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra( + FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); + + if (FFlag::LuauTypecheckOpts) + ++counters->iterationCount; + else + ++counters_DEPRECATED->iterationCount; - ++counters->iterationCount; - if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < counters->iterationCount) + if (FInt::LuauTypeInferIterationLimit > 0 && + FInt::LuauTypeInferIterationLimit < (FFlag::LuauTypecheckOpts ? counters->iterationCount : counters_DEPRECATED->iterationCount)) { errors.push_back(TypeError{location, UnificationTooComplex{}}); return; @@ -440,7 +454,11 @@ ErrorVec Unifier::canUnify(TypePackId superTy, TypePackId subTy, bool isFunction void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall) { - counters->iterationCount = 0; + if (FFlag::LuauTypecheckOpts) + counters->iterationCount = 0; + else + counters_DEPRECATED->iterationCount = 0; + return tryUnify_(superTp, subTp, isFunctionCall); } @@ -450,10 +468,16 @@ void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall */ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCall) { - RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra( + FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); + + if (FFlag::LuauTypecheckOpts) + ++counters->iterationCount; + else + ++counters_DEPRECATED->iterationCount; - ++counters->iterationCount; - if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < counters->iterationCount) + if (FInt::LuauTypeInferIterationLimit > 0 && + FInt::LuauTypeInferIterationLimit < (FFlag::LuauTypecheckOpts ? counters->iterationCount : counters_DEPRECATED->iterationCount)) { errors.push_back(TypeError{location, UnificationTooComplex{}}); return; @@ -762,9 +786,210 @@ struct Resetter void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) { - std::unique_ptr resetter; + if (!FFlag::LuauTableSubtypingVariance) + return DEPRECATED_tryUnifyTables(left, right, isIntersection); - resetter.reset(new Resetter{&variance}); + TableTypeVar* lt = getMutable(left); + TableTypeVar* rt = getMutable(right); + if (!lt || !rt) + ice("passed non-table types to unifyTables"); + + std::vector missingProperties; + std::vector extraProperties; + + // Reminder: left is the supertype, right is the subtype. + // Width subtyping: any property in the supertype must be in the subtype, + // and the types must agree. + for (const auto& [name, prop] : lt->props) + { + const auto& r = rt->props.find(name); + if (r != rt->props.end()) + { + // TODO: read-only properties don't need invariance + Resetter resetter{&variance}; + variance = Invariant; + + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(prop.type, r->second.type); + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + else + innerState.log.rollback(); + } + else if (rt->indexer && isString(rt->indexer->indexType)) + { + // TODO: read-only indexers don't need invariance + // TODO: really we should only allow this if prop.type is optional. + Resetter resetter{&variance}; + variance = Invariant; + + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(prop.type, rt->indexer->indexResultType); + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + else + innerState.log.rollback(); + } + else if (isOptional(prop.type) || get(follow(prop.type))) + // TODO: this case is unsound, but without it our test suite fails. CLI-46031 + // TODO: should isOptional(anyType) be true? + {} + else if (rt->state == TableState::Free) + { + log(rt); + rt->props[name] = prop; + } + else + missingProperties.push_back(name); + } + + for (const auto& [name, prop] : rt->props) + { + if (lt->props.count(name)) + { + // If both lt and rt contain the property, then + // we're done since we already unified them above + } + else if (lt->indexer && isString(lt->indexer->indexType)) + { + // TODO: read-only indexers don't need invariance + // TODO: really we should only allow this if prop.type is optional. + Resetter resetter{&variance}; + variance = Invariant; + + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(prop.type, lt->indexer->indexResultType); + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + else + innerState.log.rollback(); + } + else if (lt->state == TableState::Unsealed) + { + // TODO: this case is unsound when variance is Invariant, but without it lua-apps fails to typecheck. + // TODO: file a JIRA + // TODO: hopefully readonly/writeonly properties will fix this. + Property clone = prop; + clone.type = deeplyOptional(clone.type); + log(lt); + lt->props[name] = clone; + } + else if (variance == Covariant) + {} + else if (isOptional(prop.type) || get(follow(prop.type))) + // TODO: this case is unsound, but without it our test suite fails. CLI-46031 + // TODO: should isOptional(anyType) be true? + {} + else if (lt->state == TableState::Free) + { + log(lt); + lt->props[name] = prop; + } + else + extraProperties.push_back(name); + } + + // Unify indexers + if (lt->indexer && rt->indexer) + { + // TODO: read-only indexers don't need invariance + Resetter resetter{&variance}; + variance = Invariant; + + Unifier innerState = makeChildUnifier(); + innerState.tryUnify(*lt->indexer, *rt->indexer); + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + else + innerState.log.rollback(); + } + else if (lt->indexer) + { + if (rt->state == TableState::Unsealed || rt->state == TableState::Free) + { + // passing/assigning a table without an indexer to something that has one + // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. + // TODO: we only need to do this if the supertype's indexer is read/write + // since that can add indexed elements. + log(rt); + rt->indexer = lt->indexer; + } + } + else if (rt->indexer && variance == Invariant) + { + // Symmetric if we are invariant + if (lt->state == TableState::Unsealed || lt->state == TableState::Free) + { + log(lt); + lt->indexer = rt->indexer; + } + } + + if (!missingProperties.empty()) + { + errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingProperties)}}); + return; + } + + if (!extraProperties.empty()) + { + errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraProperties), MissingProperties::Extra}}); + return; + } + + /* + * TypeVars are commonly cyclic, so it is entirely possible + * for unifying a property of a table to change the table itself! + * We need to check for this and start over if we notice this occurring. + * + * I believe this is guaranteed to terminate eventually because this will + * only happen when a free table is bound to another table. + */ + if (lt->boundTo || rt->boundTo) + return tryUnify_(left, right); + + if (lt->state == TableState::Free) + { + log(lt); + lt->boundTo = right; + } + else if (rt->state == TableState::Free) + { + log(rt); + rt->boundTo = left; + } +} + +TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map seen) +{ + ty = follow(ty); + if (get(ty)) + return ty; + else if (isOptional(ty)) + return ty; + else if (const TableTypeVar* ttv = get(ty)) + { + TypeId& result = seen[ty]; + if (result) + return result; + result = types->addType(*ttv); + TableTypeVar* resultTtv = getMutable(result); + for (auto& [name, prop] : resultTtv->props) + prop.type = deeplyOptional(prop.type, seen); + return types->addType(UnionTypeVar{{ singletonTypes.nilType, result }});; + } + else + return types->addType(UnionTypeVar{{ singletonTypes.nilType, ty }}); +} + +void Unifier::DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection) +{ + LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance); + Resetter resetter{&variance}; variance = Invariant; TableTypeVar* lt = getMutable(left); @@ -894,10 +1119,7 @@ void Unifier::tryUnifyFreeTable(TypeId freeTypeId, TypeId otherTypeId) if (!freeTable->boundTo && otherTable->state != TableState::Free) { - if (FFlag::LuauLogTableTypeVarBoundTo) - log(freeTable); - else - log(freeTypeId); + log(freeTable); freeTable->boundTo = otherTypeId; } } @@ -1196,9 +1418,11 @@ void Unifier::tryUnify(const TableIndexer& superIndexer, const TableIndexer& sub tryUnify_(superIndexer.indexResultType, subIndexer.indexResultType); } -static void queueTypePack( +static void queueTypePack_DEPRECATED( std::vector& queue, std::unordered_set& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) { + LUAU_ASSERT(!FFlag::LuauTypecheckOpts); + while (true) { if (FFlag::LuauAddMissingFollow) @@ -1244,6 +1468,55 @@ static void queueTypePack( } } +static void queueTypePack(std::vector& queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) +{ + LUAU_ASSERT(FFlag::LuauTypecheckOpts); + + while (true) + { + if (FFlag::LuauAddMissingFollow) + a = follow(a); + + if (seenTypePacks.find(a)) + break; + seenTypePacks.insert(a); + + if (FFlag::LuauAddMissingFollow) + { + if (get(a)) + { + state.log(a); + *asMutable(a) = Unifiable::Bound{anyTypePack}; + } + else if (auto tp = get(a)) + { + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; + } + } + else + { + if (get(a)) + { + state.log(a); + *asMutable(a) = Unifiable::Bound{anyTypePack}; + } + + if (auto tp = get(a)) + { + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; + } + } + } +} + void Unifier::tryUnifyVariadics(TypePackId superTp, TypePackId subTp, bool reversed, int subOffset) { const VariadicTypePack* lv = get(superTp); @@ -1297,9 +1570,11 @@ void Unifier::tryUnifyVariadics(TypePackId superTp, TypePackId subTp, bool rever } } -static void tryUnifyWithAny( +static void tryUnifyWithAny_DEPRECATED( std::vector& queue, Unifier& state, std::unordered_set& seenTypePacks, TypeId anyType, TypePackId anyTypePack) { + LUAU_ASSERT(!FFlag::LuauTypecheckOpts); + std::unordered_set seen; while (!queue.empty()) @@ -1310,6 +1585,59 @@ static void tryUnifyWithAny( continue; seen.insert(ty); + if (get(ty)) + { + state.log(ty); + *asMutable(ty) = BoundTypeVar{anyType}; + } + else if (auto fun = get(ty)) + { + queueTypePack_DEPRECATED(queue, seenTypePacks, state, fun->argTypes, anyTypePack); + queueTypePack_DEPRECATED(queue, seenTypePacks, state, fun->retType, anyTypePack); + } + else if (auto table = get(ty)) + { + for (const auto& [_name, prop] : table->props) + queue.push_back(prop.type); + + if (table->indexer) + { + queue.push_back(table->indexer->indexType); + queue.push_back(table->indexer->indexResultType); + } + } + else if (auto mt = get(ty)) + { + queue.push_back(mt->table); + queue.push_back(mt->metatable); + } + else if (get(ty)) + { + // ClassTypeVars never contain free typevars. + } + else if (auto union_ = get(ty)) + queue.insert(queue.end(), union_->options.begin(), union_->options.end()); + else if (auto intersection = get(ty)) + queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); + else + { + } // Primitives, any, errors, and generics are left untouched. + } +} + +static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHashSet& seen, DenseHashSet& seenTypePacks, + TypeId anyType, TypePackId anyTypePack) +{ + LUAU_ASSERT(FFlag::LuauTypecheckOpts); + + while (!queue.empty()) + { + TypeId ty = follow(queue.back()); + queue.pop_back(); + if (seen.find(ty)) + continue; + seen.insert(ty); + if (get(ty)) { state.log(ty); @@ -1354,14 +1682,33 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) { LUAU_ASSERT(get(any) || get(any)); + if (FFlag::LuauTypecheckOpts) + { + // These types are not visited in general loop below + if (get(ty) || get(ty) || get(ty)) + return; + } + const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{singletonTypes.anyType}}); const TypePackId anyTP = get(any) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); - std::unordered_set seenTypePacks; - std::vector queue = {ty}; + if (FFlag::LuauTypecheckOpts) + { + std::vector queue = {ty}; + + tempSeenTy.clear(); + tempSeenTp.clear(); + + Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, singletonTypes.anyType, anyTP); + } + else + { + std::unordered_set seenTypePacks; + std::vector queue = {ty}; - Luau::tryUnifyWithAny(queue, *this, seenTypePacks, singletonTypes.anyType, anyTP); + Luau::tryUnifyWithAny_DEPRECATED(queue, *this, seenTypePacks, singletonTypes.anyType, anyTP); + } } void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) @@ -1370,12 +1717,26 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) const TypeId anyTy = singletonTypes.errorType; - std::unordered_set seenTypePacks; - std::vector queue; + if (FFlag::LuauTypecheckOpts) + { + std::vector queue; - queueTypePack(queue, seenTypePacks, *this, ty, any); + tempSeenTy.clear(); + tempSeenTp.clear(); - Luau::tryUnifyWithAny(queue, *this, seenTypePacks, anyTy, any); + queueTypePack(queue, tempSeenTp, *this, ty, any); + + Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, anyTy, any); + } + else + { + std::unordered_set seenTypePacks; + std::vector queue; + + queueTypePack_DEPRECATED(queue, seenTypePacks, *this, ty, any); + + Luau::tryUnifyWithAny_DEPRECATED(queue, *this, seenTypePacks, anyTy, any); + } } std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, Name name) @@ -1387,21 +1748,6 @@ std::optional Unifier::findMetatableEntry(TypeId type, std::string entry { type = follow(type); - if (!FFlag::LuauStringMetatable) - { - if (const PrimitiveTypeVar* primType = get(type)) - { - if (primType->type != PrimitiveTypeVar::String || "__index" != entry) - return std::nullopt; - - auto found = globalScope->bindings.find(AstName{"string"}); - if (found == globalScope->bindings.end()) - return std::nullopt; - else - return found->second.typeId; - } - } - std::optional metatable = getMetatable(type); if (!metatable) return std::nullopt; @@ -1427,21 +1773,36 @@ std::optional Unifier::findMetatableEntry(TypeId type, std::string entry void Unifier::occursCheck(TypeId needle, TypeId haystack) { - std::unordered_set seen; - return occursCheck(seen, needle, haystack); + std::unordered_set seen_DEPRECATED; + + if (FFlag::LuauTypecheckOpts) + tempSeenTy.clear(); + + return occursCheck(seen_DEPRECATED, tempSeenTy, needle, haystack); } -void Unifier::occursCheck(std::unordered_set& seen, TypeId needle, TypeId haystack) +void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypeId needle, TypeId haystack) { - RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra( + FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); needle = follow(needle); haystack = follow(haystack); - if (seen.end() != seen.find(haystack)) - return; + if (FFlag::LuauTypecheckOpts) + { + if (seen.find(haystack)) + return; + + seen.insert(haystack); + } + else + { + if (seen_DEPRECATED.end() != seen_DEPRECATED.find(haystack)) + return; - seen.insert(haystack); + seen_DEPRECATED.insert(haystack); + } if (get(needle)) return; @@ -1458,7 +1819,7 @@ void Unifier::occursCheck(std::unordered_set& seen, TypeId needle, TypeI } auto check = [&](TypeId tv) { - occursCheck(seen, needle, tv); + occursCheck(seen_DEPRECATED, seen, needle, tv); }; if (get(haystack)) @@ -1488,19 +1849,33 @@ void Unifier::occursCheck(std::unordered_set& seen, TypeId needle, TypeI void Unifier::occursCheck(TypePackId needle, TypePackId haystack) { - std::unordered_set seen; - return occursCheck(seen, needle, haystack); + std::unordered_set seen_DEPRECATED; + + if (FFlag::LuauTypecheckOpts) + tempSeenTp.clear(); + + return occursCheck(seen_DEPRECATED, tempSeenTp, needle, haystack); } -void Unifier::occursCheck(std::unordered_set& seen, TypePackId needle, TypePackId haystack) +void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypePackId needle, TypePackId haystack) { needle = follow(needle); haystack = follow(haystack); - if (seen.find(haystack) != seen.end()) - return; + if (FFlag::LuauTypecheckOpts) + { + if (seen.find(haystack)) + return; - seen.insert(haystack); + seen.insert(haystack); + } + else + { + if (seen_DEPRECATED.end() != seen_DEPRECATED.find(haystack)) + return; + + seen_DEPRECATED.insert(haystack); + } if (get(needle)) return; @@ -1508,7 +1883,8 @@ void Unifier::occursCheck(std::unordered_set& seen, TypePackId needl if (!get(needle)) ice("Expected needle pack to be free"); - RecursionLimiter _ra(&counters->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra( + FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); while (!get(haystack)) { @@ -1528,8 +1904,8 @@ void Unifier::occursCheck(std::unordered_set& seen, TypePackId needl { if (auto f = get(FFlag::LuauAddMissingFollow ? follow(ty) : ty)) { - occursCheck(seen, needle, f->argTypes); - occursCheck(seen, needle, f->retType); + occursCheck(seen_DEPRECATED, seen, needle, f->argTypes); + occursCheck(seen_DEPRECATED, seen, needle, f->retType); } } } @@ -1546,7 +1922,7 @@ void Unifier::occursCheck(std::unordered_set& seen, TypePackId needl Unifier Unifier::makeChildUnifier() { - return Unifier{types, mode, globalScope, log.seen, location, variance, iceHandler, counters}; + return Unifier{types, mode, globalScope, log.seen, location, variance, iceHandler, counters_DEPRECATED, counters}; } bool Unifier::isNonstrictMode() const diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index df38cfecf..a2189f7b7 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -264,6 +264,10 @@ class AstVisitor { return false; } + virtual bool visit(class AstTypePackExplicit* node) + { + return visit((class AstTypePack*)node); + } virtual bool visit(class AstTypePackVariadic* node) { return visit((class AstTypePack*)node); @@ -930,12 +934,14 @@ class AstStatTypeAlias : public AstStat public: LUAU_RTTI(AstStatTypeAlias) - AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, AstType* type, bool exported); + AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, const AstArray& genericPacks, + AstType* type, bool exported); void visit(AstVisitor* visitor) override; AstName name; AstArray generics; + AstArray genericPacks; AstType* type; bool exported; }; @@ -1007,19 +1013,28 @@ class AstType : public AstNode } }; +// Don't have Luau::Variant available, it's a bit of an overhead, but a plain struct is nice to use +struct AstTypeOrPack +{ + AstType* type = nullptr; + AstTypePack* typePack = nullptr; +}; + class AstTypeReference : public AstType { public: LUAU_RTTI(AstTypeReference) - AstTypeReference(const Location& location, std::optional prefix, AstName name, const AstArray& generics = {}); + AstTypeReference(const Location& location, std::optional prefix, AstName name, bool hasParameterList = false, + const AstArray& parameters = {}); void visit(AstVisitor* visitor) override; bool hasPrefix; + bool hasParameterList; AstName prefix; AstName name; - AstArray generics; + AstArray parameters; }; struct AstTableProp @@ -1152,6 +1167,18 @@ class AstTypePack : public AstNode } }; +class AstTypePackExplicit : public AstTypePack +{ +public: + LUAU_RTTI(AstTypePackExplicit) + + AstTypePackExplicit(const Location& location, AstTypeList typeList); + + void visit(AstVisitor* visitor) override; + + AstTypeList typeList; +}; + class AstTypePackVariadic : public AstTypePack { public: diff --git a/Ast/include/Luau/DenseHash.h b/Ast/include/Luau/DenseHash.h index 02924e883..a7b2515a6 100644 --- a/Ast/include/Luau/DenseHash.h +++ b/Ast/include/Luau/DenseHash.h @@ -136,7 +136,10 @@ class DenseHashTable const Key& key = ItemInterface::getKey(data[i]); if (!eq(key, empty_key)) - *newtable.insert_unsafe(key) = data[i]; + { + Item* item = newtable.insert_unsafe(key); + *item = std::move(data[i]); + } } LUAU_ASSERT(count == newtable.count); diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index e6ebd503c..42c64dc92 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -218,13 +218,14 @@ class Parser AstTableIndexer* parseTableIndexerAnnotation(); - AstType* parseFunctionTypeAnnotation(); + AstTypeOrPack parseFunctionTypeAnnotation(bool allowPack); AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation); AstType* parseTableTypeAnnotation(); - AstType* parseSimpleTypeAnnotation(); + AstTypeOrPack parseSimpleTypeAnnotation(bool allowPack); + AstTypeOrPack parseTypeOrPackAnnotation(); AstType* parseTypeAnnotation(TempVector& parts, const Location& begin); AstType* parseTypeAnnotation(); @@ -284,7 +285,7 @@ class Parser std::pair, AstArray> parseGenericTypeListIfFFlagParseGenericFunctions(); // `<' typeAnnotation[, ...] `>' - AstArray parseTypeParams(); + AstArray parseTypeParams(); AstExpr* parseString(); @@ -413,6 +414,7 @@ class Parser std::vector scratchLocal; std::vector scratchTableTypeProps; std::vector scratchAnnotation; + std::vector scratchTypeOrPackAnnotation; std::vector scratchDeclaredClassProps; std::vector scratchItem; std::vector scratchArgName; diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h new file mode 100644 index 000000000..641dfd3c3 --- /dev/null +++ b/Ast/include/Luau/TimeTrace.h @@ -0,0 +1,223 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Common.h" + +#include + +#include + +LUAU_FASTFLAG(DebugLuauTimeTracing) + +#if defined(LUAU_ENABLE_TIME_TRACE) + +namespace Luau +{ +namespace TimeTrace +{ +uint32_t getClockMicroseconds(); + +struct Token +{ + const char* name; + const char* category; +}; + +enum class EventType : uint8_t +{ + Enter, + Leave, + + ArgName, + ArgValue, +}; + +struct Event +{ + EventType type; + uint16_t token; + + union + { + uint32_t microsec; // 1 hour trace limit + uint32_t dataPos; + } data; +}; + +struct GlobalContext; +struct ThreadContext; + +GlobalContext& getGlobalContext(); + +uint16_t createToken(GlobalContext& context, const char* name, const char* category); +uint32_t createThread(GlobalContext& context, ThreadContext* threadContext); +void releaseThread(GlobalContext& context, ThreadContext* threadContext); +void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector& events, const std::vector& data); + +struct ThreadContext +{ + ThreadContext() + : globalContext(getGlobalContext()) + { + threadId = createThread(globalContext, this); + } + + ~ThreadContext() + { + if (!events.empty()) + flushEvents(); + + releaseThread(globalContext, this); + } + + void flushEvents() + { + static uint16_t flushToken = createToken(globalContext, "flushEvents", "TimeTrace"); + + events.push_back({EventType::Enter, flushToken, {getClockMicroseconds()}}); + + TimeTrace::flushEvents(globalContext, threadId, events, data); + + events.clear(); + data.clear(); + + events.push_back({EventType::Leave, 0, {getClockMicroseconds()}}); + } + + void eventEnter(uint16_t token) + { + eventEnter(token, getClockMicroseconds()); + } + + void eventEnter(uint16_t token, uint32_t microsec) + { + events.push_back({EventType::Enter, token, {microsec}}); + } + + void eventLeave() + { + eventLeave(getClockMicroseconds()); + } + + void eventLeave(uint32_t microsec) + { + events.push_back({EventType::Leave, 0, {microsec}}); + + if (events.size() > kEventFlushLimit) + flushEvents(); + } + + void eventArgument(const char* name, const char* value) + { + uint32_t pos = uint32_t(data.size()); + data.insert(data.end(), name, name + strlen(name) + 1); + events.push_back({EventType::ArgName, 0, {pos}}); + + pos = uint32_t(data.size()); + data.insert(data.end(), value, value + strlen(value) + 1); + events.push_back({EventType::ArgValue, 0, {pos}}); + } + + GlobalContext& globalContext; + uint32_t threadId; + std::vector events; + std::vector data; + + static constexpr size_t kEventFlushLimit = 8192; +}; + +ThreadContext& getThreadContext(); + +struct Scope +{ + explicit Scope(ThreadContext& context, uint16_t token) + : context(context) + { + if (!FFlag::DebugLuauTimeTracing) + return; + + context.eventEnter(token); + } + + ~Scope() + { + if (!FFlag::DebugLuauTimeTracing) + return; + + context.eventLeave(); + } + + ThreadContext& context; +}; + +struct OptionalTailScope +{ + explicit OptionalTailScope(ThreadContext& context, uint16_t token, uint32_t threshold) + : context(context) + , token(token) + , threshold(threshold) + { + if (!FFlag::DebugLuauTimeTracing) + return; + + pos = uint32_t(context.events.size()); + microsec = getClockMicroseconds(); + } + + ~OptionalTailScope() + { + if (!FFlag::DebugLuauTimeTracing) + return; + + if (pos == context.events.size()) + { + uint32_t curr = getClockMicroseconds(); + + if (curr - microsec > threshold) + { + context.eventEnter(token, microsec); + context.eventLeave(curr); + } + } + } + + ThreadContext& context; + uint16_t token; + uint32_t threshold; + uint32_t microsec; + uint32_t pos; +}; + +LUAU_NOINLINE std::pair createScopeData(const char* name, const char* category); + +} // namespace TimeTrace +} // namespace Luau + +// Regular scope +#define LUAU_TIMETRACE_SCOPE(name, category) \ + static auto lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \ + Luau::TimeTrace::Scope lttScope(lttScopeStatic.second, lttScopeStatic.first) + +// A scope without nested scopes that may be skipped if the time it took is less than the threshold +#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \ + static auto lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \ + Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail.second, lttScopeStaticOptTail.first, microsec) + +// Extra key/value data can be added to regular scopes +#define LUAU_TIMETRACE_ARGUMENT(name, value) \ + do \ + { \ + if (FFlag::DebugLuauTimeTracing) \ + lttScopeStatic.second.eventArgument(name, value); \ + } while (false) + +#else + +#define LUAU_TIMETRACE_SCOPE(name, category) +#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) +#define LUAU_TIMETRACE_ARGUMENT(name, value) \ + do \ + { \ + } while (false) + +#endif diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index fff1537dd..b1209faa1 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -641,10 +641,12 @@ void AstStatLocalFunction::visit(AstVisitor* visitor) func->visit(visitor); } -AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, AstType* type, bool exported) +AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, AstType* type, bool exported) : AstStat(ClassIndex(), location) , name(name) , generics(generics) + , genericPacks(genericPacks) , type(type) , exported(exported) { @@ -729,12 +731,14 @@ void AstStatError::visit(AstVisitor* visitor) } } -AstTypeReference::AstTypeReference(const Location& location, std::optional prefix, AstName name, const AstArray& generics) +AstTypeReference::AstTypeReference( + const Location& location, std::optional prefix, AstName name, bool hasParameterList, const AstArray& parameters) : AstType(ClassIndex(), location) , hasPrefix(bool(prefix)) + , hasParameterList(hasParameterList) , prefix(prefix ? *prefix : AstName()) , name(name) - , generics(generics) + , parameters(parameters) { } @@ -742,8 +746,13 @@ void AstTypeReference::visit(AstVisitor* visitor) { if (visitor->visit(this)) { - for (AstType* generic : generics) - generic->visit(visitor); + for (const AstTypeOrPack& param : parameters) + { + if (param.type) + param.type->visit(visitor); + else + param.typePack->visit(visitor); + } } } @@ -849,6 +858,24 @@ void AstTypeError::visit(AstVisitor* visitor) } } +AstTypePackExplicit::AstTypePackExplicit(const Location& location, AstTypeList typeList) + : AstTypePack(ClassIndex(), location) + , typeList(typeList) +{ +} + +void AstTypePackExplicit::visit(AstVisitor* visitor) +{ + if (visitor->visit(this)) + { + for (AstType* type : typeList.types) + type->visit(visitor); + + if (typeList.tailType) + typeList.tailType->visit(visitor); + } +} + AstTypePackVariadic::AstTypePackVariadic(const Location& location, AstType* variadicType) : AstTypePack(ClassIndex(), location) , variadicType(variadicType) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 6672efe8d..40026d8be 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" +#include "Luau/TimeTrace.h" + #include // Warning: If you are introducing new syntax, ensure that it is behind a separate @@ -13,6 +15,8 @@ LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctions, false) LUAU_FASTFLAGVARIABLE(LuauCaptureBrokenCommentSpans, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) +LUAU_FASTFLAGVARIABLE(LuauTypeAliasPacks, false) +LUAU_FASTFLAGVARIABLE(LuauParseTypePackTypeParameters, false) namespace Luau { @@ -148,6 +152,8 @@ static bool shouldParseTypePackAnnotation(Lexer& lexer) ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& names, Allocator& allocator, ParseOptions options) { + LUAU_TIMETRACE_SCOPE("Parser::parse", "Parser"); + Parser p(buffer, bufferSize, names, allocator); try @@ -769,14 +775,14 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) if (!name) name = Name(nameError, lexer.current().location); - // TODO: support generic type pack parameters in type aliases CLI-39907 auto [generics, genericPacks] = parseGenericTypeList(); expectAndConsume('=', "type alias"); AstType* type = parseTypeAnnotation(); - return allocator.alloc(Location(start, type->location), name->name, generics, type, exported); + return allocator.alloc( + Location(start, type->location), name->name, generics, FFlag::LuauTypeAliasPacks ? genericPacks : AstArray{}, type, exported); } AstDeclaredClassProp Parser::parseDeclaredClassMethod() @@ -1333,7 +1339,7 @@ AstType* Parser::parseTableTypeAnnotation() // ReturnType ::= TypeAnnotation | `(' TypeList `)' // FunctionTypeAnnotation ::= [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstType* Parser::parseFunctionTypeAnnotation() +AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) { incrementRecursionCounter("type annotation"); @@ -1364,14 +1370,23 @@ AstType* Parser::parseFunctionTypeAnnotation() matchRecoveryStopOnToken[Lexeme::SkinnyArrow]--; - // Not a function at all. Just a parenthesized type. + AstArray paramTypes = copy(params); + + // Not a function at all. Just a parenthesized type. Or maybe a type pack with a single element if (params.size() == 1 && !varargAnnotation && monomorphic && lexer.current().type != Lexeme::SkinnyArrow) - return params[0]; + { + if (allowPack) + return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, nullptr})}; + else + return {params[0], {}}; + } + + if (lexer.current().type != Lexeme::SkinnyArrow && monomorphic && allowPack) + return {{}, allocator.alloc(begin.location, AstTypeList{paramTypes, varargAnnotation})}; - AstArray paramTypes = copy(params); AstArray> paramNames = copy(names); - return parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation); + return {parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; } AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, @@ -1421,7 +1436,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location if (c == '|') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation()); + parts.push_back(parseSimpleTypeAnnotation(false).type); isUnion = true; } else if (c == '?') @@ -1434,7 +1449,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location else if (c == '&') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation()); + parts.push_back(parseSimpleTypeAnnotation(false).type); isIntersection = true; } else @@ -1462,6 +1477,30 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location ParseError::raise(begin, "Composite type was not an intersection or union."); } +AstTypeOrPack Parser::parseTypeOrPackAnnotation() +{ + unsigned int oldRecursionCount = recursionCounter; + incrementRecursionCounter("type annotation"); + + Location begin = lexer.current().location; + + TempVector parts(scratchAnnotation); + + auto [type, typePack] = parseSimpleTypeAnnotation(true); + + if (typePack) + { + LUAU_ASSERT(!type); + return {{}, typePack}; + } + + parts.push_back(type); + + recursionCounter = oldRecursionCount; + + return {parseTypeAnnotation(parts, begin), {}}; +} + AstType* Parser::parseTypeAnnotation() { unsigned int oldRecursionCount = recursionCounter; @@ -1470,7 +1509,7 @@ AstType* Parser::parseTypeAnnotation() Location begin = lexer.current().location; TempVector parts(scratchAnnotation); - parts.push_back(parseSimpleTypeAnnotation()); + parts.push_back(parseSimpleTypeAnnotation(false).type); recursionCounter = oldRecursionCount; @@ -1479,7 +1518,7 @@ AstType* Parser::parseTypeAnnotation() // typeannotation ::= nil | Name[`.' Name] [ `<' typeannotation [`,' ...] `>' ] | `typeof' `(' expr `)' | `{' [PropList] `}' // | [`<' varlist `>'] `(' [TypeList] `)' `->` ReturnType -AstType* Parser::parseSimpleTypeAnnotation() +AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) { incrementRecursionCounter("type annotation"); @@ -1488,7 +1527,7 @@ AstType* Parser::parseSimpleTypeAnnotation() if (lexer.current().type == Lexeme::ReservedNil) { nextLexeme(); - return allocator.alloc(begin, std::nullopt, nameNil); + return {allocator.alloc(begin, std::nullopt, nameNil), {}}; } else if (lexer.current().type == Lexeme::Name) { @@ -1514,22 +1553,41 @@ AstType* Parser::parseSimpleTypeAnnotation() expectMatchAndConsume(')', typeofBegin); - return allocator.alloc(Location(begin, end), expr); + return {allocator.alloc(Location(begin, end), expr), {}}; } - AstArray generics = parseTypeParams(); + if (FFlag::LuauParseTypePackTypeParameters) + { + bool hasParameters = false; + AstArray parameters{}; + + if (lexer.current().type == '<') + { + hasParameters = true; + parameters = parseTypeParams(); + } + + Location end = lexer.previousLocation(); + + return {allocator.alloc(Location(begin, end), prefix, name.name, hasParameters, parameters), {}}; + } + else + { + AstArray generics = parseTypeParams(); - Location end = lexer.previousLocation(); + Location end = lexer.previousLocation(); - return allocator.alloc(Location(begin, end), prefix, name.name, generics); + // false in 'hasParameterList' as it is not used without FFlagLuauTypeAliasPacks + return {allocator.alloc(Location(begin, end), prefix, name.name, false, generics), {}}; + } } else if (lexer.current().type == '{') { - return parseTableTypeAnnotation(); + return {parseTableTypeAnnotation(), {}}; } else if (lexer.current().type == '(' || (FFlag::LuauParseGenericFunctions && lexer.current().type == '<')) { - return parseFunctionTypeAnnotation(); + return parseFunctionTypeAnnotation(allowPack); } else { @@ -1538,7 +1596,7 @@ AstType* Parser::parseSimpleTypeAnnotation() // For a missing type annoation, capture 'space' between last token and the next one location = Location(lexer.previousLocation().end, lexer.current().location.begin); - return reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()); + return {reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()), {}}; } } @@ -2312,18 +2370,59 @@ std::pair, AstArray> Parser::parseGenericTypeList() return {generics, genericPacks}; } -AstArray Parser::parseTypeParams() +AstArray Parser::parseTypeParams() { - TempVector result{scratchAnnotation}; + TempVector parameters{scratchTypeOrPackAnnotation}; if (lexer.current().type == '<') { Lexeme begin = lexer.current(); nextLexeme(); + bool seenPack = false; while (true) { - result.push_back(parseTypeAnnotation()); + if (FFlag::LuauParseTypePackTypeParameters) + { + if (shouldParseTypePackAnnotation(lexer)) + { + seenPack = true; + + auto typePack = parseTypePackAnnotation(); + + if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them + parameters.push_back({{}, typePack}); + } + else if (lexer.current().type == '(') + { + auto [type, typePack] = parseTypeOrPackAnnotation(); + + if (typePack) + { + seenPack = true; + + if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them + parameters.push_back({{}, typePack}); + } + else + { + parameters.push_back({type, {}}); + } + } + else if (lexer.current().type == '>' && parameters.empty()) + { + break; + } + else + { + parameters.push_back({parseTypeAnnotation(), {}}); + } + } + else + { + parameters.push_back({parseTypeAnnotation(), {}}); + } + if (lexer.current().type == ',') nextLexeme(); else @@ -2333,7 +2432,7 @@ AstArray Parser::parseTypeParams() expectMatchAndConsume('>', begin); } - return copy(result); + return copy(parameters); } AstExpr* Parser::parseString() diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp new file mode 100644 index 000000000..e6aab20e5 --- /dev/null +++ b/Ast/src/TimeTrace.cpp @@ -0,0 +1,248 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TimeTrace.h" + +#include "Luau/StringUtils.h" + +#include +#include + +#include + +#ifdef _WIN32 +#include +#endif + +#ifdef __APPLE__ +#include +#include +#endif + +#include + +LUAU_FASTFLAGVARIABLE(DebugLuauTimeTracing, false) + +#if defined(LUAU_ENABLE_TIME_TRACE) + +namespace Luau +{ +namespace TimeTrace +{ +static double getClockPeriod() +{ +#if defined(_WIN32) + LARGE_INTEGER result = {}; + QueryPerformanceFrequency(&result); + return 1.0 / double(result.QuadPart); +#elif defined(__APPLE__) + mach_timebase_info_data_t result = {}; + mach_timebase_info(&result); + return double(result.numer) / double(result.denom) * 1e-9; +#elif defined(__linux__) + return 1e-9; +#else + return 1.0 / double(CLOCKS_PER_SEC); +#endif +} + +static double getClockTimestamp() +{ +#if defined(_WIN32) + LARGE_INTEGER result = {}; + QueryPerformanceCounter(&result); + return double(result.QuadPart); +#elif defined(__APPLE__) + return double(mach_absolute_time()); +#elif defined(__linux__) + timespec now; + clock_gettime(CLOCK_MONOTONIC, &now); + return now.tv_sec * 1e9 + now.tv_nsec; +#else + return double(clock()); +#endif +} + +uint32_t getClockMicroseconds() +{ + static double period = getClockPeriod() * 1e6; + static double start = getClockTimestamp(); + + return uint32_t((getClockTimestamp() - start) * period); +} + +struct GlobalContext +{ + GlobalContext() = default; + ~GlobalContext() + { + // Ideally we would want all ThreadContext destructors to run + // But in VS, not all thread_local object instances are destroyed + for (ThreadContext* context : threads) + context->flushEvents(); + + if (traceFile) + fclose(traceFile); + } + + std::mutex mutex; + std::vector threads; + uint32_t nextThreadId = 0; + std::vector tokens; + FILE* traceFile = nullptr; +}; + +GlobalContext& getGlobalContext() +{ + static GlobalContext context; + return context; +} + +uint16_t createToken(GlobalContext& context, const char* name, const char* category) +{ + std::scoped_lock lock(context.mutex); + + LUAU_ASSERT(context.tokens.size() < 64 * 1024); + + context.tokens.push_back({name, category}); + return uint16_t(context.tokens.size() - 1); +} + +uint32_t createThread(GlobalContext& context, ThreadContext* threadContext) +{ + std::scoped_lock lock(context.mutex); + + context.threads.push_back(threadContext); + + return ++context.nextThreadId; +} + +void releaseThread(GlobalContext& context, ThreadContext* threadContext) +{ + std::scoped_lock lock(context.mutex); + + if (auto it = std::find(context.threads.begin(), context.threads.end(), threadContext); it != context.threads.end()) + context.threads.erase(it); +} + +void flushEvents(GlobalContext& context, uint32_t threadId, const std::vector& events, const std::vector& data) +{ + std::scoped_lock lock(context.mutex); + + if (!context.traceFile) + { + context.traceFile = fopen("trace.json", "w"); + + if (!context.traceFile) + return; + + fprintf(context.traceFile, "[\n"); + } + + std::string temp; + const unsigned tempReserve = 64 * 1024; + temp.reserve(tempReserve); + + const char* rawData = data.data(); + + // Formatting state + bool unfinishedEnter = false; + bool unfinishedArgs = false; + + for (const Event& ev : events) + { + switch (ev.type) + { + case EventType::Enter: + { + if (unfinishedArgs) + { + formatAppend(temp, "}"); + unfinishedArgs = false; + } + + if (unfinishedEnter) + { + formatAppend(temp, "},\n"); + unfinishedEnter = false; + } + + Token& token = context.tokens[ev.token]; + + formatAppend(temp, R"({"name": "%s", "cat": "%s", "ph": "B", "ts": %u, "pid": 0, "tid": %u)", token.name, token.category, + ev.data.microsec, threadId); + unfinishedEnter = true; + } + break; + case EventType::Leave: + if (unfinishedArgs) + { + formatAppend(temp, "}"); + unfinishedArgs = false; + } + if (unfinishedEnter) + { + formatAppend(temp, "},\n"); + unfinishedEnter = false; + } + + formatAppend(temp, + R"({"ph": "E", "ts": %u, "pid": 0, "tid": %u},)" + "\n", + ev.data.microsec, threadId); + break; + case EventType::ArgName: + LUAU_ASSERT(unfinishedEnter); + + if (!unfinishedArgs) + { + formatAppend(temp, R"(, "args": { "%s": )", rawData + ev.data.dataPos); + unfinishedArgs = true; + } + else + { + formatAppend(temp, R"(, "%s": )", rawData + ev.data.dataPos); + } + break; + case EventType::ArgValue: + LUAU_ASSERT(unfinishedArgs); + formatAppend(temp, R"("%s")", rawData + ev.data.dataPos); + break; + } + + // Don't want to hit the string capacity and reallocate + if (temp.size() > tempReserve - 1024) + { + fwrite(temp.data(), 1, temp.size(), context.traceFile); + temp.clear(); + } + } + + if (unfinishedArgs) + { + formatAppend(temp, "}"); + unfinishedArgs = false; + } + if (unfinishedEnter) + { + formatAppend(temp, "},\n"); + unfinishedEnter = false; + } + + fwrite(temp.data(), 1, temp.size(), context.traceFile); + fflush(context.traceFile); +} + +ThreadContext& getThreadContext() +{ + thread_local ThreadContext context; + return context; +} + +std::pair createScopeData(const char* name, const char* category) +{ + uint16_t token = createToken(Luau::TimeTrace::getGlobalContext(), name, category); + return {token, Luau::TimeTrace::getThreadContext()}; +} +} // namespace TimeTrace +} // namespace Luau + +#endif diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 920502b82..ed0552d74 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -111,11 +111,24 @@ struct CliFileResolver : Luau::FileResolver return Luau::SourceCode{*source, Luau::SourceCode::Module}; } + std::optional resolveModule(const Luau::ModuleInfo* context, Luau::AstExpr* node) override + { + if (Luau::AstExprConstantString* expr = node->as()) + { + Luau::ModuleName name = std::string(expr->value.data, expr->value.size) + ".lua"; + + return {{name}}; + } + + return std::nullopt; + } + bool moduleExists(const Luau::ModuleName& name) const override { return !!readFile(name); } + std::optional fromAstFragment(Luau::AstExpr* expr) const override { return std::nullopt; @@ -130,11 +143,6 @@ struct CliFileResolver : Luau::FileResolver { return std::nullopt; } - - std::optional getEnvironmentForModule(const Luau::ModuleName& name) const override - { - return std::nullopt; - } }; struct CliConfigResolver : Luau::ConfigResolver diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 022eccb70..797ee20dc 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -4,6 +4,7 @@ #include "Luau/Parser.h" #include "Luau/BytecodeBuilder.h" #include "Luau/Common.h" +#include "Luau/TimeTrace.h" #include #include @@ -137,6 +138,11 @@ struct Compiler uint32_t compileFunction(AstExprFunction* func) { + LUAU_TIMETRACE_SCOPE("Compiler::compileFunction", "Compiler"); + + if (func->debugname.value) + LUAU_TIMETRACE_ARGUMENT("name", func->debugname.value); + LUAU_ASSERT(!functions.contains(func)); LUAU_ASSERT(regTop == 0 && stackSize == 0 && localStack.empty() && upvals.empty()); @@ -3686,6 +3692,8 @@ struct Compiler void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstNameTable& names, const CompileOptions& options) { + LUAU_TIMETRACE_SCOPE("compileOrThrow", "Compiler"); + Compiler compiler(bytecode, options); // since access to some global objects may result in values that change over time, we block table imports @@ -3748,6 +3756,8 @@ void compileOrThrow(BytecodeBuilder& bytecode, const std::string& source, const std::string compile(const std::string& source, const CompileOptions& options, const ParseOptions& parseOptions, BytecodeEncoder* encoder) { + LUAU_TIMETRACE_SCOPE("compile", "Compiler"); + Allocator allocator; AstNameTable names(allocator); ParseResult result = Parser::parse(source.c_str(), source.size(), names, allocator, parseOptions); diff --git a/Sources.cmake b/Sources.cmake index 6f96f6aba..83ed52301 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -9,6 +9,7 @@ target_sources(Luau.Ast PRIVATE Ast/include/Luau/ParseOptions.h Ast/include/Luau/Parser.h Ast/include/Luau/StringUtils.h + Ast/include/Luau/TimeTrace.h Ast/src/Ast.cpp Ast/src/Confusables.cpp @@ -16,6 +17,7 @@ target_sources(Luau.Ast PRIVATE Ast/src/Location.cpp Ast/src/Parser.cpp Ast/src/StringUtils.cpp + Ast/src/TimeTrace.cpp ) # Luau.Compiler Sources @@ -46,6 +48,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Predicate.h Analysis/include/Luau/RecursionCounter.h Analysis/include/Luau/RequireTracer.h + Analysis/include/Luau/Scope.h Analysis/include/Luau/Substitution.h Analysis/include/Luau/Symbol.h Analysis/include/Luau/TopoSortStatements.h @@ -75,6 +78,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Module.cpp Analysis/src/Predicate.cpp Analysis/src/RequireTracer.cpp + Analysis/src/Scope.cpp Analysis/src/Substitution.cpp Analysis/src/Symbol.cpp Analysis/src/TopoSortStatements.cpp @@ -188,6 +192,7 @@ if(TARGET Luau.UnitTest) tests/TopoSort.test.cpp tests/ToString.test.cpp tests/Transpiler.test.cpp + tests/TypeInfer.aliases.test.cpp tests/TypeInfer.annotations.test.cpp tests/TypeInfer.builtins.test.cpp tests/TypeInfer.classes.test.cpp diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index ee4962a50..39a615978 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -8,9 +8,13 @@ #include "lmem.h" #include "lvm.h" +#if LUA_USE_LONGJMP +#include +#include +#else #include +#endif -#include #include LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false) @@ -51,8 +55,8 @@ l_noret luaD_throw(lua_State* L, int errcode) longjmp(jb->buf, 1); } - if (L->global->panic) - L->global->panic(L, errcode); + if (L->global->cb.panic) + L->global->cb.panic(L, errcode); abort(); } diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 9b040fb50..510a9f548 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -16,6 +16,8 @@ LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgain, false) LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgainForwardBarrier, false) LUAU_FASTFLAGVARIABLE(LuauGcFullSkipInactiveThreads, false) LUAU_FASTFLAGVARIABLE(LuauShrinkWeakTables, false) +LUAU_FASTFLAGVARIABLE(LuauConsolidatedStep, false) + LUAU_FASTFLAG(LuauArrayBoundary) #define GC_SWEEPMAX 40 @@ -810,6 +812,133 @@ static size_t singlestep(lua_State* L) return cost; } +static size_t gcstep(lua_State* L, size_t limit) +{ + size_t cost = 0; + global_State* g = L->global; + switch (g->gcstate) + { + case GCSpause: + { + markroot(L); /* start a new collection */ + break; + } + case GCSpropagate: + { + if (FFlag::LuauRescanGrayAgain) + { + while (g->gray && cost < limit) + { + g->gcstats.currcycle.markitems++; + + cost += propagatemark(g); + } + + if (!g->gray) + { + // perform one iteration over 'gray again' list + g->gray = g->grayagain; + g->grayagain = NULL; + + g->gcstate = GCSpropagateagain; + } + } + else + { + while (g->gray && cost < limit) + { + g->gcstats.currcycle.markitems++; + + cost += propagatemark(g); + } + + if (!g->gray) /* no more `gray' objects */ + { + double starttimestamp = lua_clock(); + + g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + + atomic(L); /* finish mark phase */ + + g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + } + } + break; + } + case GCSpropagateagain: + { + while (g->gray && cost < limit) + { + g->gcstats.currcycle.markitems++; + + cost += propagatemark(g); + } + + if (!g->gray) /* no more `gray' objects */ + { + double starttimestamp = lua_clock(); + + g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + + atomic(L); /* finish mark phase */ + + g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + } + break; + } + case GCSsweepstring: + { + while (g->sweepstrgc < g->strt.size && cost < limit) + { + size_t traversedcount = 0; + sweepwholelist(L, &g->strt.hash[g->sweepstrgc++], &traversedcount); + + g->gcstats.currcycle.sweepitems += traversedcount; + cost += GC_SWEEPCOST; + } + + // nothing more to sweep? + if (g->sweepstrgc >= g->strt.size) + { + // sweep string buffer list and preserve used string count + uint32_t nuse = L->global->strt.nuse; + + size_t traversedcount = 0; + sweepwholelist(L, &g->strbufgc, &traversedcount); + + L->global->strt.nuse = nuse; + + g->gcstats.currcycle.sweepitems += traversedcount; + g->gcstate = GCSsweep; // end sweep-string phase + } + break; + } + case GCSsweep: + { + while (*g->sweepgc && cost < limit) + { + size_t traversedcount = 0; + g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX, &traversedcount); + + g->gcstats.currcycle.sweepitems += traversedcount; + cost += GC_SWEEPMAX * GC_SWEEPCOST; + } + + if (*g->sweepgc == NULL) + { /* nothing more to sweep? */ + shrinkbuffers(L); + g->gcstate = GCSpause; /* end collection */ + } + break; + } + default: + LUAU_ASSERT(0); + } + return cost; +} + static int64_t getheaptriggererroroffset(GCHeapTriggerStats* triggerstats, GCCycleStats* cyclestats) { // adjust for error using Proportional-Integral controller @@ -878,33 +1007,40 @@ void luaC_step(lua_State* L, bool assist) if (g->gcstate == GCSpause) startGcCycleStats(g); - if (assist) - g->gcstats.currcycle.assistwork += lim; - else - g->gcstats.currcycle.explicitwork += lim; - int lastgcstate = g->gcstate; double lastttimestamp = lua_clock(); - // always perform at least one single step - do + if (FFlag::LuauConsolidatedStep) { - lim -= singlestep(L); + size_t work = gcstep(L, lim); - // if we have switched to a different state, capture the duration of last stage - // this way we reduce the number of timer calls we make - if (lastgcstate != g->gcstate) + if (assist) + g->gcstats.currcycle.assistwork += work; + else + g->gcstats.currcycle.explicitwork += work; + } + else + { + // always perform at least one single step + do { - GC_INTERRUPT(lastgcstate); + lim -= singlestep(L); + + // if we have switched to a different state, capture the duration of last stage + // this way we reduce the number of timer calls we make + if (lastgcstate != g->gcstate) + { + GC_INTERRUPT(lastgcstate); - double now = lua_clock(); + double now = lua_clock(); - recordGcStateTime(g, lastgcstate, now - lastttimestamp, assist); + recordGcStateTime(g, lastgcstate, now - lastttimestamp, assist); - lastttimestamp = now; - lastgcstate = g->gcstate; - } - } while (lim > 0 && g->gcstate != GCSpause); + lastttimestamp = now; + lastgcstate = g->gcstate; + } + } while (lim > 0 && g->gcstate != GCSpause); + } recordGcStateTime(g, lastgcstate, lua_clock() - lastttimestamp, assist); @@ -931,7 +1067,14 @@ void luaC_step(lua_State* L, bool assist) g->GCthreshold -= debt; } - GC_INTERRUPT(g->gcstate); + if (FFlag::LuauConsolidatedStep) + { + GC_INTERRUPT(lastgcstate); + } + else + { + GC_INTERRUPT(g->gcstate); + } } void luaC_fullgc(lua_State* L) @@ -957,7 +1100,10 @@ void luaC_fullgc(lua_State* L) while (g->gcstate != GCSpause) { LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); - singlestep(L); + if (FFlag::LuauConsolidatedStep) + gcstep(L, SIZE_MAX); + else + singlestep(L); } finishGcCycleStats(g); @@ -968,7 +1114,10 @@ void luaC_fullgc(lua_State* L) markroot(L); while (g->gcstate != GCSpause) { - singlestep(L); + if (FFlag::LuauConsolidatedStep) + gcstep(L, SIZE_MAX); + else + singlestep(L); } /* reclaim as much buffer memory as possible (shrinkbuffers() called during sweep is incremental) */ shrinkbuffersfull(L); diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 090e183ff..de5788eb6 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -9,14 +9,8 @@ #include "ldebug.h" #include "lvm.h" -LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauTableMoveTelemetry, false) - LUAU_FASTFLAGVARIABLE(LuauTableFreeze, false) -bool lua_telemetry_table_move_oob_src_from = false; -bool lua_telemetry_table_move_oob_src_to = false; -bool lua_telemetry_table_move_oob_dst = false; - static int foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); @@ -202,22 +196,6 @@ static int tmove(lua_State* L) int tt = !lua_isnoneornil(L, 5) ? 5 : 1; /* destination table */ luaL_checktype(L, tt, LUA_TTABLE); - if (DFFlag::LuauTableMoveTelemetry) - { - int nf = lua_objlen(L, 1); - int nt = lua_objlen(L, tt); - - // source index range must be in bounds in source table unless the table is empty (permits 1..#t moves) - if (!(f == 1 || (f >= 1 && f <= nf))) - lua_telemetry_table_move_oob_src_from = true; - if (!(e == nf || (e >= 1 && e <= nf))) - lua_telemetry_table_move_oob_src_to = true; - - // destination index must be in bounds in dest table or be exactly at the first empty element (permits concats) - if (!(t == nt + 1 || (t >= 1 && t <= nt + 1))) - lua_telemetry_table_move_oob_dst = true; - } - if (e >= f) { /* otherwise, nothing to move */ luaL_argcheck(L, f > 0 || e < INT_MAX + f, 3, "too many elements to move"); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 5f0ee922a..eed2862b1 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -16,8 +16,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauLoopUseSafeenv, false) - // Disable c99-designator to avoid the warning in CGOTO dispatch table #ifdef __clang__ #if __has_warning("-Wc99-designator") @@ -292,10 +290,6 @@ inline bool luau_skipstep(uint8_t op) return op == LOP_PREPVARARGS || op == LOP_BREAK; } -// declared in lbaselib.cpp, needed to support cases when pairs/ipairs have been replaced via setfenv -LUAI_FUNC int luaB_inext(lua_State* L); -LUAI_FUNC int luaB_next(lua_State* L); - template static void luau_execute(lua_State* L) { @@ -2223,8 +2217,7 @@ static void luau_execute(lua_State* L) StkId ra = VM_REG(LUAU_INSN_A(insn)); // fast-path: ipairs/inext - bool safeenv = FFlag::LuauLoopUseSafeenv ? cl->env->safeenv : ttisfunction(ra) && clvalue(ra)->isC && clvalue(ra)->c.f == luaB_inext; - if (safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0) + if (cl->env->safeenv && ttistable(ra + 1) && ttisnumber(ra + 2) && nvalue(ra + 2) == 0.0) { setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } @@ -2304,8 +2297,7 @@ static void luau_execute(lua_State* L) StkId ra = VM_REG(LUAU_INSN_A(insn)); // fast-path: pairs/next - bool safeenv = FFlag::LuauLoopUseSafeenv ? cl->env->safeenv : ttisfunction(ra) && clvalue(ra)->isC && clvalue(ra)->c.f == luaB_next; - if (safeenv && ttistable(ra + 1) && ttisnil(ra + 2)) + if (cl->env->safeenv && ttistable(ra + 1) && ttisnil(ra + 2)) { setpvalue(ra + 2, reinterpret_cast(uintptr_t(0))); } diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index 0a2323425..b932a85b3 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -12,7 +12,32 @@ #include -#include +// TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens +template +struct TempBuffer +{ + lua_State* L; + T* data; + size_t count; + + TempBuffer(lua_State* L, size_t count) + : L(L) + , data(luaM_newarray(L, count, T, 0)) + , count(count) + { + } + + ~TempBuffer() + { + luaM_freearray(L, data, count, T, 0); + } + + T& operator[](size_t index) + { + LUAU_ASSERT(index < count); + return data[index]; + } +}; void luaV_getimport(lua_State* L, Table* env, TValue* k, uint32_t id, bool propagatenil) { @@ -67,7 +92,7 @@ static unsigned int readVarInt(const char* data, size_t size, size_t& offset) return result; } -static TString* readString(std::vector& strings, const char* data, size_t size, size_t& offset) +static TString* readString(TempBuffer& strings, const char* data, size_t size, size_t& offset) { unsigned int id = readVarInt(data, size, offset); @@ -133,6 +158,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size } // pause GC for the duration of deserialization - some objects we're creating aren't rooted + // TODO: if an allocation error happens mid-load, we do not unpause GC! size_t GCthreshold = L->global->GCthreshold; L->global->GCthreshold = SIZE_MAX; @@ -144,7 +170,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size // string table unsigned int stringCount = readVarInt(data, size, offset); - std::vector strings(stringCount); + TempBuffer strings(L, stringCount); for (unsigned int i = 0; i < stringCount; ++i) { @@ -156,7 +182,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size // proto table unsigned int protoCount = readVarInt(data, size, offset); - std::vector protos(protoCount); + TempBuffer protos(L, protoCount); for (unsigned int i = 0; i < protoCount; ++i) { diff --git a/bench/tests/deltablue.lua b/bench/tests/deltablue.lua deleted file mode 100644 index ecf246d30..000000000 --- a/bench/tests/deltablue.lua +++ /dev/null @@ -1,934 +0,0 @@ -local bench = script and require(script.Parent.bench_support) or require("bench_support") - --- Copyright 2008 the V8 project authors. All rights reserved. --- Copyright 1996 John Maloney and Mario Wolczko. - --- This program is free software; you can redistribute it and/or modify --- it under the terms of the GNU General Public License as published by --- the Free Software Foundation; either version 2 of the License, or --- (at your option) any later version. --- --- This program is distributed in the hope that it will be useful, --- but WITHOUT ANY WARRANTY; without even the implied warranty of --- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the --- GNU General Public License for more details. --- --- You should have received a copy of the GNU General Public License --- along with this program; if not, write to the Free Software --- Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA - - --- This implementation of the DeltaBlue benchmark is derived --- from the Smalltalk implementation by John Maloney and Mario --- Wolczko. Some parts have been translated directly, whereas --- others have been modified more aggresively to make it feel --- more like a JavaScript program. - - --- --- A JavaScript implementation of the DeltaBlue constraint-solving --- algorithm, as described in: --- --- "The DeltaBlue Algorithm: An Incremental Constraint Hierarchy Solver" --- Bjorn N. Freeman-Benson and John Maloney --- January 1990 Communications of the ACM, --- also available as University of Washington TR 89-08-06. --- --- Beware: this benchmark is written in a grotesque style where --- the constraint model is built by side-effects from constructors. --- I've kept it this way to avoid deviating too much from the original --- implementation. --- - -function class(base) - local T = {} - T.__index = T - - if base then - T.super = base - setmetatable(T, base) - end - - function T.new(...) - local O = {} - setmetatable(O, T) - O:constructor(...) - return O - end - - return T -end - -local planner - ---- O b j e c t M o d e l --- - -local function alert (...) print(...) end - -local OrderedCollection = class() - -function OrderedCollection:constructor() - self.elms = {} -end - -function OrderedCollection:add(elm) - self.elms[#self.elms + 1] = elm -end - -function OrderedCollection:at (index) - return self.elms[index] -end - -function OrderedCollection:size () - return #self.elms -end - -function OrderedCollection:removeFirst () - local e = self.elms[#self.elms] - self.elms[#self.elms] = nil - return e -end - -function OrderedCollection:remove (elm) - local index = 0 - local skipped = 0 - - for i = 1, #self.elms do - local value = self.elms[i] - if value ~= elm then - self.elms[index] = value - index = index + 1 - else - skipped = skipped + 1 - end - end - - local l = #self.elms - for i = 1, skipped do self.elms[l - i + 1] = nil end -end - --- --- S t r e n g t h --- - --- --- Strengths are used to measure the relative importance of constraints. --- New strengths may be inserted in the strength hierarchy without --- disrupting current constraints. Strengths cannot be created outside --- this class, so pointer comparison can be used for value comparison. --- - -local Strength = class() - -function Strength:constructor(strengthValue, name) - self.strengthValue = strengthValue - self.name = name -end - -function Strength.stronger (s1, s2) - return s1.strengthValue < s2.strengthValue -end - -function Strength.weaker (s1, s2) - return s1.strengthValue > s2.strengthValue -end - -function Strength.weakestOf (s1, s2) - return Strength.weaker(s1, s2) and s1 or s2 -end - -function Strength.strongest (s1, s2) - return Strength.stronger(s1, s2) and s1 or s2 -end - -function Strength:nextWeaker () - local v = self.strengthValue - if v == 0 then return Strength.WEAKEST - elseif v == 1 then return Strength.WEAK_DEFAULT - elseif v == 2 then return Strength.NORMAL - elseif v == 3 then return Strength.STRONG_DEFAULT - elseif v == 4 then return Strength.PREFERRED - elseif v == 5 then return Strength.REQUIRED - end -end - --- Strength constants. -Strength.REQUIRED = Strength.new(0, "required"); -Strength.STONG_PREFERRED = Strength.new(1, "strongPreferred"); -Strength.PREFERRED = Strength.new(2, "preferred"); -Strength.STRONG_DEFAULT = Strength.new(3, "strongDefault"); -Strength.NORMAL = Strength.new(4, "normal"); -Strength.WEAK_DEFAULT = Strength.new(5, "weakDefault"); -Strength.WEAKEST = Strength.new(6, "weakest"); - --- --- C o n s t r a i n t --- - --- --- An abstract class representing a system-maintainable relationship --- (or "constraint") between a set of variables. A constraint supplies --- a strength instance variable; concrete subclasses provide a means --- of storing the constrained variables and other information required --- to represent a constraint. --- - -local Constraint = class () - -function Constraint:constructor(strength) - self.strength = strength -end - --- --- Activate this constraint and attempt to satisfy it. --- -function Constraint:addConstraint () - self:addToGraph() - planner:incrementalAdd(self) -end - --- --- Attempt to find a way to enforce this constraint. If successful, --- record the solution, perhaps modifying the current dataflow --- graph. Answer the constraint that this constraint overrides, if --- there is one, or nil, if there isn't. --- Assume: I am not already satisfied. --- -function Constraint:satisfy (mark) - self:chooseMethod(mark) - if not self:isSatisfied() then - if self.strength == Strength.REQUIRED then - alert("Could not satisfy a required constraint!") - end - return nil - end - self:markInputs(mark) - local out = self:output() - local overridden = out.determinedBy - if overridden ~= nil then overridden:markUnsatisfied() end - out.determinedBy = self - if not planner:addPropagate(self, mark) then alert("Cycle encountered") end - out.mark = mark - return overridden -end - -function Constraint:destroyConstraint () - if self:isSatisfied() - then planner:incrementalRemove(self) - else self:removeFromGraph() - end -end - --- --- Normal constraints are not input constraints. An input constraint --- is one that depends on external state, such as the mouse, the --- keybord, a clock, or some arbitraty piece of imperative code. --- -function Constraint:isInput () - return false -end - - --- --- U n a r y C o n s t r a i n t --- - --- --- Abstract superclass for constraints having a single possible output --- variable. --- - -local UnaryConstraint = class(Constraint) - -function UnaryConstraint:constructor (v, strength) - UnaryConstraint.super.constructor(self, strength) - self.myOutput = v - self.satisfied = false - self:addConstraint() -end - --- --- Adds this constraint to the constraint graph --- -function UnaryConstraint:addToGraph () - self.myOutput:addConstraint(self) - self.satisfied = false -end - --- --- Decides if this constraint can be satisfied and records that --- decision. --- -function UnaryConstraint:chooseMethod (mark) - self.satisfied = (self.myOutput.mark ~= mark) - and Strength.stronger(self.strength, self.myOutput.walkStrength); -end - --- --- Returns true if this constraint is satisfied in the current solution. --- -function UnaryConstraint:isSatisfied () - return self.satisfied; -end - -function UnaryConstraint:markInputs (mark) - -- has no inputs -end - --- --- Returns the current output variable. --- -function UnaryConstraint:output () - return self.myOutput -end - --- --- Calculate the walkabout strength, the stay flag, and, if it is --- 'stay', the value for the current output of this constraint. Assume --- this constraint is satisfied. --- -function UnaryConstraint:recalculate () - self.myOutput.walkStrength = self.strength - self.myOutput.stay = not self:isInput() - if self.myOutput.stay then - self:execute() -- Stay optimization - end -end - --- --- Records that this constraint is unsatisfied --- -function UnaryConstraint:markUnsatisfied () - self.satisfied = false -end - -function UnaryConstraint:inputsKnown () - return true -end - -function UnaryConstraint:removeFromGraph () - if self.myOutput ~= nil then - self.myOutput:removeConstraint(self) - end - self.satisfied = false -end - --- --- S t a y C o n s t r a i n t --- - --- --- Variables that should, with some level of preference, stay the same. --- Planners may exploit the fact that instances, if satisfied, will not --- change their output during plan execution. This is called "stay --- optimization". --- - -local StayConstraint = class(UnaryConstraint) - -function StayConstraint:constructor(v, str) - StayConstraint.super.constructor(self, v, str) -end - -function StayConstraint:execute () - -- Stay constraints do nothing -end - --- --- E d i t C o n s t r a i n t --- - --- --- A unary input constraint used to mark a variable that the client --- wishes to change. --- - -local EditConstraint = class (UnaryConstraint) - -function EditConstraint:constructor(v, str) - EditConstraint.super.constructor(self, v, str) -end - --- --- Edits indicate that a variable is to be changed by imperative code. --- -function EditConstraint:isInput () - return true -end - -function EditConstraint:execute () - -- Edit constraints do nothing -end - --- --- B i n a r y C o n s t r a i n t --- - -local Direction = {} -Direction.NONE = 0 -Direction.FORWARD = 1 -Direction.BACKWARD = -1 - --- --- Abstract superclass for constraints having two possible output --- variables. --- - -local BinaryConstraint = class(Constraint) - -function BinaryConstraint:constructor(var1, var2, strength) - BinaryConstraint.super.constructor(self, strength); - self.v1 = var1 - self.v2 = var2 - self.direction = Direction.NONE - self:addConstraint() -end - - --- --- Decides if this constraint can be satisfied and which way it --- should flow based on the relative strength of the variables related, --- and record that decision. --- -function BinaryConstraint:chooseMethod (mark) - if self.v1.mark == mark then - self.direction = (self.v2.mark ~= mark and Strength.stronger(self.strength, self.v2.walkStrength)) and Direction.FORWARD or Direction.NONE - end - if self.v2.mark == mark then - self.direction = (self.v1.mark ~= mark and Strength.stronger(self.strength, self.v1.walkStrength)) and Direction.BACKWARD or Direction.NONE - end - if Strength.weaker(self.v1.walkStrength, self.v2.walkStrength) then - self.direction = Strength.stronger(self.strength, self.v1.walkStrength) and Direction.BACKWARD or Direction.NONE - else - self.direction = Strength.stronger(self.strength, self.v2.walkStrength) and Direction.FORWARD or Direction.BACKWARD - end -end - --- --- Add this constraint to the constraint graph --- -function BinaryConstraint:addToGraph () - self.v1:addConstraint(self) - self.v2:addConstraint(self) - self.direction = Direction.NONE -end - --- --- Answer true if this constraint is satisfied in the current solution. --- -function BinaryConstraint:isSatisfied () - return self.direction ~= Direction.NONE -end - --- --- Mark the input variable with the given mark. --- -function BinaryConstraint:markInputs (mark) - self:input().mark = mark -end - --- --- Returns the current input variable --- -function BinaryConstraint:input () - return (self.direction == Direction.FORWARD) and self.v1 or self.v2 -end - --- --- Returns the current output variable --- -function BinaryConstraint:output () - return (self.direction == Direction.FORWARD) and self.v2 or self.v1 -end - --- --- Calculate the walkabout strength, the stay flag, and, if it is --- 'stay', the value for the current output of this --- constraint. Assume this constraint is satisfied. --- -function BinaryConstraint:recalculate () - local ihn = self:input() - local out = self:output() - out.walkStrength = Strength.weakestOf(self.strength, ihn.walkStrength); - out.stay = ihn.stay - if out.stay then self:execute() end -end - --- --- Record the fact that self constraint is unsatisfied. --- -function BinaryConstraint:markUnsatisfied () - self.direction = Direction.NONE -end - -function BinaryConstraint:inputsKnown (mark) - local i = self:input() - return i.mark == mark or i.stay or i.determinedBy == nil -end - -function BinaryConstraint:removeFromGraph () - if (self.v1 ~= nil) then self.v1:removeConstraint(self) end - if (self.v2 ~= nil) then self.v2:removeConstraint(self) end - self.direction = Direction.NONE -end - --- --- S c a l e C o n s t r a i n t --- - --- --- Relates two variables by the linear scaling relationship: "v2 = --- (v1 * scale) + offset". Either v1 or v2 may be changed to maintain --- this relationship but the scale factor and offset are considered --- read-only. --- - -local ScaleConstraint = class (BinaryConstraint) - -function ScaleConstraint:constructor(src, scale, offset, dest, strength) - self.direction = Direction.NONE - self.scale = scale - self.offset = offset - ScaleConstraint.super.constructor(self, src, dest, strength) -end - - --- --- Adds this constraint to the constraint graph. --- -function ScaleConstraint:addToGraph () - ScaleConstraint.super.addToGraph(self) - self.scale:addConstraint(self) - self.offset:addConstraint(self) -end - -function ScaleConstraint:removeFromGraph () - ScaleConstraint.super.removeFromGraph(self) - if (self.scale ~= nil) then self.scale:removeConstraint(self) end - if (self.offset ~= nil) then self.offset:removeConstraint(self) end -end - -function ScaleConstraint:markInputs (mark) - ScaleConstraint.super.markInputs(self, mark); - self.offset.mark = mark - self.scale.mark = mark -end - --- --- Enforce this constraint. Assume that it is satisfied. --- -function ScaleConstraint:execute () - if self.direction == Direction.FORWARD then - self.v2.value = self.v1.value * self.scale.value + self.offset.value - else - self.v1.value = (self.v2.value - self.offset.value) / self.scale.value - end -end - --- --- Calculate the walkabout strength, the stay flag, and, if it is --- 'stay', the value for the current output of this constraint. Assume --- this constraint is satisfied. --- -function ScaleConstraint:recalculate () - local ihn = self:input() - local out = self:output() - out.walkStrength = Strength.weakestOf(self.strength, ihn.walkStrength) - out.stay = ihn.stay and self.scale.stay and self.offset.stay - if out.stay then self:execute() end -end - --- --- E q u a l i t y C o n s t r a i n t --- - --- --- Constrains two variables to have the same value. --- - -local EqualityConstraint = class (BinaryConstraint) - -function EqualityConstraint:constructor(var1, var2, strength) - EqualityConstraint.super.constructor(self, var1, var2, strength) -end - - --- --- Enforce this constraint. Assume that it is satisfied. --- -function EqualityConstraint:execute () - self:output().value = self:input().value -end - --- --- V a r i a b l e --- - --- --- A constrained variable. In addition to its value, it maintain the --- structure of the constraint graph, the current dataflow graph, and --- various parameters of interest to the DeltaBlue incremental --- constraint solver. --- -local Variable = class () - -function Variable:constructor(name, initialValue) - self.value = initialValue or 0 - self.constraints = OrderedCollection.new() - self.determinedBy = nil - self.mark = 0 - self.walkStrength = Strength.WEAKEST - self.stay = true - self.name = name -end - --- --- Add the given constraint to the set of all constraints that refer --- this variable. --- -function Variable:addConstraint (c) - self.constraints:add(c) -end - --- --- Removes all traces of c from this variable. --- -function Variable:removeConstraint (c) - self.constraints:remove(c) - if self.determinedBy == c then - self.determinedBy = nil - end -end - --- --- P l a n n e r --- - --- --- The DeltaBlue planner --- -local Planner = class() -function Planner:constructor() - self.currentMark = 0 -end - --- --- Attempt to satisfy the given constraint and, if successful, --- incrementally update the dataflow graph. Details: If satifying --- the constraint is successful, it may override a weaker constraint --- on its output. The algorithm attempts to resatisfy that --- constraint using some other method. This process is repeated --- until either a) it reaches a variable that was not previously --- determined by any constraint or b) it reaches a constraint that --- is too weak to be satisfied using any of its methods. The --- variables of constraints that have been processed are marked with --- a unique mark value so that we know where we've been. This allows --- the algorithm to avoid getting into an infinite loop even if the --- constraint graph has an inadvertent cycle. --- -function Planner:incrementalAdd (c) - local mark = self:newMark() - local overridden = c:satisfy(mark) - while overridden ~= nil do - overridden = overridden:satisfy(mark) - end -end - --- --- Entry point for retracting a constraint. Remove the given --- constraint and incrementally update the dataflow graph. --- Details: Retracting the given constraint may allow some currently --- unsatisfiable downstream constraint to be satisfied. We therefore collect --- a list of unsatisfied downstream constraints and attempt to --- satisfy each one in turn. This list is traversed by constraint --- strength, strongest first, as a heuristic for avoiding --- unnecessarily adding and then overriding weak constraints. --- Assume: c is satisfied. --- -function Planner:incrementalRemove (c) - local out = c:output() - c:markUnsatisfied() - c:removeFromGraph() - local unsatisfied = self:removePropagateFrom(out) - local strength = Strength.REQUIRED - repeat - for i = 1, unsatisfied:size() do - local u = unsatisfied:at(i) - if u.strength == strength then - self:incrementalAdd(u) - end - end - strength = strength:nextWeaker() - until strength == Strength.WEAKEST -end - --- --- Select a previously unused mark value. --- -function Planner:newMark () - self.currentMark = self.currentMark + 1 - return self.currentMark -end - --- --- Extract a plan for resatisfaction starting from the given source --- constraints, usually a set of input constraints. This method --- assumes that stay optimization is desired; the plan will contain --- only constraints whose output variables are not stay. Constraints --- that do no computation, such as stay and edit constraints, are --- not included in the plan. --- Details: The outputs of a constraint are marked when it is added --- to the plan under construction. A constraint may be appended to --- the plan when all its input variables are known. A variable is --- known if either a) the variable is marked (indicating that has --- been computed by a constraint appearing earlier in the plan), b) --- the variable is 'stay' (i.e. it is a constant at plan execution --- time), or c) the variable is not determined by any --- constraint. The last provision is for past states of history --- variables, which are not stay but which are also not computed by --- any constraint. --- Assume: sources are all satisfied. --- -local Plan -- FORWARD DECLARATION -function Planner:makePlan (sources) - local mark = self:newMark() - local plan = Plan.new() - local todo = sources - while todo:size() > 0 do - local c = todo:removeFirst() - if c:output().mark ~= mark and c:inputsKnown(mark) then - plan:addConstraint(c) - c:output().mark = mark - self:addConstraintsConsumingTo(c:output(), todo) - end - end - return plan -end - --- --- Extract a plan for resatisfying starting from the output of the --- given constraints, usually a set of input constraints. --- -function Planner:extractPlanFromConstraints (constraints) - local sources = OrderedCollection.new() - for i = 1, constraints:size() do - local c = constraints:at(i) - if c:isInput() and c:isSatisfied() then - -- not in plan already and eligible for inclusion - sources:add(c) - end - end - return self:makePlan(sources) -end - --- --- Recompute the walkabout strengths and stay flags of all variables --- downstream of the given constraint and recompute the actual --- values of all variables whose stay flag is true. If a cycle is --- detected, remove the given constraint and answer --- false. Otherwise, answer true. --- Details: Cycles are detected when a marked variable is --- encountered downstream of the given constraint. The sender is --- assumed to have marked the inputs of the given constraint with --- the given mark. Thus, encountering a marked node downstream of --- the output constraint means that there is a path from the --- constraint's output to one of its inputs. --- -function Planner:addPropagate (c, mark) - local todo = OrderedCollection.new() - todo:add(c) - while todo:size() > 0 do - local d = todo:removeFirst() - if d:output().mark == mark then - self:incrementalRemove(c) - return false - end - d:recalculate() - self:addConstraintsConsumingTo(d:output(), todo) - end - return true -end - - --- --- Update the walkabout strengths and stay flags of all variables --- downstream of the given constraint. Answer a collection of --- unsatisfied constraints sorted in order of decreasing strength. --- -function Planner:removePropagateFrom (out) - out.determinedBy = nil - out.walkStrength = Strength.WEAKEST - out.stay = true - local unsatisfied = OrderedCollection.new() - local todo = OrderedCollection.new() - todo:add(out) - while todo:size() > 0 do - local v = todo:removeFirst() - for i = 1, v.constraints:size() do - local c = v.constraints:at(i) - if not c:isSatisfied() then unsatisfied:add(c) end - end - local determining = v.determinedBy - for i = 1, v.constraints:size() do - local next = v.constraints:at(i); - if next ~= determining and next:isSatisfied() then - next:recalculate() - todo:add(next:output()) - end - end - end - return unsatisfied -end - -function Planner:addConstraintsConsumingTo (v, coll) - local determining = v.determinedBy - local cc = v.constraints - for i = 1, cc:size() do - local c = cc:at(i) - if c ~= determining and c:isSatisfied() then - coll:add(c) - end - end -end - --- --- P l a n --- - --- --- A Plan is an ordered list of constraints to be executed in sequence --- to resatisfy all currently satisfiable constraints in the face of --- one or more changing inputs. --- -Plan = class() -function Plan:constructor() - self.v = OrderedCollection.new() -end - -function Plan:addConstraint (c) - self.v:add(c) -end - -function Plan:size () - return self.v:size() -end - -function Plan:constraintAt (index) - return self.v:at(index) -end - -function Plan:execute () - for i = 1, self:size() do - local c = self:constraintAt(i) - c:execute() - end -end - --- --- M a i n --- - --- --- This is the standard DeltaBlue benchmark. A long chain of equality --- constraints is constructed with a stay constraint on one end. An --- edit constraint is then added to the opposite end and the time is --- measured for adding and removing this constraint, and extracting --- and executing a constraint satisfaction plan. There are two cases. --- In case 1, the added constraint is stronger than the stay --- constraint and values must propagate down the entire length of the --- chain. In case 2, the added constraint is weaker than the stay --- constraint so it cannot be accomodated. The cost in this case is, --- of course, very low. Typical situations lie somewhere between these --- two extremes. --- -local function chainTest(n) - planner = Planner.new() - local prev = nil - local first = nil - local last = nil - - -- Build chain of n equality constraints - for i = 0, n do - local name = "v" .. i; - local v = Variable.new(name) - if prev ~= nil then EqualityConstraint.new(prev, v, Strength.REQUIRED) end - if i == 0 then first = v end - if i == n then last = v end - prev = v - end - - StayConstraint.new(last, Strength.STRONG_DEFAULT) - local edit = EditConstraint.new(first, Strength.PREFERRED) - local edits = OrderedCollection.new() - edits:add(edit) - local plan = planner:extractPlanFromConstraints(edits) - for i = 0, 99 do - first.value = i - plan:execute() - if last.value ~= i then - alert("Chain test failed.") - end - end -end - -local function change(v, newValue) - local edit = EditConstraint.new(v, Strength.PREFERRED) - local edits = OrderedCollection.new() - edits:add(edit) - local plan = planner:extractPlanFromConstraints(edits) - for i = 1, 10 do - v.value = newValue - plan:execute() - end - edit:destroyConstraint() -end - --- --- This test constructs a two sets of variables related to each --- other by a simple linear transformation (scale and offset). The --- time is measured to change a variable on either side of the --- mapping and to change the scale and offset factors. --- -local function projectionTest(n) - planner = Planner.new(); - local scale = Variable.new("scale", 10); - local offset = Variable.new("offset", 1000); - local src = nil - local dst = nil; - - local dests = OrderedCollection.new(); - for i = 0, n - 1 do - src = Variable.new("src" .. i, i); - dst = Variable.new("dst" .. i, i); - dests:add(dst); - StayConstraint.new(src, Strength.NORMAL); - ScaleConstraint.new(src, scale, offset, dst, Strength.REQUIRED); - end - - change(src, 17) - if dst.value ~= 1170 then alert("Projection 1 failed") end - change(dst, 1050) - if src.value ~= 5 then alert("Projection 2 failed") end - change(scale, 5) - for i = 0, n - 2 do - if dests:at(i + 1).value ~= i * 5 + 1000 then - alert("Projection 3 failed") - end - end - change(offset, 2000) - for i = 0, n - 2 do - if dests:at(i + 1).value ~= i * 5 + 2000 then - alert("Projection 4 failed") - end - end -end - -function test() - local t0 = os.clock() - chainTest(1000); - projectionTest(1000); - local t1 = os.clock() - return t1-t0 -end - -bench.runCode(test, "deltablue") diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 9cd642cfd..07910a0ac 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -23,19 +23,17 @@ static std::optional nullCallback(std::string tag, std::op return std::nullopt; } -struct ACFixture : Fixture +template +struct ACFixtureImpl : BaseType { AutocompleteResult autocomplete(unsigned row, unsigned column) { - return Luau::autocomplete(frontend, "MainModule", Position{row, column}, nullCallback); + return Luau::autocomplete(this->frontend, "MainModule", Position{row, column}, nullCallback); } AutocompleteResult autocomplete(char marker) { - auto i = markerPosition.find(marker); - LUAU_ASSERT(i != markerPosition.end()); - const Position& pos = i->second; - return Luau::autocomplete(frontend, "MainModule", pos, nullCallback); + return Luau::autocomplete(this->frontend, "MainModule", getPosition(marker), nullCallback); } CheckResult check(const std::string& source) @@ -45,16 +43,18 @@ struct ACFixture : Fixture filteredSource.reserve(source.size()); Position curPos(0, 0); + char prevChar{}; for (char c : source) { - if (c == '@' && !filteredSource.empty()) + if (prevChar == '@') + { + LUAU_ASSERT("Illegal marker character" && c >= '0' && c <= '9'); + LUAU_ASSERT("Duplicate marker found" && markerPosition.count(c) == 0); + markerPosition.insert(std::pair{c, curPos}); + } + else if (c == '@') { - char prevChar = filteredSource.back(); - filteredSource.pop_back(); - curPos.column--; // Adjust column position since we removed a character from the output - LUAU_ASSERT("Illegal marker character" && prevChar >= '0' && prevChar <= '9'); - LUAU_ASSERT("Duplicate marker found" && markerPosition.count(prevChar) == 0); - markerPosition.insert(std::pair{prevChar, curPos}); + // skip the '@' character } else { @@ -69,22 +69,39 @@ struct ACFixture : Fixture curPos.column++; } } + prevChar = c; } + LUAU_ASSERT("Digit expected after @ symbol" && prevChar != '@'); return Fixture::check(filteredSource); } + const Position& getPosition(char marker) const + { + auto i = markerPosition.find(marker); + LUAU_ASSERT(i != markerPosition.end()); + return i->second; + } + // Maps a marker character (0-9 inclusive) to a position in the source code. std::map markerPosition; }; +struct ACFixture : ACFixtureImpl +{ +}; + +struct UnfrozenACFixture : ACFixtureImpl +{ +}; + TEST_SUITE_BEGIN("AutocompleteTest"); TEST_CASE_FIXTURE(ACFixture, "empty_program") { - check(" "); + check(" @1"); - auto ac = autocomplete(0, 1); + auto ac = autocomplete('1'); CHECK(!ac.entryMap.empty()); CHECK(ac.entryMap.count("table")); @@ -93,26 +110,26 @@ TEST_CASE_FIXTURE(ACFixture, "empty_program") TEST_CASE_FIXTURE(ACFixture, "local_initializer") { - check("local a = "); + check("local a = @1"); - auto ac = autocomplete(0, 10); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("table")); CHECK(ac.entryMap.count("math")); } TEST_CASE_FIXTURE(ACFixture, "leave_numbers_alone") { - check("local a = 3.1"); + check("local a = 3.@11"); - auto ac = autocomplete(0, 12); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "user_defined_globals") { - check("local myLocal = 4; "); + check("local myLocal = 4; @1"); - auto ac = autocomplete(0, 19); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("myLocal")); CHECK(ac.entryMap.count("table")); @@ -124,20 +141,20 @@ TEST_CASE_FIXTURE(ACFixture, "dont_suggest_local_before_its_definition") check(R"( local myLocal = 4 function abc() - local myInnerLocal = 1 - +@1 local myInnerLocal = 1 +@2 end - )"); +@3 )"); - auto ac = autocomplete(3, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("myLocal")); CHECK(!ac.entryMap.count("myInnerLocal")); - ac = autocomplete(4, 0); + ac = autocomplete('2'); CHECK(ac.entryMap.count("myLocal")); CHECK(ac.entryMap.count("myInnerLocal")); - ac = autocomplete(6, 0); + ac = autocomplete('3'); CHECK(ac.entryMap.count("myLocal")); CHECK(!ac.entryMap.count("myInnerLocal")); } @@ -146,10 +163,10 @@ TEST_CASE_FIXTURE(ACFixture, "recursive_function") { check(R"( function foo() - end +@1 end )"); - auto ac = autocomplete(2, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("foo")); } @@ -158,11 +175,11 @@ TEST_CASE_FIXTURE(ACFixture, "nested_recursive_function") check(R"( local function outer() local function inner() - end +@1 end end )"); - auto ac = autocomplete(3, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("inner")); CHECK(ac.entryMap.count("outer")); } @@ -171,11 +188,11 @@ TEST_CASE_FIXTURE(ACFixture, "user_defined_local_functions_in_own_definition") { check(R"( local function abc() - +@1 end )"); - auto ac = autocomplete(2, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("abc")); CHECK(ac.entryMap.count("table")); @@ -183,11 +200,11 @@ TEST_CASE_FIXTURE(ACFixture, "user_defined_local_functions_in_own_definition") check(R"( local abc = function() - +@1 end )"); - ac = autocomplete(2, 0); + ac = autocomplete('1'); CHECK(ac.entryMap.count("abc")); // FIXME: This is actually incorrect! CHECK(ac.entryMap.count("table")); @@ -202,9 +219,9 @@ TEST_CASE_FIXTURE(ACFixture, "global_functions_are_not_scoped_lexically") end end - )"); +@1 )"); - auto ac = autocomplete(6, 0); + auto ac = autocomplete('1'); CHECK(!ac.entryMap.empty()); CHECK(ac.entryMap.count("abc")); @@ -220,9 +237,9 @@ TEST_CASE_FIXTURE(ACFixture, "local_functions_fall_out_of_scope") end end - )"); +@1 )"); - auto ac = autocomplete(6, 0); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); CHECK(!ac.entryMap.count("abc")); @@ -233,10 +250,10 @@ TEST_CASE_FIXTURE(ACFixture, "function_parameters") check(R"( function abc(test) - end +@1 end )"); - auto ac = autocomplete(3, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("test")); } @@ -244,11 +261,10 @@ TEST_CASE_FIXTURE(ACFixture, "function_parameters") TEST_CASE_FIXTURE(ACFixture, "get_member_completions") { check(R"( - local a = table. -- Line 1 - -- | Column 23 + local a = table.@1 )"); - auto ac = autocomplete(1, 24); + auto ac = autocomplete('1'); CHECK_EQ(16, ac.entryMap.size()); CHECK(ac.entryMap.count("find")); @@ -260,10 +276,10 @@ TEST_CASE_FIXTURE(ACFixture, "nested_member_completions") { check(R"( local tbl = { abc = { def = 1234, egh = false } } - tbl.abc. + tbl.abc. @1 )"); - auto ac = autocomplete(2, 17); + auto ac = autocomplete('1'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("def")); CHECK(ac.entryMap.count("egh")); @@ -274,10 +290,10 @@ TEST_CASE_FIXTURE(ACFixture, "unsealed_table") check(R"( local tbl = {} tbl.prop = 5 - tbl. + tbl.@1 )"); - auto ac = autocomplete(3, 12); + auto ac = autocomplete('1'); CHECK_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("prop")); } @@ -288,10 +304,10 @@ TEST_CASE_FIXTURE(ACFixture, "unsealed_table_2") local tbl = {} local inner = { prop = 5 } tbl.inner = inner - tbl.inner. + tbl.inner. @1 )"); - auto ac = autocomplete(4, 19); + auto ac = autocomplete('1'); CHECK_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("prop")); } @@ -302,10 +318,10 @@ TEST_CASE_FIXTURE(ACFixture, "cyclic_table") local abc = {} local def = { abc = abc } abc.def = def - abc.def. + abc.def. @1 )"); - auto ac = autocomplete(4, 17); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("abc")); } @@ -315,11 +331,11 @@ TEST_CASE_FIXTURE(ACFixture, "table_union") type t1 = { a1 : string, b2 : number } type t2 = { b2 : string, c3 : string } function func(abc : t1 | t2) - abc. + abc. @1 end )"); - auto ac = autocomplete(4, 18); + auto ac = autocomplete('1'); CHECK_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("b2")); } @@ -330,11 +346,11 @@ TEST_CASE_FIXTURE(ACFixture, "table_intersection") type t1 = { a1 : string, b2 : number } type t2 = { b2 : string, c3 : string } function func(abc : t1 & t2) - abc. + abc. @1 end )"); - auto ac = autocomplete(4, 18); + auto ac = autocomplete('1'); CHECK_EQ(3, ac.entryMap.size()); CHECK(ac.entryMap.count("a1")); CHECK(ac.entryMap.count("b2")); @@ -344,20 +360,19 @@ TEST_CASE_FIXTURE(ACFixture, "table_intersection") TEST_CASE_FIXTURE(ACFixture, "get_string_completions") { check(R"( - local a = ("foo"): -- Line 1 - -- | Column 26 + local a = ("foo"):@1 )"); - auto ac = autocomplete(1, 26); + auto ac = autocomplete('1'); CHECK_EQ(17, ac.entryMap.size()); } TEST_CASE_FIXTURE(ACFixture, "get_suggestions_for_new_statement") { - check(""); + check("@1"); - auto ac = autocomplete(0, 0); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); @@ -366,12 +381,12 @@ TEST_CASE_FIXTURE(ACFixture, "get_suggestions_for_new_statement") TEST_CASE_FIXTURE(ACFixture, "get_suggestions_for_the_very_start_of_the_script") { - check(R"( + check(R"(@1 function aaa() end )"); - auto ac = autocomplete(0, 0); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("table")); } @@ -382,11 +397,11 @@ TEST_CASE_FIXTURE(ACFixture, "method_call_inside_function_body") local game = { GetService=function(s) return 'hello' end } function a() - game: + game: @1 end )"); - auto ac = autocomplete(4, 19); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); @@ -396,10 +411,10 @@ TEST_CASE_FIXTURE(ACFixture, "method_call_inside_function_body") TEST_CASE_FIXTURE(ACFixture, "method_call_inside_if_conditional") { check(R"( - if table: + if table: @1 )"); - auto ac = autocomplete(1, 19); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); CHECK(ac.entryMap.count("concat")); @@ -411,12 +426,12 @@ TEST_CASE_FIXTURE(ACFixture, "statement_between_two_statements") check(R"( function getmyscripts() end - g + g@1 getmyscripts() )"); - auto ac = autocomplete(3, 9); + auto ac = autocomplete('1'); CHECK_NE(0, ac.entryMap.size()); @@ -431,11 +446,11 @@ TEST_CASE_FIXTURE(ACFixture, "bias_toward_inner_scope") function B() local A = {two=2} - A + A @1 end )"); - auto ac = autocomplete(6, 15); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("A")); @@ -448,12 +463,12 @@ TEST_CASE_FIXTURE(ACFixture, "bias_toward_inner_scope") TEST_CASE_FIXTURE(ACFixture, "recommend_statement_starting_keywords") { - check(""); - auto ac = autocomplete(0, 0); + check("@1"); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("local")); - check("local i = "); - auto ac2 = autocomplete(0, 10); + check("local i = @1"); + auto ac2 = autocomplete('1'); CHECK(!ac2.entryMap.count("local")); } @@ -464,9 +479,9 @@ TEST_CASE_FIXTURE(ACFixture, "do_not_overwrite_context_sensitive_kws") end - )"); +@1 )"); - auto ac = autocomplete(5, 0); + auto ac = autocomplete('1'); AutocompleteEntry entry = ac.entryMap["continue"]; CHECK(entry.kind == AutocompleteEntryKind::Binding); @@ -480,11 +495,11 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_comment") function foo:bar() end --[[ - foo: + foo:@1 ]] )"); - auto ac = autocomplete(6, 16); + auto ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); } @@ -492,10 +507,10 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_comment") TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_the_end_of_a_comment") { check(R"( - --!strict + --!strict@1 )"); - auto ac = autocomplete(1, 17); + auto ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); } @@ -505,10 +520,10 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_co ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; check(R"( - --[[ + --[[ @1 )"); - auto ac = autocomplete(1, 13); + auto ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); } @@ -517,129 +532,129 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_co { ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - check("--[["); + check("--[[@1"); - auto ac = autocomplete(0, 4); + auto ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_middle_keywords") { check(R"( - for x = + for x @1= )"); - auto ac1 = autocomplete(1, 14); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("do"), 0); CHECK_EQ(ac1.entryMap.count("end"), 0); check(R"( - for x = 1 + for x =@1 1 )"); - auto ac2 = autocomplete(1, 15); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("do"), 0); CHECK_EQ(ac2.entryMap.count("end"), 0); check(R"( - for x = 1, 2 + for x = 1,@1 2 )"); - auto ac3 = autocomplete(1, 18); + auto ac3 = autocomplete('1'); CHECK_EQ(1, ac3.entryMap.size()); CHECK_EQ(ac3.entryMap.count("do"), 1); check(R"( - for x = 1, 2, + for x = 1, @12, )"); - auto ac4 = autocomplete(1, 19); + auto ac4 = autocomplete('1'); CHECK_EQ(ac4.entryMap.count("do"), 0); CHECK_EQ(ac4.entryMap.count("end"), 0); check(R"( - for x = 1, 2, 5 + for x = 1, 2, @15 )"); - auto ac5 = autocomplete(1, 22); + auto ac5 = autocomplete('1'); CHECK_EQ(ac5.entryMap.count("do"), 1); CHECK_EQ(ac5.entryMap.count("end"), 0); check(R"( - for x = 1, 2, 5 f + for x = 1, 2, 5 f@1 )"); - auto ac6 = autocomplete(1, 25); + auto ac6 = autocomplete('1'); CHECK_EQ(ac6.entryMap.size(), 1); CHECK_EQ(ac6.entryMap.count("do"), 1); check(R"( - for x = 1, 2, 5 do + for x = 1, 2, 5 do @1 )"); - auto ac7 = autocomplete(1, 32); + auto ac7 = autocomplete('1'); CHECK_EQ(ac7.entryMap.count("end"), 1); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_in_middle_keywords") { check(R"( - for + for @1 )"); - auto ac1 = autocomplete(1, 12); + auto ac1 = autocomplete('1'); CHECK_EQ(0, ac1.entryMap.size()); check(R"( - for x + for x@1 @2 )"); - auto ac2 = autocomplete(1, 13); + auto ac2 = autocomplete('1'); CHECK_EQ(0, ac2.entryMap.size()); - auto ac2a = autocomplete(1, 14); + auto ac2a = autocomplete('2'); CHECK_EQ(1, ac2a.entryMap.size()); CHECK_EQ(1, ac2a.entryMap.count("in")); check(R"( - for x in y + for x in y@1 )"); - auto ac3 = autocomplete(1, 18); + auto ac3 = autocomplete('1'); CHECK_EQ(ac3.entryMap.count("table"), 1); CHECK_EQ(ac3.entryMap.count("do"), 0); check(R"( - for x in y + for x in y @1 )"); - auto ac4 = autocomplete(1, 19); + auto ac4 = autocomplete('1'); CHECK_EQ(ac4.entryMap.size(), 1); CHECK_EQ(ac4.entryMap.count("do"), 1); check(R"( - for x in f f + for x in f f@1 )"); - auto ac5 = autocomplete(1, 20); + auto ac5 = autocomplete('1'); CHECK_EQ(ac5.entryMap.size(), 1); CHECK_EQ(ac5.entryMap.count("do"), 1); check(R"( - for x in y do + for x in y do @1 )"); - auto ac6 = autocomplete(1, 23); + auto ac6 = autocomplete('1'); CHECK_EQ(ac6.entryMap.count("in"), 0); CHECK_EQ(ac6.entryMap.count("table"), 1); CHECK_EQ(ac6.entryMap.count("end"), 1); CHECK_EQ(ac6.entryMap.count("function"), 1); check(R"( - for x in y do e + for x in y do e@1 )"); - auto ac7 = autocomplete(1, 23); + auto ac7 = autocomplete('1'); CHECK_EQ(ac7.entryMap.count("in"), 0); CHECK_EQ(ac7.entryMap.count("table"), 1); CHECK_EQ(ac7.entryMap.count("end"), 1); @@ -649,33 +664,33 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_for_in_middle_keywords") TEST_CASE_FIXTURE(ACFixture, "autocomplete_while_middle_keywords") { check(R"( - while + while@1 )"); - auto ac1 = autocomplete(1, 13); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("do"), 0); CHECK_EQ(ac1.entryMap.count("end"), 0); check(R"( - while true + while true @1 )"); - auto ac2 = autocomplete(1, 19); + auto ac2 = autocomplete('1'); CHECK_EQ(1, ac2.entryMap.size()); CHECK_EQ(ac2.entryMap.count("do"), 1); check(R"( - while true do + while true do @1 )"); - auto ac3 = autocomplete(1, 23); + auto ac3 = autocomplete('1'); CHECK_EQ(ac3.entryMap.count("end"), 1); check(R"( - while true d + while true d@1 )"); - auto ac4 = autocomplete(1, 20); + auto ac4 = autocomplete('1'); CHECK_EQ(1, ac4.entryMap.size()); CHECK_EQ(ac4.entryMap.count("do"), 1); } @@ -683,10 +698,10 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_while_middle_keywords") TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") { check(R"( - if + if @1 )"); - auto ac1 = autocomplete(1, 13); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("then"), 0); CHECK_EQ(ac1.entryMap.count("function"), 1); // FIXME: This is kind of dumb. It is technically syntactically valid but you can never do anything interesting with this. @@ -696,10 +711,10 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") CHECK_EQ(ac1.entryMap.count("end"), 0); check(R"( - if x + if x @1 )"); - auto ac2 = autocomplete(1, 14); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("then"), 1); CHECK_EQ(ac2.entryMap.count("function"), 0); CHECK_EQ(ac2.entryMap.count("else"), 0); @@ -707,20 +722,20 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") CHECK_EQ(ac2.entryMap.count("end"), 0); check(R"( - if x t + if x t@1 )"); - auto ac3 = autocomplete(1, 14); + auto ac3 = autocomplete('1'); CHECK_EQ(1, ac3.entryMap.size()); CHECK_EQ(ac3.entryMap.count("then"), 1); check(R"( if x then - +@1 end )"); - auto ac4 = autocomplete(2, 0); + auto ac4 = autocomplete('1'); CHECK_EQ(ac4.entryMap.count("then"), 0); CHECK_EQ(ac4.entryMap.count("else"), 1); CHECK_EQ(ac4.entryMap.count("function"), 1); @@ -729,11 +744,11 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") check(R"( if x then - t + t@1 end )"); - auto ac4a = autocomplete(2, 13); + auto ac4a = autocomplete('1'); CHECK_EQ(ac4a.entryMap.count("then"), 0); CHECK_EQ(ac4a.entryMap.count("table"), 1); CHECK_EQ(ac4a.entryMap.count("else"), 1); @@ -741,12 +756,12 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") check(R"( if x then - +@1 elseif x then end )"); - auto ac5 = autocomplete(2, 0); + auto ac5 = autocomplete('1'); CHECK_EQ(ac5.entryMap.count("then"), 0); CHECK_EQ(ac5.entryMap.count("function"), 1); CHECK_EQ(ac5.entryMap.count("else"), 0); @@ -757,10 +772,10 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_middle_keywords") TEST_CASE_FIXTURE(ACFixture, "autocomplete_until_in_repeat") { check(R"( - repeat + repeat @1 )"); - auto ac = autocomplete(1, 16); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("table"), 1); CHECK_EQ(ac.entryMap.count("until"), 1); } @@ -769,48 +784,48 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_until_expression") { check(R"( repeat - until + until @1 )"); - auto ac = autocomplete(2, 16); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("table"), 1); } TEST_CASE_FIXTURE(ACFixture, "local_names") { check(R"( - local ab + local ab@1 )"); - auto ac1 = autocomplete(1, 16); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.size(), 1); CHECK_EQ(ac1.entryMap.count("function"), 1); check(R"( - local ab, cd + local ab, cd@1 )"); - auto ac2 = autocomplete(1, 20); + auto ac2 = autocomplete('1'); CHECK(ac2.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_end_with_fn_exprs") { check(R"( - local function f() + local function f() @1 )"); - auto ac = autocomplete(1, 28); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("end"), 1); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_end_with_lambda") { check(R"( - local a = function() local bar = foo en + local a = function() local bar = foo en@1 )"); - auto ac = autocomplete(1, 47); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("end"), 1); } @@ -818,10 +833,10 @@ TEST_CASE_FIXTURE(ACFixture, "stop_at_first_stat_when_recommending_keywords") { check(R"( repeat - for x + for x @1 )"); - auto ac1 = autocomplete(2, 18); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("in"), 1); CHECK_EQ(ac1.entryMap.count("until"), 0); } @@ -829,112 +844,112 @@ TEST_CASE_FIXTURE(ACFixture, "stop_at_first_stat_when_recommending_keywords") TEST_CASE_FIXTURE(ACFixture, "autocomplete_repeat_middle_keyword") { check(R"( - repeat + repeat @1 )"); - auto ac1 = autocomplete(1, 15); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.count("do"), 1); CHECK_EQ(ac1.entryMap.count("function"), 1); CHECK_EQ(ac1.entryMap.count("until"), 1); check(R"( - repeat f f + repeat f f@1 )"); - auto ac2 = autocomplete(1, 18); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("function"), 1); CHECK_EQ(ac2.entryMap.count("until"), 1); check(R"( repeat - u + u@1 until )"); - auto ac3 = autocomplete(2, 13); + auto ac3 = autocomplete('1'); CHECK_EQ(ac3.entryMap.count("until"), 0); } TEST_CASE_FIXTURE(ACFixture, "local_function") { check(R"( - local f + local f@1 )"); - auto ac1 = autocomplete(1, 15); + auto ac1 = autocomplete('1'); CHECK_EQ(ac1.entryMap.size(), 1); CHECK_EQ(ac1.entryMap.count("function"), 1); check(R"( - local f, cd + local f@1, cd )"); - auto ac2 = autocomplete(1, 15); + auto ac2 = autocomplete('1'); CHECK(ac2.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "local_function") { check(R"( - local function + local function @1 )"); - auto ac = autocomplete(1, 23); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); check(R"( - local function s + local function @1s@2 )"); - ac = autocomplete(1, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); - ac = autocomplete(1, 24); + ac = autocomplete('2'); CHECK(ac.entryMap.empty()); check(R"( - local function () + local function @1()@2 )"); - ac = autocomplete(1, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); - ac = autocomplete(1, 25); + ac = autocomplete('2'); CHECK(ac.entryMap.count("end")); check(R"( - local function something + local function something@1 )"); - ac = autocomplete(1, 32); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); check(R"( local tbl = {} - function tbl.something() end + function tbl.something@1() end )"); - ac = autocomplete(2, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "local_function_params") { check(R"( - local function abc(def) + local function @1a@2bc(@3d@4ef)@5 @6 )"); - CHECK(autocomplete(1, 23).entryMap.empty()); - CHECK(autocomplete(1, 24).entryMap.empty()); - CHECK(autocomplete(1, 27).entryMap.empty()); - CHECK(autocomplete(1, 28).entryMap.empty()); - CHECK(!autocomplete(1, 31).entryMap.empty()); + CHECK(autocomplete('1').entryMap.empty()); + CHECK(autocomplete('2').entryMap.empty()); + CHECK(autocomplete('3').entryMap.empty()); + CHECK(autocomplete('4').entryMap.empty()); + CHECK(!autocomplete('5').entryMap.empty()); - CHECK(!autocomplete(1, 32).entryMap.empty()); + CHECK(!autocomplete('6').entryMap.empty()); check(R"( local function abc(def) - end +@1 end )"); for (unsigned int i = 23; i < 31; ++i) @@ -943,16 +958,16 @@ TEST_CASE_FIXTURE(ACFixture, "local_function_params") } CHECK(!autocomplete(1, 32).entryMap.empty()); - auto ac2 = autocomplete(2, 0); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("abc"), 1); CHECK_EQ(ac2.entryMap.count("def"), 1); check(R"( - local function abc(def, ghi) + local function abc(def, ghi@1) end )"); - auto ac3 = autocomplete(1, 35); + auto ac3 = autocomplete('1'); CHECK(ac3.entryMap.empty()); } @@ -981,48 +996,48 @@ TEST_CASE_FIXTURE(ACFixture, "global_function_params") check(R"( function abc(def) - +@1 end )"); - auto ac2 = autocomplete(2, 0); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("abc"), 1); CHECK_EQ(ac2.entryMap.count("def"), 1); check(R"( - function abc(def, ghi) + function abc(def, ghi@1) end )"); - auto ac3 = autocomplete(1, 29); + auto ac3 = autocomplete('1'); CHECK(ac3.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "arguments_to_global_lambda") { check(R"( - abc = function(def, ghi) + abc = function(def, ghi@1) end )"); - auto ac = autocomplete(1, 31); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } TEST_CASE_FIXTURE(ACFixture, "function_expr_params") { check(R"( - abc = function(def) + abc = function(def) @1 )"); for (unsigned int i = 20; i < 27; ++i) { CHECK(autocomplete(1, i).entryMap.empty()); } - CHECK(!autocomplete(1, 28).entryMap.empty()); + CHECK(!autocomplete('1').entryMap.empty()); check(R"( - abc = function(def) + abc = function(def) @1 end )"); @@ -1030,25 +1045,25 @@ TEST_CASE_FIXTURE(ACFixture, "function_expr_params") { CHECK(autocomplete(1, i).entryMap.empty()); } - CHECK(!autocomplete(1, 28).entryMap.empty()); + CHECK(!autocomplete('1').entryMap.empty()); check(R"( abc = function(def) - +@1 end )"); - auto ac2 = autocomplete(2, 0); + auto ac2 = autocomplete('1'); CHECK_EQ(ac2.entryMap.count("def"), 1); } TEST_CASE_FIXTURE(ACFixture, "local_initializer") { check(R"( - local a = t + local a = t@1 )"); - auto ac = autocomplete(1, 19); + auto ac = autocomplete('1'); CHECK_EQ(ac.entryMap.count("table"), 1); CHECK_EQ(ac.entryMap.count("true"), 1); } @@ -1056,20 +1071,20 @@ TEST_CASE_FIXTURE(ACFixture, "local_initializer") TEST_CASE_FIXTURE(ACFixture, "local_initializer_2") { check(R"( - local a= + local a=@1 )"); - auto ac = autocomplete(1, 16); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("table")); } TEST_CASE_FIXTURE(ACFixture, "get_member_completions") { check(R"( - local a = 12.3 + local a = 12.@13 )"); - auto ac = autocomplete(1, 21); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } @@ -1083,21 +1098,21 @@ TEST_CASE_FIXTURE(ACFixture, "sometimes_the_metatable_is_an_error") return setmetatable({x=6}, X) -- oops! end local t = T.new() - t. + t. @1 )"); - autocomplete(8, 12); + autocomplete('1'); // Don't crash! } TEST_CASE_FIXTURE(ACFixture, "local_types_builtin") { check(R"( -local a: n +local a: n@1 local b: string = "don't trip" )"); - auto ac = autocomplete(1, 10); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); @@ -1108,23 +1123,23 @@ TEST_CASE_FIXTURE(ACFixture, "private_types") check(R"( do type num = number - local a: nu - local b: num + local a: n@1u + local b: nu@2m end -local a: nu +local a: nu@3 )"); - auto ac = autocomplete(3, 14); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("num")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(4, 15); + ac = autocomplete('2'); CHECK(ac.entryMap.count("num")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(6, 11); + ac = autocomplete('3'); CHECK(!ac.entryMap.count("num")); CHECK(ac.entryMap.count("number")); @@ -1136,11 +1151,11 @@ TEST_CASE_FIXTURE(ACFixture, "type_scoping_easy") type Table = { a: number, b: number } do type Table = { x: string, y: string } - local a: T + local a: T@1 end )"); - auto ac = autocomplete(4, 14); + auto ac = autocomplete('1'); REQUIRE(ac.entryMap.count("Table")); REQUIRE(ac.entryMap["Table"].type); @@ -1198,11 +1213,11 @@ local a: aaa. TEST_CASE_FIXTURE(ACFixture, "argument_types") { check(R"( -local function f(a: n +local function f(a: n@1 local b: string = "don't trip" )"); - auto ac = autocomplete(1, 21); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); @@ -1211,11 +1226,11 @@ local b: string = "don't trip" TEST_CASE_FIXTURE(ACFixture, "return_types") { check(R"( -local function f(a: number): n +local function f(a: number): n@1 local b: string = "don't trip" )"); - auto ac = autocomplete(1, 30); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); @@ -1225,10 +1240,10 @@ TEST_CASE_FIXTURE(ACFixture, "as_types") { check(R"( local a: any = 5 -local b: number = (a :: n +local b: number = (a :: n@1 )"); - auto ac = autocomplete(2, 25); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); @@ -1237,34 +1252,34 @@ local b: number = (a :: n TEST_CASE_FIXTURE(ACFixture, "function_type_types") { check(R"( -local a: (n -local b: (number, (n -local c: (number, (number) -> n -local d: (number, (number) -> (number, n -local e: (n: n +local a: (n@1 +local b: (number, (n@2 +local c: (number, (number) -> n@3 +local d: (number, (number) -> (number, n@4 +local e: (n: n@5 )"); - auto ac = autocomplete(1, 11); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(2, 20); + ac = autocomplete('2'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(3, 31); + ac = autocomplete('3'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(4, 40); + ac = autocomplete('4'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); - ac = autocomplete(5, 14); + ac = autocomplete('5'); CHECK(ac.entryMap.count("nil")); CHECK(ac.entryMap.count("number")); @@ -1276,11 +1291,11 @@ TEST_CASE_FIXTURE(ACFixture, "generic_types") ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); check(R"( -function f(a: T +function f(a: T@1 local b: string = "don't trip" )"); - auto ac = autocomplete(1, 25); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("Tee")); } @@ -1293,10 +1308,10 @@ local function target(a: number, b: string) return a + #b end local one = 4 local two = "hello" -return target(o +return target(o@1 )"); - auto ac = autocomplete(5, 15); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -1307,10 +1322,10 @@ local function target(a: number, b: string) return a + #b end local one = 4 local two = "hello" -return target(one, t +return target(one, t@1 )"); - ac = autocomplete(5, 20); + ac = autocomplete('1'); CHECK(ac.entryMap.count("two")); CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); @@ -1321,10 +1336,10 @@ return target(one, t local function target(a: number, b: string) return a + #b end local a = { one = 4, two = "hello" } -return target(a. +return target(a.@1 )"); - ac = autocomplete(4, 16); + ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -1334,10 +1349,10 @@ return target(a. local function target(a: number, b: string) return a + #b end local a = { one = 4, two = "hello" } -return target(a.one, a. +return target(a.one, a.@1 )"); - ac = autocomplete(4, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.count("two")); CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); @@ -1348,10 +1363,10 @@ return target(a.one, a. local function target(a: string?) return #b end local a = { one = 4, two = "hello" } -return target(a. +return target(a.@1 )"); - ac = autocomplete(4, 16); + ac = autocomplete('1'); CHECK(ac.entryMap.count("two")); CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); @@ -1363,10 +1378,10 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_suggestion_in_table") check(R"( type Foo = { a: number, b: string } local a = { one = 4, two = "hello" } -local b: Foo = { a = a. +local b: Foo = { a = a.@1 )"); - auto ac = autocomplete(3, 23); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -1375,10 +1390,10 @@ local b: Foo = { a = a. check(R"( type Foo = { a: number, b: string } local a = { one = 4, two = "hello" } -local b: Foo = { b = a. +local b: Foo = { b = a.@1 )"); - ac = autocomplete(3, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.count("two")); CHECK(ac.entryMap["two"].typeCorrect == TypeCorrectKind::Correct); @@ -1392,10 +1407,10 @@ local function target(a: number, b: string) return a + #b end local function bar1(a: number) return -a end local function bar2(a: string) reutrn a .. 'x' end -return target(b +return target(b@1 )"); - auto ac = autocomplete(5, 15); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("bar1")); CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); @@ -1406,10 +1421,10 @@ local function target(a: number, b: string) return a + #b end local function bar1(a: number) return -a end local function bar2(a: string) return a .. 'x' end -return target(bar1, b +return target(bar1, b@1 )"); - ac = autocomplete(5, 21); + ac = autocomplete('1'); CHECK(ac.entryMap.count("bar2")); CHECK(ac.entryMap["bar2"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); @@ -1420,10 +1435,10 @@ local function target(a: number, b: string) return a + #b end local function bar1(a: number): (...number) return -a, a end local function bar2(a: string) reutrn a .. 'x' end -return target(b +return target(b@1 )"); - ac = autocomplete(5, 15); + ac = autocomplete('1'); CHECK(ac.entryMap.count("bar1")); CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); @@ -1433,69 +1448,69 @@ return target(b TEST_CASE_FIXTURE(ACFixture, "type_correct_local_type_suggestion") { check(R"( -local b: s = "str" +local b: s@1 = "str" )"); - auto ac = autocomplete(1, 10); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( local function f() return "str" end -local b: s = f() +local b: s@1 = f() )"); - ac = autocomplete(2, 10); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: s, c: n = "str", 2 +local b: s@1, c: n@2 = "str", 2 )"); - ac = autocomplete(1, 10); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(1, 16); + ac = autocomplete('2'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); check(R"( local function f() return 1, "str", 3 end -local a: b, b: n, c: s, d: n = false, f() +local a: b@1, b: n@2, c: s@3, d: n@4 = false, f() )"); - ac = autocomplete(2, 10); + ac = autocomplete('1'); CHECK(ac.entryMap.count("boolean")); CHECK(ac.entryMap["boolean"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(2, 16); + ac = autocomplete('2'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(2, 22); + ac = autocomplete('3'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(2, 28); + ac = autocomplete('4'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); check(R"( local function f(): ...number return 1, 2, 3 end -local a: boolean, b: n = false, f() +local a: boolean, b: n@1 = false, f() )"); - ac = autocomplete(2, 22); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1504,46 +1519,46 @@ local a: boolean, b: n = false, f() TEST_CASE_FIXTURE(ACFixture, "type_correct_function_type_suggestion") { check(R"( -local b: (n) -> number = function(a: number, b: string) return a + #b end +local b: (n@1) -> number = function(a: number, b: string) return a + #b end )"); - auto ac = autocomplete(1, 11); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: (number, s = function(a: number, b: string) return a + #b end +local b: (number, s@1 = function(a: number, b: string) return a + #b end )"); - ac = autocomplete(1, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: (number, string) -> b = function(a: number, b: string): boolean return a + #b == 0 end +local b: (number, string) -> b@1 = function(a: number, b: string): boolean return a + #b == 0 end )"); - ac = autocomplete(1, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.count("boolean")); CHECK(ac.entryMap["boolean"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: (number, ...s) = function(a: number, ...: string) return a end +local b: (number, ...s@1) = function(a: number, ...: string) return a end )"); - ac = autocomplete(1, 22); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: (number) -> ...s = function(a: number): ...string return "a", "b", "c" end +local b: (number) -> ...s@1 = function(a: number): ...string return "a", "b", "c" end )"); - ac = autocomplete(1, 25); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); @@ -1552,24 +1567,24 @@ local b: (number) -> ...s = function(a: number): ...string return "a", "b", "c" TEST_CASE_FIXTURE(ACFixture, "type_correct_full_type_suggestion") { check(R"( -local b: = "str" +local b:@1 @2= "str" )"); - auto ac = autocomplete(1, 8); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(1, 9); + ac = autocomplete('2'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); check(R"( -local b: = function(a: number) return -a end +local b: @1= function(a: number) return -a end )"); - ac = autocomplete(1, 9); + ac = autocomplete('1'); CHECK(ac.entryMap.count("(number) -> number")); CHECK(ac.entryMap["(number) -> number"].typeCorrect == TypeCorrectKind::Correct); @@ -1580,12 +1595,12 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_argument_type_suggestion") check(R"( local function target(a: number, b: string) return a + #b end -local function d(a: n, b) +local function d(a: n@1, b) return target(a, b) end )"); - auto ac = autocomplete(3, 21); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1593,12 +1608,12 @@ end check(R"( local function target(a: number, b: string) return a + #b end -local function d(a, b: s) +local function d(a, b: s@1) return target(a, b) end )"); - ac = autocomplete(3, 24); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); @@ -1606,17 +1621,17 @@ end check(R"( local function target(a: number, b: string) return a + #b end -local function d(a: , b) +local function d(a:@1 @2, b) return target(a, b) end )"); - ac = autocomplete(3, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(3, 20); + ac = autocomplete('2'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1624,17 +1639,17 @@ end check(R"( local function target(a: number, b: string) return a + #b end -local function d(a, b: ): number +local function d(a, b: @1)@2: number return target(a, b) end )"); - ac = autocomplete(3, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(3, 24); + ac = autocomplete('2'); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::None); } @@ -1644,10 +1659,10 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_suggestion") check(R"( local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end -local x = target(function(a: +local x = target(function(a: @1 )"); - auto ac = autocomplete(3, 29); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1655,10 +1670,10 @@ local x = target(function(a: check(R"( local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end -local x = target(function(a: n +local x = target(function(a: n@1 )"); - ac = autocomplete(3, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1666,17 +1681,17 @@ local x = target(function(a: n check(R"( local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end -local x = target(function(a: n, b: ) +local x = target(function(a: n@1, b: @2) return a + #b end) )"); - ac = autocomplete(3, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(3, 35); + ac = autocomplete('2'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); @@ -1684,12 +1699,12 @@ end) check(R"( local function target(callback: (...number) -> number) return callback(1, 2, 3) end -local x = target(function(a: n) +local x = target(function(a: n@1) return a end )"); - ac = autocomplete(3, 30); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1700,12 +1715,12 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_pack_suggestio check(R"( local function target(callback: (...number) -> number) return callback(1, 2, 3) end -local x = target(function(...:n) +local x = target(function(...:n@1) return a end )"); - auto ac = autocomplete(3, 31); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1713,12 +1728,12 @@ end check(R"( local function target(callback: (...number) -> number) return callback(1, 2, 3) end -local x = target(function(a:number, b:number, ...:) +local x = target(function(a:number, b:number, ...:@1) return a + b end )"); - ac = autocomplete(3, 50); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1729,12 +1744,12 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_suggestion") check(R"( local function target(callback: () -> number) return callback() end -local x = target(function(): n +local x = target(function(): n@1 return 1 end )"); - auto ac = autocomplete(3, 30); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1742,12 +1757,12 @@ end check(R"( local function target(callback: () -> (number, number)) return callback() end -local x = target(function(): (number, n +local x = target(function(): (number, n@1 return 1, 2 end )"); - ac = autocomplete(3, 39); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1758,12 +1773,12 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_pack_suggestion" check(R"( local function target(callback: () -> ...number) return callback() end -local x = target(function(): ...n +local x = target(function(): ...n@1 return 1, 2, 3 end )"); - auto ac = autocomplete(3, 33); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1771,12 +1786,12 @@ end check(R"( local function target(callback: () -> ...number) return callback() end -local x = target(function(): (number, number, ...n +local x = target(function(): (number, number, ...n@1 return 1, 2, 3 end )"); - ac = autocomplete(3, 50); + ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1787,10 +1802,10 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_suggestion_opt check(R"( local function target(callback: nil | (a: number, b: string) -> number) return callback(4, "hello") end -local x = target(function(a: +local x = target(function(a: @1 )"); - auto ac = autocomplete(3, 29); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); @@ -1803,21 +1818,21 @@ local t = {} t.x = 5 function t:target(callback: (a: number, b: string) -> number) return callback(self.x, "hello") end -local x = t:target(function(a: , b: ) end) -local y = t.target(t, function(a: number, b: ) end) +local x = t:target(function(a: @1, b:@2 ) end) +local y = t.target(t, function(a: number, b: @3) end) )"); - auto ac = autocomplete(5, 31); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("number")); CHECK(ac.entryMap["number"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(5, 35); + ac = autocomplete('2'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(6, 45); + ac = autocomplete('3'); CHECK(ac.entryMap.count("string")); CHECK(ac.entryMap["string"].typeCorrect == TypeCorrectKind::Correct); @@ -1899,26 +1914,26 @@ TEST_CASE_FIXTURE(ACFixture, "do_not_suggest_synthetic_table_name") { check(R"( local foo = { a = 1, b = 2 } -local bar: = foo +local bar: @1= foo )"); - auto ac = autocomplete(2, 11); + auto ac = autocomplete('1'); CHECK(!ac.entryMap.count("foo")); } -// CLI-45692: Remove UnfrozenFixture here -TEST_CASE_FIXTURE(UnfrozenFixture, "type_correct_function_no_parenthesis") +// CLI-45692: Remove UnfrozenACFixture here +TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_function_no_parenthesis") { check(R"( local function target(a: (number) -> number) return a(4) end local function bar1(a: number) return -a end local function bar2(a: string) reutrn a .. 'x' end -return target(b +return target(b@1 )"); - auto ac = autocomplete(frontend, "MainModule", Position{5, 15}, nullCallback); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("bar1")); CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::Correct); @@ -1930,16 +1945,16 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_sealed_table") { check(R"( local function f(a: { x: number, y: number }) return a.x + a.y end -local fp: = f +local fp: @1= f )"); - auto ac = autocomplete(2, 10); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); } -// CLI-45692: Remove UnfrozenFixture here -TEST_CASE_FIXTURE(UnfrozenFixture, "type_correct_keywords") +// CLI-45692: Remove UnfrozenACFixture here +TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_keywords") { check(R"( local function a(x: boolean) end @@ -1951,33 +1966,33 @@ local function e(x: ((number) -> string) & ((boolean) -> number)) end local tru = {} local ni = false -local ac = a(t) -local bc = b(n) -local cc = c(f) -local dc = d(f) -local ec = e(f) +local ac = a(t@1) +local bc = b(n@2) +local cc = c(f@3) +local dc = d(f@4) +local ec = e(f@5) )"); - auto ac = autocomplete(frontend, "MainModule", Position{10, 14}, nullCallback); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("tru")); CHECK(ac.entryMap["tru"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["true"].typeCorrect == TypeCorrectKind::Correct); CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(frontend, "MainModule", Position{11, 14}, nullCallback); + ac = autocomplete('2'); CHECK(ac.entryMap.count("ni")); CHECK(ac.entryMap["ni"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["nil"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(frontend, "MainModule", Position{12, 14}, nullCallback); + ac = autocomplete('3'); CHECK(ac.entryMap.count("false")); CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(frontend, "MainModule", Position{13, 14}, nullCallback); + ac = autocomplete('4'); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete(frontend, "MainModule", Position{14, 14}, nullCallback); + ac = autocomplete('5'); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); } @@ -1988,10 +2003,10 @@ local target: ((number) -> string) & ((string) -> number)) local one = 4 local two = "hello" -return target(o) +return target(o@1) )"); - auto ac = autocomplete(5, 15); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -2002,10 +2017,10 @@ local target: ((number) -> string) & ((number) -> number)) local one = 4 local two = "hello" -return target(o) +return target(o@1) )"); - ac = autocomplete(5, 15); + ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -2016,10 +2031,10 @@ local target: ((number, number) -> string) & ((string) -> number)) local one = 4 local two = "hello" -return target(1, o) +return target(1, o@1) )"); - ac = autocomplete(5, 18); + ac = autocomplete('1'); CHECK(ac.entryMap.count("one")); CHECK(ac.entryMap["one"].typeCorrect == TypeCorrectKind::Correct); @@ -2032,10 +2047,10 @@ TEST_CASE_FIXTURE(ACFixture, "optional_members") local a = { x = 2, y = 3 } type A = typeof(a) local b: A? = a -return b. +return b.@1 )"); - auto ac = autocomplete(4, 9); + auto ac = autocomplete('1'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("x")); @@ -2045,10 +2060,10 @@ return b. local a = { x = 2, y = 3 } type A = typeof(a) local b: nil | A = a -return b. +return b.@1 )"); - ac = autocomplete(4, 9); + ac = autocomplete('1'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("x")); @@ -2056,10 +2071,10 @@ return b. check(R"( local b: nil | nil -return b. +return b.@1 )"); - ac = autocomplete(2, 9); + ac = autocomplete('1'); CHECK_EQ(0, ac.entryMap.size()); } @@ -2067,26 +2082,26 @@ return b. TEST_CASE_FIXTURE(ACFixture, "no_function_name_suggestions") { check(R"( -function na +function na@1 )"); - auto ac = autocomplete(1, 11); + auto ac = autocomplete('1'); CHECK(ac.entryMap.empty()); check(R"( -local function +local function @1 )"); - ac = autocomplete(1, 15); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); check(R"( -local function na +local function na@1 )"); - ac = autocomplete(1, 17); + ac = autocomplete('1'); CHECK(ac.entryMap.empty()); } @@ -2095,20 +2110,20 @@ TEST_CASE_FIXTURE(ACFixture, "skip_current_local") { check(R"( local other = 1 -local name = na +local name = na@1 )"); - auto ac = autocomplete(2, 15); + auto ac = autocomplete('1'); CHECK(!ac.entryMap.count("name")); CHECK(ac.entryMap.count("other")); check(R"( local other = 1 -local name, test = na +local name, test = na@1 )"); - ac = autocomplete(2, 21); + ac = autocomplete('1'); CHECK(!ac.entryMap.count("name")); CHECK(!ac.entryMap.count("test")); @@ -2119,26 +2134,26 @@ TEST_CASE_FIXTURE(ACFixture, "keyword_members") { check(R"( local a = { done = 1, forever = 2 } -local b = a.do -local c = a.for -local d = a. +local b = a.do@1 +local c = a.for@2 +local d = a.@3 do end )"); - auto ac = autocomplete(2, 14); + auto ac = autocomplete('1'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); CHECK(ac.entryMap.count("forever")); - ac = autocomplete(3, 15); + ac = autocomplete('2'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); CHECK(ac.entryMap.count("forever")); - ac = autocomplete(4, 12); + ac = autocomplete('3'); CHECK_EQ(2, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); @@ -2150,10 +2165,10 @@ TEST_CASE_FIXTURE(ACFixture, "keyword_methods") check(R"( local a = {} function a:done() end -local b = a:do +local b = a:do@1 )"); - auto ac = autocomplete(3, 14); + auto ac = autocomplete('1'); CHECK_EQ(1, ac.entryMap.size()); CHECK(ac.entryMap.count("done")); @@ -2247,29 +2262,29 @@ local elsewhere = false local doover = false local endurance = true -if 1 then -else +if 1 then@1 +else@2 end -while false do +while false do@3 end -repeat +repeat@4 until )"); - auto ac = autocomplete(6, 9); + auto ac = autocomplete('1'); CHECK(ac.entryMap.size() == 1); CHECK(ac.entryMap.count("then")); - ac = autocomplete(7, 4); + ac = autocomplete('2'); CHECK(ac.entryMap.count("else")); CHECK(ac.entryMap.count("elseif")); - ac = autocomplete(10, 14); + ac = autocomplete('3'); CHECK(ac.entryMap.count("do")); - ac = autocomplete(13, 6); + ac = autocomplete('4'); CHECK(ac.entryMap.count("do")); // FIXME: ideally we want to handle start and end of all statements as well @@ -2284,11 +2299,11 @@ local elsewhere = false if true then return 1 -el +el@1 end )"); - auto ac = autocomplete(5, 2); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("else")); CHECK(ac.entryMap.count("elseif")); CHECK(ac.entryMap.count("elsewhere") == 0); @@ -2300,11 +2315,11 @@ if true then return 1 else return 2 -el +el@1 end )"); - ac = autocomplete(7, 2); + ac = autocomplete('1'); CHECK(ac.entryMap.count("else") == 0); CHECK(ac.entryMap.count("elseif") == 0); CHECK(ac.entryMap.count("elsewhere")); @@ -2316,10 +2331,10 @@ if true then print("1") elif true then print("2") -el +el@1 end )"); - ac = autocomplete(7, 2); + ac = autocomplete('1'); CHECK(ac.entryMap.count("else")); CHECK(ac.entryMap.count("elseif")); CHECK(ac.entryMap.count("elsewhere")); @@ -2360,30 +2375,30 @@ TEST_CASE_FIXTURE(ACFixture, "suggest_table_keys") { check(R"( type Test = { first: number, second: number } -local t: Test = { f } +local t: Test = { f@1 } )"); - auto ac = autocomplete(2, 19); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); // Intersection check(R"( type Test = { first: number } & { second: number } -local t: Test = { f } +local t: Test = { f@1 } )"); - ac = autocomplete(2, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); // Union check(R"( type Test = { first: number, second: number } | { second: number, third: number } -local t: Test = { s } +local t: Test = { s@1 } )"); - ac = autocomplete(2, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("second")); CHECK(!ac.entryMap.count("first")); CHECK(!ac.entryMap.count("third")); @@ -2391,60 +2406,60 @@ local t: Test = { s } // No parenthesis suggestion check(R"( type Test = { first: (number) -> number, second: number } -local t: Test = { f } +local t: Test = { f@1 } )"); - ac = autocomplete(2, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap["first"].parens == ParenthesesRecommendation::None); // When key is changed check(R"( type Test = { first: number, second: number } -local t: Test = { f = 2 } +local t: Test = { f@1 = 2 } )"); - ac = autocomplete(2, 19); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); // Alternative key syntax check(R"( type Test = { first: number, second: number } -local t: Test = { ["f"] } +local t: Test = { ["f@1"] } )"); - ac = autocomplete(2, 21); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); // Not an alternative key syntax check(R"( type Test = { first: number, second: number } -local t: Test = { "f" } +local t: Test = { "f@1" } )"); - ac = autocomplete(2, 20); + ac = autocomplete('1'); CHECK(!ac.entryMap.count("first")); CHECK(!ac.entryMap.count("second")); // Skip keys that are already defined check(R"( type Test = { first: number, second: number } -local t: Test = { first = 2, s } +local t: Test = { first = 2, s@1 } )"); - ac = autocomplete(2, 30); + ac = autocomplete('1'); CHECK(!ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); // Don't skip active key check(R"( type Test = { first: number, second: number } -local t: Test = { first } +local t: Test = { first@1 } )"); - ac = autocomplete(2, 23); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); @@ -2452,22 +2467,22 @@ local t: Test = { first } check(R"( local t = { { first = 5, second = 10 }, - { f } + { f@1 } } )"); - ac = autocomplete(3, 7); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); check(R"( local t = { [2] = { first = 5, second = 10 }, - [5] = { f } + [5] = { f@1 } } )"); - ac = autocomplete(3, 13); + ac = autocomplete('1'); CHECK(ac.entryMap.count("first")); CHECK(ac.entryMap.count("second")); } @@ -2502,15 +2517,15 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_ifelse_expressions") local temp = false local even = true; local a = true -a = if t1@emp then t -a = if temp t2@ -a = if temp then e3@ -a = if temp then even e4@ -a = if temp then even elseif t5@ -a = if temp then even elseif true t6@ -a = if temp then even elseif true then t7@ -a = if temp then even elseif true then temp e8@ -a = if temp then even elseif true then temp else e9@ +a = if t@1emp then t +a = if temp t@2 +a = if temp then e@3 +a = if temp then even e@4 +a = if temp then even elseif t@5 +a = if temp then even elseif true t@6 +a = if temp then even elseif true then t@7 +a = if temp then even elseif true then temp e@8 +a = if temp then even elseif true then temp else e@9 )"); auto ac = autocomplete('1'); @@ -2573,4 +2588,20 @@ a = if temp then even elseif true then temp else e9@ } } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack") +{ + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + check(R"( +type A = () -> T... +local a: A<(number, s@1> + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap.count("string")); +} + TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 26bc77f70..29c33f7c1 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -32,6 +32,55 @@ std::optional TestFileResolver::fromAstFragment(AstExpr* expr) const return std::nullopt; } +std::optional TestFileResolver::resolveModule(const ModuleInfo* context, AstExpr* expr) +{ + if (AstExprGlobal* g = expr->as()) + { + if (g->name == "game") + return ModuleInfo{"game"}; + if (g->name == "workspace") + return ModuleInfo{"workspace"}; + if (g->name == "script") + return context ? std::optional(*context) : std::nullopt; + } + else if (AstExprIndexName* i = expr->as(); i && context) + { + if (i->index == "Parent") + { + std::string_view view = context->name; + size_t lastSeparatorIndex = view.find_last_of('/'); + + if (lastSeparatorIndex == std::string_view::npos) + return std::nullopt; + + return ModuleInfo{ModuleName(view.substr(0, lastSeparatorIndex)), context->optional}; + } + else + { + return ModuleInfo{context->name + '/' + i->index.value, context->optional}; + } + } + else if (AstExprIndexExpr* i = expr->as(); i && context) + { + if (AstExprConstantString* index = i->index->as()) + { + return ModuleInfo{context->name + '/' + std::string(index->value.data, index->value.size), context->optional}; + } + } + else if (AstExprCall* call = expr->as(); call && call->self && call->args.size >= 1 && context) + { + if (AstExprConstantString* index = call->args.data[0]->as()) + { + AstName func = call->func->as()->index; + + if (func == "GetService" && context->name == "game") + return ModuleInfo{"game/" + std::string(index->value.data, index->value.size)}; + } + } + + return std::nullopt; +} + ModuleName TestFileResolver::concat(const ModuleName& lhs, std::string_view rhs) const { return lhs + "/" + ModuleName(rhs); diff --git a/tests/Fixture.h b/tests/Fixture.h index c6294b014..1480a7f6a 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -65,6 +65,8 @@ struct TestFileResolver } std::optional fromAstFragment(AstExpr* expr) const override; + std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) override; + ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override; std::optional getParentModuleName(const ModuleName& name) const override; diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 3f33a5d19..fbfec6367 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -58,6 +58,35 @@ struct NaiveFileResolver : NullFileResolver return std::nullopt; } + std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) override + { + if (AstExprGlobal* g = expr->as()) + { + if (g->name == "Modules") + return ModuleInfo{"Modules"}; + + if (g->name == "game") + return ModuleInfo{"game"}; + } + else if (AstExprIndexName* i = expr->as()) + { + if (context) + return ModuleInfo{context->name + '/' + i->index.value, context->optional}; + } + else if (AstExprCall* call = expr->as(); call && call->self && call->args.size >= 1 && context) + { + if (AstExprConstantString* index = call->args.data[0]->as()) + { + AstName func = call->func->as()->index; + + if (func == "GetService" && context->name == "game") + return ModuleInfo{"game/" + std::string(index->value.data, index->value.size)}; + } + } + + return std::nullopt; + } + ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override { return lhs + "/" + ModuleName(rhs); @@ -528,7 +557,7 @@ TEST_CASE_FIXTURE(FrontendFixture, "ignore_require_to_nonexistent_file") { fileResolver.source["Modules/A"] = R"( local Modules = script - local B = require(Modules.B :: any) + local B = require(Modules.B) :: any )"; CheckResult result = frontend.check("Modules/A"); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index c8eff3991..a9ed139f1 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1400,6 +1400,8 @@ end TEST_CASE_FIXTURE(Fixture, "TableOperations") { + ScopedFastFlag sff("LuauLinterTableMoveZero", true); + LintResult result = lintTyped(R"( local t = {} local tt = {} @@ -1417,9 +1419,12 @@ table.remove(t, 0) table.remove(t, #t-1) table.insert(t, string.find("hello", "h")) + +table.move(t, 0, #t, 1, tt) +table.move(t, 1, #t, 0, tt) )"); - REQUIRE_EQ(result.warnings.size(), 6); + REQUIRE_EQ(result.warnings.size(), 8); CHECK_EQ(result.warnings[0].text, "table.insert will insert the value before the last element, which is likely a bug; consider removing the " "second argument or wrap it in parentheses to silence"); CHECK_EQ(result.warnings[1].text, "table.insert will append the value to the table; consider removing the second argument for efficiency"); @@ -1429,6 +1434,8 @@ table.insert(t, string.find("hello", "h")) "second argument or wrap it in parentheses to silence"); CHECK_EQ(result.warnings[5].text, "table.insert may change behavior if the call returns more than one result; consider adding parentheses around second argument"); + CHECK_EQ(result.warnings[6].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + CHECK_EQ(result.warnings[7].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); } TEST_CASE_FIXTURE(Fixture, "DuplicateConditions") diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 1b146ed2a..18f55d2c1 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Module.h" +#include "Luau/Scope.h" #include "Fixture.h" diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index f3c76d55a..931a8403a 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index cb03a7bd9..a80718e47 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2519,4 +2519,19 @@ TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") } } +TEST_CASE_FIXTURE(Fixture, "parse_type_pack_type_parameters") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + AstStat* stat = parse(R"( +type Packed = () -> T... + +type A = Packed +type B = Packed<...number> +type C = Packed<(number, X...)> + )"); + REQUIRE(stat != nullptr); +} + TEST_SUITE_END(); diff --git a/tests/RequireTracer.test.cpp b/tests/RequireTracer.test.cpp index cbd4af29a..b9fd04d69 100644 --- a/tests/RequireTracer.test.cpp +++ b/tests/RequireTracer.test.cpp @@ -57,6 +57,7 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_local") { AstStatBlock* block = parse(R"( local m = workspace.Foo.Bar.Baz + require(m) )"); RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); @@ -70,22 +71,22 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_local") AstExprIndexName* value = loc->values.data[0]->as(); REQUIRE(value); REQUIRE(result.exprs.contains(value)); - CHECK_EQ("workspace/Foo/Bar/Baz", result.exprs[value]); + CHECK_EQ("workspace/Foo/Bar/Baz", result.exprs[value].name); value = value->expr->as(); REQUIRE(value); REQUIRE(result.exprs.contains(value)); - CHECK_EQ("workspace/Foo/Bar", result.exprs[value]); + CHECK_EQ("workspace/Foo/Bar", result.exprs[value].name); value = value->expr->as(); REQUIRE(value); REQUIRE(result.exprs.contains(value)); - CHECK_EQ("workspace/Foo", result.exprs[value]); + CHECK_EQ("workspace/Foo", result.exprs[value].name); AstExprGlobal* workspace = value->expr->as(); REQUIRE(workspace); REQUIRE(result.exprs.contains(workspace)); - CHECK_EQ("workspace", result.exprs[workspace]); + CHECK_EQ("workspace", result.exprs[workspace].name); } TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local") @@ -93,9 +94,10 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local") AstStatBlock* block = parse(R"( local m = workspace.Foo.Bar.Baz local n = m.Quux + require(n) )"); - REQUIRE_EQ(2, block->body.size); + REQUIRE_EQ(3, block->body.size); RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); @@ -104,13 +106,13 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_transitive_local") REQUIRE_EQ(1, local->vars.size); REQUIRE(result.exprs.contains(local->values.data[0])); - CHECK_EQ("workspace/Foo/Bar/Baz/Quux", result.exprs[local->values.data[0]]); + CHECK_EQ("workspace/Foo/Bar/Baz/Quux", result.exprs[local->values.data[0]].name); } TEST_CASE_FIXTURE(RequireTracerFixture, "trace_function_arguments") { AstStatBlock* block = parse(R"( - local M = require(workspace.Game.Thing, workspace.Something.Else) + local M = require(workspace.Game.Thing) )"); REQUIRE_EQ(1, block->body.size); @@ -124,52 +126,9 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "trace_function_arguments") AstExprCall* call = local->values.data[0]->as(); REQUIRE(call != nullptr); - REQUIRE_EQ(2, call->args.size); - - CHECK_EQ("workspace/Game/Thing", result.exprs[call->args.data[0]]); - CHECK_EQ("workspace/Something/Else", result.exprs[call->args.data[1]]); -} - -TEST_CASE_FIXTURE(RequireTracerFixture, "follow_GetService_calls") -{ - AstStatBlock* block = parse(R"( - local R = game:GetService('ReplicatedStorage').Roact - local Roact = require(R) - )"); - REQUIRE_EQ(2, block->body.size); - - RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); - - AstStatLocal* local = block->body.data[0]->as(); - REQUIRE(local != nullptr); - - CHECK_EQ("game/ReplicatedStorage/Roact", result.exprs[local->values.data[0]]); - - AstStatLocal* local2 = block->body.data[1]->as(); - REQUIRE(local2 != nullptr); - REQUIRE_EQ(1, local2->values.size); - - AstExprCall* call = local2->values.data[0]->as(); - REQUIRE(call != nullptr); REQUIRE_EQ(1, call->args.size); - CHECK_EQ("game/ReplicatedStorage/Roact", result.exprs[call->args.data[0]]); -} - -TEST_CASE_FIXTURE(RequireTracerFixture, "follow_WaitForChild_calls") -{ - ScopedFastFlag luauTraceRequireLookupChild("LuauTraceRequireLookupChild", true); - - AstStatBlock* block = parse(R"( -local A = require(workspace:WaitForChild('ReplicatedStorage').Content) -local B = require(workspace:FindFirstChild('ReplicatedFirst').Data) - )"); - - RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); - - REQUIRE_EQ(2, result.requires.size()); - CHECK_EQ("workspace/ReplicatedStorage/Content", result.requires[0].first); - CHECK_EQ("workspace/ReplicatedFirst/Data", result.requires[1].first); + CHECK_EQ("workspace/Game/Thing", result.exprs[call->args.data[0]].name); } TEST_CASE_FIXTURE(RequireTracerFixture, "follow_typeof") @@ -200,22 +159,23 @@ TEST_CASE_FIXTURE(RequireTracerFixture, "follow_typeof") REQUIRE(call != nullptr); REQUIRE_EQ(1, call->args.size); - CHECK_EQ("workspace/CoolThing", result.exprs[call->args.data[0]]); + CHECK_EQ("workspace/CoolThing", result.exprs[call->args.data[0]].name); } TEST_CASE_FIXTURE(RequireTracerFixture, "follow_string_indexexpr") { AstStatBlock* block = parse(R"( local R = game["Test"] + require(R) )"); - REQUIRE_EQ(1, block->body.size); + REQUIRE_EQ(2, block->body.size); RequireTraceResult result = traceRequires(&fileResolver, block, "ModuleName"); AstStatLocal* local = block->body.data[0]->as(); REQUIRE(local != nullptr); - CHECK_EQ("game/Test", result.exprs[local->values.data[0]]); + CHECK_EQ("game/Test", result.exprs[local->values.data[0]].name); } TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index d7d68c461..e18bf7cdd 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Scope.h" #include "Luau/ToString.h" #include "Fixture.h" @@ -416,8 +417,6 @@ function foo(a, b) return a(b) end TEST_CASE_FIXTURE(Fixture, "toString_the_boundTo_table_type_contained_within_a_TypePack") { - ScopedFastFlag sff{"LuauToStringFollowsBoundTo", true}; - TypeVar tv1{TableTypeVar{}}; TableTypeVar* ttv = getMutable(&tv1); ttv->state = TableState::Sealed; diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp new file mode 100644 index 000000000..045f0230d --- /dev/null +++ b/tests/TypeInfer.aliases.test.cpp @@ -0,0 +1,557 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" +#include "Luau/BuiltinDefinitions.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeAliases"); + +TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") +{ + ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; + + CheckResult result = check(R"( + type F = () -> F? + local function f() + return f + end + + local g: F = f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("t1 where t1 = () -> t1?", toString(requireType("g"))); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_when_stringified") +{ + CheckResult result = check(R"( + --!strict + type Node = { Parent: Node?; } + local node: Node; + node.Parent = 1 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("Node?", toString(tm->wantedType)); + CHECK_EQ(typeChecker.numberType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types") +{ + CheckResult result = check(R"( + --!strict + type T = { f: a, g: U } + type U = { h: a, i: T? } + local x: T = { f = 37, g = { h = 5, i = nil } } + x.g.i = x + local y: T = { f = "hi", g = { h = "lo", i = nil } } + y.g.i = y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_errors") +{ + CheckResult result = check(R"( + --!strict + type T = { f: a, g: U } + type U = { h: b, i: T? } + local x: T = { f = 37, g = { h = 5, i = nil } } + x.g.i = x + local y: T = { f = "hi", g = { h = 5, i = nil } } + y.g.i = y + )"); + + LUAU_REQUIRE_ERRORS(result); + + // We had a UAF in this example caused by not cloning type function arguments + ModulePtr module = frontend.moduleResolver.getModule("MainModule"); + unfreeze(module->interfaceTypes); + copyErrors(module->errors, module->interfaceTypes); + freeze(module->interfaceTypes); + module->internalTypes.clear(); + module->astTypes.clear(); + + // Make sure the error strings don't include "VALUELESS" + for (auto error : module->errors) + CHECK_MESSAGE(toString(error).find("VALUELESS") == std::string::npos, toString(error)); +} + +TEST_CASE_FIXTURE(Fixture, "use_table_name_and_generic_params_in_errors") +{ + CheckResult result = check(R"( + type Pair = {first: T, second: U} + local a: Pair + local b: Pair + + a = b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + + CHECK_EQ("Pair", toString(tm->wantedType)); + CHECK_EQ("Pair", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_type_definition") +{ + CheckResult result = check(R"( + type A = number + type A = string -- Redefinition of type 'A', previously defined at line 1 + local foo: string = 1 -- No "Type 'number' could not be converted into 'string'" + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + +TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type") +{ + CheckResult result = check(R"( + type Table = { a: T } + type Wrapped = Table + local l: Wrapped = 2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("Wrapped", toString(tm->wantedType)); + CHECK_EQ(typeChecker.numberType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type2") +{ + CheckResult result = check(R"( + type Table = { a: T } + type Wrapped = (Table) -> string + local l: Wrapped = 2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("t1 where t1 = ({| a: t1 |}) -> string", toString(tm->wantedType)); + CHECK_EQ(typeChecker.numberType, tm->givenType); +} + +// Check that recursive intersection type doesn't generate an OOM +TEST_CASE_FIXTURE(Fixture, "cli_38393_recursive_intersection_oom") +{ + CheckResult result = check(R"( + function _(l0:(t0)&((t0)&(((t0)&((t0)->()))->(typeof(_),typeof(# _)))),l39,...):any + end + type t0 = ((typeof(_))&((t0)&(((typeof(_))&(t0))->typeof(_))),{n163:any,})->(any,typeof(_)) + _(_) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_fwd_declaration_is_precise") +{ + CheckResult result = check(R"( + local foo: Id = 1 + type Id = T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "corecursive_types_generic") +{ + const std::string code = R"( + type A = {v:T, b:B} + type B = {v:T, a:A} + local aa:A + local bb = aa + )"; + + const std::string expected = R"( + type A = {v:T, b:B} + type B = {v:T, a:A} + local aa:A + local bb:A=aa + )"; + + CHECK_EQ(expected, decorateWithTypes(code)); + CheckResult result = check(code); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "corecursive_function_types") +{ + ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; + + CheckResult result = check(R"( + type A = () -> (number, B) + type B = () -> (string, A) + local a: A + local b: B + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("t1 where t1 = () -> (number, () -> (string, t1))", toString(requireType("a"))); + CHECK_EQ("t1 where t1 = () -> (string, () -> (number, t1))", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "generic_param_remap") +{ + const std::string code = R"( + -- An example of a forwarded use of a type that has different type arguments than parameters + type A = {t:T, u:U, next:A?} + local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } + local bb = aa + )"; + + const std::string expected = R"( + + type A = {t:T, u:U, next:A?} + local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } + local bb:A=aa + )"; + + CHECK_EQ(expected, decorateWithTypes(code)); + CheckResult result = check(code); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "export_type_and_type_alias_are_duplicates") +{ + CheckResult result = check(R"( + export type Foo = number + type Foo = number + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto dtd = get(result.errors[0]); + REQUIRE(dtd); + CHECK_EQ(dtd->name, "Foo"); +} + +TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") +{ + ScopedFastFlag sffs3{"LuauGenericFunctions", true}; + ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; + + CheckResult result = check(R"( + type Node = { value: T, child: Node? } + + local function visitor(node: Node?) + local a: Node + + if node then + a = node.child -- Observe the output of the error message. + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto e = get(result.errors[0]); + CHECK_EQ("Node?", toString(e->givenType)); + CHECK_EQ("Node", toString(e->wantedType)); +} + +TEST_CASE_FIXTURE(Fixture, "general_require_multi_assign") +{ + fileResolver.source["workspace/A"] = R"( + export type myvec2 = {x: number, y: number} + return {} + )"; + + fileResolver.source["workspace/B"] = R"( + export type myvec3 = {x: number, y: number, z: number} + return {} + )"; + + fileResolver.source["workspace/C"] = R"( + local Foo, Bar = require(workspace.A), require(workspace.B) + + local a: Foo.myvec2 + local b: Bar.myvec3 + )"; + + CheckResult result = frontend.check("workspace/C"); + LUAU_REQUIRE_NO_ERRORS(result); + ModulePtr m = frontend.moduleResolver.modules["workspace/C"]; + + REQUIRE(m != nullptr); + + std::optional aTypeId = lookupName(m->getModuleScope(), "a"); + REQUIRE(aTypeId); + const Luau::TableTypeVar* aType = get(follow(*aTypeId)); + REQUIRE(aType); + REQUIRE(aType->props.size() == 2); + + std::optional bTypeId = lookupName(m->getModuleScope(), "b"); + REQUIRE(bTypeId); + const Luau::TableTypeVar* bType = get(follow(*bTypeId)); + REQUIRE(bType); + REQUIRE(bType->props.size() == 3); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_import_mutation") +{ + CheckResult result = check("type t10 = typeof(table)"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId ty = getGlobalBinding(frontend.typeChecker, "table"); + CHECK_EQ(toString(ty), "table"); + + const TableTypeVar* ttv = get(ty); + REQUIRE(ttv); + + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_local_mutation") +{ + CheckResult result = check(R"( +type Cool = { a: number, b: string } +local c: Cool = { a = 1, b = "s" } +type NotCool = Cool +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("c"); + REQUIRE(ty); + CHECK_EQ(toString(*ty), "Cool"); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_local_rename") +{ + CheckResult result = check(R"( +type Cool = { a: number, b: string } +type NotCool = Cool +local c: Cool = { a = 1, b = "s" } +local d: NotCool = { a = 1, b = "s" } +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("c"); + REQUIRE(ty); + CHECK_EQ(toString(*ty), "Cool"); + + ty = requireType("d"); + REQUIRE(ty); + CHECK_EQ(toString(*ty), "NotCool"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_local_synthetic_mutation") +{ + CheckResult result = check(R"( +local c = { a = 1, b = "s" } +type Cool = typeof(c) +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("c"); + REQUIRE(ty); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + CHECK_EQ(ttv->name, "Cool"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_type") +{ + fileResolver.source["game/A"] = R"( +export type X = { a: number, b: X? } +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = check(R"( +local Import = require(game.A) +type X = Import.X + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + std::optional ty1 = lookupImportedType("Import", "X"); + REQUIRE(ty1); + + std::optional ty2 = lookupType("X"); + REQUIRE(ty2); + + CHECK_EQ(follow(*ty1), follow(*ty2)); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_generic_type") +{ + fileResolver.source["game/A"] = R"( +export type X = { a: T, b: U, C: X? } +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = check(R"( +local Import = require(game.A) +type X = Import.X + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + std::optional ty1 = lookupImportedType("Import", "X"); + REQUIRE(ty1); + + std::optional ty2 = lookupType("X"); + REQUIRE(ty2); + + CHECK_EQ(toString(*ty1, {true}), toString(*ty2, {true})); + + bResult = check(R"( +local Import = require(game.A) +type X = Import.X + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + ty1 = lookupImportedType("Import", "X"); + REQUIRE(ty1); + + ty2 = lookupType("X"); + REQUIRE(ty2); + + CHECK_EQ(toString(*ty1, {true}), "t1 where t1 = {| C: t1?, a: T, b: U |}"); + CHECK_EQ(toString(*ty2, {true}), "{| C: t1, a: U, b: T |} where t1 = {| C: t1, a: U, b: T |}?"); +} + +TEST_CASE_FIXTURE(Fixture, "module_export_free_type_leak") +{ + CheckResult result = check(R"( +function get() + return function(obj) return true end +end + +export type f = typeof(get()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "module_export_wrapped_free_type_leak") +{ + CheckResult result = check(R"( +function get() + return {a = 1, b = function(obj) return true end} +end + +export type f = typeof(get()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_ok") +{ + CheckResult result = check(R"( + type Tree = { data: T, children: Forest } + type Forest = {Tree} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1") +{ + ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; + + CheckResult result = check(R"( + -- OK because forwarded types are used with their parameters. + type Tree = { data: T, children: Forest } + type Forest = {Tree<{T}>} + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_2") +{ + ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; + + CheckResult result = check(R"( + -- Not OK because forwarded types are used with different types than their parameters. + type Forest = {Tree<{T}>} + type Tree = { data: T, children: Forest } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_ok") +{ + CheckResult result = check(R"( + type Tree1 = { data: T, children: {Tree2} } + type Tree2 = { data: U, children: {Tree1} } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_not_ok") +{ + ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; + + CheckResult result = check(R"( + type Tree1 = { data: T, children: {Tree2} } + type Tree2 = { data: U, children: {Tree1} } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "free_variables_from_typeof_in_aliases") +{ + CheckResult result = check(R"( + function f(x) return x[1] end + -- x has type X? for a free type variable X + local x = f ({}) + type ContainsFree = { this: a, that: typeof(x) } + type ContainsContainsFree = { that: ContainsFree } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "non_recursive_aliases_that_reuse_a_generic_name") +{ + ScopedFastFlag sff1{"LuauSubstitutionDontReplaceIgnoredTypes", true}; + + CheckResult result = check(R"( + type Array = { [number]: T } + type Tuple = Array + + local p: Tuple + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("{number | string}", toString(requireType("p"), {true})); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 46496fdb1..8bcb02424 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -30,6 +30,8 @@ TEST_SUITE_BEGIN("ProvisionalTests"); */ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") { + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + const std::string code = R"( function f(a) if type(a) == "boolean" then @@ -41,11 +43,11 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") )"; const std::string expected = R"( - function f(a:{fn:()->(free)}): () + function f(a:{fn:()->(free,free...)}): () if type(a) == 'boolean'then local a1:boolean=a elseif a.fn()then - local a2:{fn:()->(free)}=a + local a2:{fn:()->(free,free...)}=a end end )"; @@ -231,16 +233,7 @@ TEST_CASE_FIXTURE(Fixture, "operator_eq_completely_incompatible") local r2 = b == a )"); - if (FFlag::LuauEqConstraint) - { - LUAU_REQUIRE_NO_ERRORS(result); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(toString(result.errors[0]), "Type '{| x: string |}?' could not be converted into 'number | string'"); - CHECK_EQ(toString(result.errors[1]), "Type 'number | string' could not be converted into '{| x: string |}?'"); - } + LUAU_REQUIRE_NO_ERRORS(result); } // Belongs in TypeInfer.refinements.test.cpp. @@ -542,6 +535,25 @@ TEST_CASE_FIXTURE(Fixture, "bail_early_on_typescript_port_of_Result_type" * doct } } +TEST_CASE_FIXTURE(Fixture, "table_subtyping_shouldn't_add_optional_properties_to_sealed_tables") +{ + CheckResult result = check(R"( + --!strict + local function setNumber(t: { p: number? }, x:number) t.p = x end + local function getString(t: { p: string? }):string return t.p or "" end + -- This shouldn't type-check! + local function oh(x:number): string + local t: {} = {} + setNumber(t, x) + return getString(t) + end + local s: string = oh(37) + )"); + + // Really this should return an error, but it doesn't + LUAU_REQUIRE_NO_ERRORS(result); +} + // Should be in TypeInfer.tables.test.cpp // It's unsound to instantiate tables containing generic methods, // since mutating properties means table properties should be invariant. diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index f2ba0ddc5..31739cdc7 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1,4 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Fixture.h" @@ -6,7 +7,6 @@ #include "doctest.h" LUAU_FASTFLAG(LuauWeakEqConstraint) -LUAU_FASTFLAG(LuauImprovedTypeGuardPredicate2) LUAU_FASTFLAG(LuauOrPredicate) using namespace Luau; @@ -199,16 +199,8 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_only_look_up_types_from_global_scope") end )"); - if (FFlag::LuauImprovedTypeGuardPredicate2) - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0])); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'string' could not be converted into 'boolean'", toString(result.errors[0])); - } + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") @@ -526,8 +518,6 @@ TEST_CASE_FIXTURE(Fixture, "narrow_property_of_a_bounded_variable") TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x) if type(x) == "vector" then @@ -544,8 +534,6 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local t = {"hello"} local v = t[2] @@ -573,8 +561,6 @@ TEST_CASE_FIXTURE(Fixture, "nonoptional_type_can_narrow_to_nil_if_sense_is_true" TEST_CASE_FIXTURE(Fixture, "typeguard_not_to_be_string") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x: string | number | boolean) if type(x) ~= "string" then @@ -593,8 +579,6 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_not_to_be_string") TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_table") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x: string | {x: number} | {y: boolean}) if type(x) == "table" then @@ -613,8 +597,6 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_table") TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_functions") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function weird(x: string | ((number) -> string)) if type(x) == "function" then @@ -698,8 +680,6 @@ struct RefinementClassFixture : Fixture TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(vec) local X, Y, Z = vec.X, vec.Y, vec.Z @@ -726,8 +706,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x: Instance | Vector3) if typeof(x) == "Vector3" then @@ -746,8 +724,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(x: string | number | Instance | Vector3) if type(x) == "userdata" then @@ -766,10 +742,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") { - ScopedFastFlag sffs[] = { - {"LuauImprovedTypeGuardPredicate2", true}, - {"LuauTypeGuardPeelsAwaySubclasses", true}, - }; + ScopedFastFlag sff{"LuauTypeGuardPeelsAwaySubclasses", true}; CheckResult result = check(R"( local function f(x: Part | Folder | string) @@ -789,10 +762,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") { - ScopedFastFlag sffs[] = { - {"LuauImprovedTypeGuardPredicate2", true}, - {"LuauTypeGuardPeelsAwaySubclasses", true}, - }; + ScopedFastFlag sff{"LuauTypeGuardPeelsAwaySubclasses", true}; CheckResult result = check(R"( local function f(x: Part | Folder | Instance | string | Vector3 | any) @@ -812,10 +782,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; + ScopedFastFlag sff{"LuauOrPredicate", true}; CheckResult result = check(R"( --!nonstrict @@ -839,7 +806,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") { ScopedFastFlag sffs[] = { {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, {"LuauTypeGuardPeelsAwaySubclasses", true}, }; @@ -861,8 +827,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( type XYCoord = {x: number} & {y: number} local function f(t: XYCoord?) @@ -882,8 +846,6 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( type SomeOverloadedFunction = ((number) -> string) & ((string) -> number) local function f(g: SomeOverloadedFunction?) @@ -903,8 +865,6 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_overloaded_function") TEST_CASE_FIXTURE(Fixture, "type_guard_warns_on_no_overlapping_types_only_when_sense_is_true") { - ScopedFastFlag sff2{"LuauImprovedTypeGuardPredicate2", true}; - CheckResult result = check(R"( local function f(t: {x: number}) if type(t) ~= "table" then @@ -999,10 +959,7 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") TEST_CASE_FIXTURE(Fixture, "either_number_or_string") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; + ScopedFastFlag sff{"LuauOrPredicate", true}; CheckResult result = check(R"( local function f(x: any) @@ -1036,10 +993,7 @@ TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; + ScopedFastFlag sff{"LuauOrPredicate", true}; CheckResult result = check(R"( local a: (number | string)? @@ -1057,10 +1011,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; + ScopedFastFlag sff{"LuauOrPredicate", true}; // This bug came up because there was a mistake in Luau::merge where zipping on two maps would produce the wrong merged result. CheckResult result = check(R"( @@ -1081,10 +1032,7 @@ TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_number_or_string") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauImprovedTypeGuardPredicate2", true}, - }; + ScopedFastFlag sff{"LuauOrPredicate", true}; CheckResult result = check(R"( local function f(a: string | number | boolean) diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 1d1b2fae2..b7f0dc7b0 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -46,6 +46,21 @@ TEST_CASE_FIXTURE(Fixture, "augment_table") CHECK(tType->props.find("foo") != tType->props.end()); } +TEST_CASE_FIXTURE(Fixture, "augment_nested_table") +{ + CheckResult result = check("local t = { p = {} } t.p.foo = 'bar'"); + LUAU_REQUIRE_NO_ERRORS(result); + + TableTypeVar* tType = getMutable(requireType("t")); + REQUIRE(tType != nullptr); + + REQUIRE(tType->props.find("p") != tType->props.end()); + const TableTypeVar* pType = get(tType->props["p"].type); + REQUIRE(pType != nullptr); + + CHECK(pType->props.find("foo") != pType->props.end()); +} + TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table") { CheckResult result = check("local t = {prop=999} t.foo = 'bar'"); @@ -260,6 +275,8 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification") TEST_CASE_FIXTURE(Fixture, "open_table_unification_2") { + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + CheckResult result = check(R"( local a = {} a.x = 99 @@ -272,10 +289,11 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification_2") LUAU_REQUIRE_ERROR_COUNT(1, result); TypeError& err = result.errors[0]; - UnknownProperty* error = get(err); + MissingProperties* error = get(err); REQUIRE(error != nullptr); + REQUIRE(error->properties.size() == 1); - CHECK_EQ(error->key, "y"); + CHECK_EQ("y", error->properties[0]); // TODO(rblanckaert): Revist when we can bind self at function creation time // CHECK_EQ(err.location, Location(Position{5, 19}, Position{5, 25})); @@ -328,6 +346,8 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_1") TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") { + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + CheckResult result = check(R"( --!strict function foo(o) @@ -340,14 +360,17 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") LUAU_REQUIRE_ERROR_COUNT(1, result); - UnknownProperty* error = get(result.errors[0]); + MissingProperties* error = get(result.errors[0]); REQUIRE(error != nullptr); + REQUIRE(error->properties.size() == 1); - CHECK_EQ("baz", error->key); + CHECK_EQ("baz", error->properties[0]); } TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3") { + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + CheckResult result = check(R"( local T = {} T.bar = 'hello' @@ -359,8 +382,11 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3") LUAU_REQUIRE_ERROR_COUNT(1, result); TypeError& err = result.errors[0]; - UnknownProperty* error = get(err); + MissingProperties* error = get(err); REQUIRE(error != nullptr); + REQUIRE(error->properties.size() == 1); + + CHECK_EQ("baz", error->properties[0]); // TODO(rblanckaert): Revist when we can bind self at function creation time /* @@ -448,6 +474,73 @@ TEST_CASE_FIXTURE(Fixture, "ok_to_add_property_to_free_table") dumpErrors(result); } +TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_assignment") +{ + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + + CheckResult result = check(R"( + --!strict + local t = { u = {} } + t = { u = { p = 37 } } + t = { u = { q = "hi" } } + local x = t.u.p + local y = t.u.q + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number?", toString(requireType("x"))); + CHECK_EQ("string?", toString(requireType("y"))); +} + +TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_function_call") +{ + CheckResult result = check(R"( + --!strict + function get(x) return x.opts["MYOPT"] end + function set(x,y) x.opts["MYOPT"] = y end + local t = { opts = {} } + set(t,37) + local x = get(t) + )"); + + // Currently this errors but it shouldn't, since set only needs write access + // TODO: file a JIRA for this + LUAU_REQUIRE_ERRORS(result); + // CHECK_EQ("number?", toString(requireType("x"))); +} + +TEST_CASE_FIXTURE(Fixture, "width_subtyping") +{ + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + + CheckResult result = check(R"( + --!strict + function f(x : { q : number }) + x.q = 8 + end + local t : { q : number, r : string } = { q = 8, r = "hi" } + f(t) + local x : string = t.r + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "width_subtyping_needs_covariance") +{ + CheckResult result = check(R"( + --!strict + function f(x : { p : { q : number }}) + x.p = { q = 8, r = 5 } + end + local t : { p : { q : number, r : string } } = { p = { q = 8, r = "hi" } } + f(t) -- Shouldn't typecheck + local x : string = t.p.r -- x is 5 + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "infer_array") { CheckResult result = check(R"( @@ -676,16 +769,27 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_for_left_unsealed_table_from_right_han LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "sealed_table_value_must_not_infer_an_indexer") +TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer") { + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + CheckResult result = check(R"( local t: { a: string, [number]: string } = { a = "foo" } )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_NO_ERRORS(result); +} - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm != nullptr); +TEST_CASE_FIXTURE(Fixture, "array_factory_function") +{ + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + + CheckResult result = check(R"( + function empty() return {} end + local array: {string} = empty() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "sealed_table_indexers_must_unify") @@ -756,37 +860,6 @@ TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_should_prefer_properties_when_ CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]); } -TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_with_a_string") -{ - ScopedFastFlag fflag("LuauIndexTablesWithIndexers", true); - - CheckResult result = check(R"( - local t: { a: string } - function f(x: string) return t[x] end - local a = f("a") - local b = f("b") - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.anyType, *requireType("a")); - CHECK_EQ(*typeChecker.anyType, *requireType("b")); -} - -TEST_CASE_FIXTURE(Fixture, "indexing_from_a_table_with_a_number") -{ - ScopedFastFlag fflag("LuauIndexTablesWithIndexers", true); - - CheckResult result = check(R"( - local t = { a = true } - function f(x: number) return t[x] end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_MESSAGE(nullptr != get(result.errors[0]), "Expected a TypeMismatch but got " << result.errors[0]); -} - TEST_CASE_FIXTURE(Fixture, "assigning_to_an_unsealed_table_with_string_literal_should_infer_new_properties_over_indexer") { CheckResult result = check(R"( @@ -1392,6 +1465,8 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer2") TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") { + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + CheckResult result = check(R"( local function foo(a: {[string]: number, a: string}) end foo({ a = 1 }) @@ -1402,8 +1477,21 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") ToStringOptions o{/* exhaustive= */ true}; TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); - CHECK_EQ("string", toString(tm->wantedType, o)); - CHECK_EQ("number", toString(tm->givenType, o)); + CHECK_EQ("{| [string]: number, a: string |}", toString(tm->wantedType, o)); + CHECK_EQ("{| a: number |}", toString(tm->givenType, o)); +} + +TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer4") +{ + CheckResult result = check(R"( + local function foo(a: {[string]: number, a: string}, i: string) + return a[i] + end + local hi: number = foo({ a = "hi" }, "a") -- shouldn't typecheck since at runtime hi is "hi" + )"); + + // This typechecks but shouldn't + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multiple_errors") @@ -1446,22 +1534,32 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multi TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multiple_errors") { CheckResult result = check(R"( - local vec3 = {x = 1, y = 2, z = 3} - local vec1 = {x = 1} + local vec3 = {{x = 1, y = 2, z = 3}} + local vec1 = {{x = 1}} vec1 = vec3 )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - MissingProperties* mp = get(result.errors[0]); - REQUIRE(mp); - CHECK_EQ(mp->context, MissingProperties::Extra); - REQUIRE_EQ(2, mp->properties.size()); - CHECK_EQ(mp->properties[0], "y"); - CHECK_EQ(mp->properties[1], "z"); - CHECK_EQ("vec1", toString(mp->superType)); - CHECK_EQ("vec3", toString(mp->subType)); + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("vec1", toString(tm->wantedType)); + CHECK_EQ("vec3", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_is_ok") +{ + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + + CheckResult result = check(R"( + local vec3 = {x = 1, y = 2, z = 3} + local vec1 = {x = 1} + + vec1 = vec3 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "type_mismatch_on_massive_table_is_cut_short") @@ -1824,4 +1922,32 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_nonstrict") +{ + CheckResult result = check(R"( + --!nonstrict + local buttons = {} + table.insert(buttons, { a = 1 }) + table.insert(buttons, { a = 2, b = true }) + table.insert(buttons, { a = 3 }) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_strict") +{ + ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + + CheckResult result = check(R"( + --!strict + local buttons = {} + table.insert(buttons, { a = 1 }) + table.insert(buttons, { a = 2, b = true }) + table.insert(buttons, { a = 3 }) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 37333b193..b75878b7e 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -3,6 +3,7 @@ #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" #include "Luau/VisitTypeVar.h" @@ -978,23 +979,6 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") CHECK_EQ("t1 where t1 = (t1) -> ()", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") -{ - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - - CheckResult result = check(R"( - type F = () -> F? - local function f() - return f - end - - local g: F = f - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("t1 where t1 = () -> t1?", toString(requireType("g"))); -} - // TODO: File a Jira about this /* TEST_CASE_FIXTURE(Fixture, "unifying_vararg_pack_with_fixed_length_pack_produces_fixed_length_pack") @@ -1257,23 +1241,6 @@ TEST_CASE_FIXTURE(Fixture, "instantiate_cyclic_generic_function") REQUIRE_EQ(follow(*methodArg), follow(arg)); } -TEST_CASE_FIXTURE(Fixture, "cyclic_types_of_named_table_fields_do_not_expand_when_stringified") -{ - CheckResult result = check(R"( - --!strict - type Node = { Parent: Node?; } - local node: Node; - node.Parent = 1 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("Node?", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); -} - TEST_CASE_FIXTURE(Fixture, "varlist_declared_by_for_in_loop_should_be_free") { CheckResult result = check(R"( @@ -2591,48 +2558,6 @@ TEST_CASE_FIXTURE(Fixture, "toposort_doesnt_break_mutual_recursion") dumpErrors(result); } -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types") -{ - CheckResult result = check(R"( - --!strict - type T = { f: a, g: U } - type U = { h: a, i: T? } - local x: T = { f = 37, g = { h = 5, i = nil } } - x.g.i = x - local y: T = { f = "hi", g = { h = "lo", i = nil } } - y.g.i = y - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_errors") -{ - CheckResult result = check(R"( - --!strict - type T = { f: a, g: U } - type U = { h: b, i: T? } - local x: T = { f = 37, g = { h = 5, i = nil } } - x.g.i = x - local y: T = { f = "hi", g = { h = 5, i = nil } } - y.g.i = y - )"); - - LUAU_REQUIRE_ERRORS(result); - - // We had a UAF in this example caused by not cloning type function arguments - ModulePtr module = frontend.moduleResolver.getModule("MainModule"); - unfreeze(module->interfaceTypes); - copyErrors(module->errors, module->interfaceTypes); - freeze(module->interfaceTypes); - module->internalTypes.clear(); - module->astTypes.clear(); - - // Make sure the error strings don't include "VALUELESS" - for (auto error : module->errors) - CHECK_MESSAGE(toString(error).find("VALUELESS") == std::string::npos, toString(error)); -} - TEST_CASE_FIXTURE(Fixture, "object_constructor_can_refer_to_method_of_self") { // CLI-30902 @@ -3369,16 +3294,7 @@ TEST_CASE_FIXTURE(Fixture, "unknown_type_in_comparison") end )"); - if (FFlag::LuauEqConstraint) - { - LUAU_REQUIRE_NO_ERRORS(result); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE(get(result.errors[0])); - } + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable") @@ -3388,18 +3304,8 @@ TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable print((x == true and (x .. "y")) .. 1) )"); - if (FFlag::LuauEqConstraint) - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - REQUIRE(get(result.errors[0])); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(2, result); - - CHECK_EQ("Type 'boolean' could not be converted into 'number | string'", toString(result.errors[0])); - CHECK_EQ("Type 'boolean | string' could not be converted into 'number | string'", toString(result.errors[1])); - } + LUAU_REQUIRE_ERROR_COUNT(1, result); + REQUIRE(get(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "concat_op_on_string_lhs_and_free_rhs") @@ -3511,25 +3417,6 @@ _(...)(...,setfenv,_):_G() )"); } -TEST_CASE_FIXTURE(Fixture, "use_table_name_and_generic_params_in_errors") -{ - CheckResult result = check(R"( - type Pair = {first: T, second: U} - local a: Pair - local b: Pair - - a = b - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - - CHECK_EQ("Pair", toString(tm->wantedType)); - CHECK_EQ("Pair", toString(tm->givenType)); -} - TEST_CASE_FIXTURE(Fixture, "cyclic_type_packs") { // this has a risk of creating cyclic type packs, causing infinite loops / OOMs @@ -3639,17 +3526,6 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_where_iteratee_is_free") )"); } -TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_type_definition") -{ - CheckResult result = check(R"( - type A = number - type A = string -- Redefinition of type 'A', previously defined at line 1 - local foo: string = 1 -- No "Type 'number' could not be converted into 'string'" - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); -} - TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery") { CheckResult result = check(R"( @@ -3752,38 +3628,6 @@ TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operato CHECK_EQ("Type 'number | string' cannot be compared with relational operator <", ge->message); } -TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type") -{ - CheckResult result = check(R"( - type Table = { a: T } - type Wrapped = Table - local l: Wrapped = 2 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("Wrapped", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); -} - -TEST_CASE_FIXTURE(Fixture, "stringify_type_alias_of_recursive_template_table_type2") -{ - CheckResult result = check(R"( - type Table = { a: T } - type Wrapped = (Table) -> string - local l: Wrapped = 2 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("t1 where t1 = ({| a: t1 |}) -> string", toString(tm->wantedType)); - CHECK_EQ(typeChecker.numberType, tm->givenType); -} - TEST_CASE_FIXTURE(Fixture, "index_expr_should_be_checked") { CheckResult result = check(R"( @@ -3909,19 +3753,6 @@ TEST_CASE_FIXTURE(Fixture, "stringify_nested_unions_with_optionals") CHECK_EQ("(boolean | number | string)?", toString(tm->givenType)); } -// Check that recursive intersection type doesn't generate an OOM -TEST_CASE_FIXTURE(Fixture, "cli_38393_recursive_intersection_oom") -{ - CheckResult result = check(R"( - function _(l0:(t0)&((t0)&(((t0)&((t0)->()))->(typeof(_),typeof(# _)))),l39,...):any - end - type t0 = ((typeof(_))&((t0)&(((typeof(_))&(t0))->typeof(_))),{n163:any,})->(any,typeof(_)) - _(_) - )"); - - LUAU_REQUIRE_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign") { // In non-strict mode, global definition is still allowed @@ -3974,16 +3805,6 @@ TEST_CASE_FIXTURE(Fixture, "loop_typecheck_crash_on_empty_optional") LUAU_REQUIRE_ERROR_COUNT(2, result); } -TEST_CASE_FIXTURE(Fixture, "type_alias_fwd_declaration_is_precise") -{ - CheckResult result = check(R"( - local foo: Id = 1 - type Id = T - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "cli_39932_use_unifier_in_ensure_methods") { CheckResult result = check(R"( @@ -4014,81 +3835,6 @@ end LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "corecursive_types_generic") -{ - const std::string code = R"( - type A = {v:T, b:B} - type B = {v:T, a:A} - local aa:A - local bb = aa - )"; - - const std::string expected = R"( - type A = {v:T, b:B} - type B = {v:T, a:A} - local aa:A - local bb:A=aa - )"; - - CHECK_EQ(expected, decorateWithTypes(code)); - CheckResult result = check(code); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "corecursive_function_types") -{ - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - - CheckResult result = check(R"( - type A = () -> (number, B) - type B = () -> (string, A) - local a: A - local b: B - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("t1 where t1 = () -> (number, () -> (string, t1))", toString(requireType("a"))); - CHECK_EQ("t1 where t1 = () -> (string, () -> (number, t1))", toString(requireType("b"))); -} - -TEST_CASE_FIXTURE(Fixture, "generic_param_remap") -{ - const std::string code = R"( - -- An example of a forwarded use of a type that has different type arguments than parameters - type A = {t:T, u:U, next:A?} - local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } - local bb = aa - )"; - - const std::string expected = R"( - - type A = {t:T, u:U, next:A?} - local aa:A = { t = 5, u = 'hi', next = { t = 'lo', u = 8 } } - local bb:A=aa - )"; - - CHECK_EQ(expected, decorateWithTypes(code)); - CheckResult result = check(code); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "export_type_and_type_alias_are_duplicates") -{ - CheckResult result = check(R"( - export type Foo = number - type Foo = number - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - auto dtd = get(result.errors[0]); - REQUIRE(dtd); - CHECK_EQ(dtd->name, "Foo"); -} - TEST_CASE_FIXTURE(Fixture, "dont_report_type_errors_within_an_AstStatError") { CheckResult result = check(R"( @@ -4193,30 +3939,6 @@ TEST_CASE_FIXTURE(Fixture, "luau_resolves_symbols_the_same_way_lua_does") REQUIRE_MESSAGE(get(e) != nullptr, "Expected UnknownSymbol, but got " << e); } -TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") -{ - ScopedFastFlag sffs3{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; - - CheckResult result = check(R"( - type Node = { value: T, child: Node? } - - local function visitor(node: Node?) - local a: Node - - if node then - a = node.child -- Observe the output of the error message. - end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - auto e = get(result.errors[0]); - CHECK_EQ("Node?", toString(e->givenType)); - CHECK_EQ("Node", toString(e->wantedType)); -} - TEST_CASE_FIXTURE(Fixture, "operator_eq_verifies_types_do_intersect") { CheckResult result = check(R"( @@ -4272,181 +3994,6 @@ local tbl: string = require(game.A) CHECK_EQ("Type '{| def: number |}' could not be converted into 'string'", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "general_require_multi_assign") -{ - fileResolver.source["workspace/A"] = R"( - export type myvec2 = {x: number, y: number} - return {} - )"; - - fileResolver.source["workspace/B"] = R"( - export type myvec3 = {x: number, y: number, z: number} - return {} - )"; - - fileResolver.source["workspace/C"] = R"( - local Foo, Bar = require(workspace.A), require(workspace.B) - - local a: Foo.myvec2 - local b: Bar.myvec3 - )"; - - CheckResult result = frontend.check("workspace/C"); - LUAU_REQUIRE_NO_ERRORS(result); - ModulePtr m = frontend.moduleResolver.modules["workspace/C"]; - - REQUIRE(m != nullptr); - - std::optional aTypeId = lookupName(m->getModuleScope(), "a"); - REQUIRE(aTypeId); - const Luau::TableTypeVar* aType = get(follow(*aTypeId)); - REQUIRE(aType); - REQUIRE(aType->props.size() == 2); - - std::optional bTypeId = lookupName(m->getModuleScope(), "b"); - REQUIRE(bTypeId); - const Luau::TableTypeVar* bType = get(follow(*bTypeId)); - REQUIRE(bType); - REQUIRE(bType->props.size() == 3); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_import_mutation") -{ - CheckResult result = check("type t10 = typeof(table)"); - LUAU_REQUIRE_NO_ERRORS(result); - - TypeId ty = getGlobalBinding(frontend.typeChecker, "table"); - CHECK_EQ(toString(ty), "table"); - - const TableTypeVar* ttv = get(ty); - REQUIRE(ttv); - - CHECK(ttv->instantiatedTypeParams.empty()); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_local_mutation") -{ - CheckResult result = check(R"( -type Cool = { a: number, b: string } -local c: Cool = { a = 1, b = "s" } -type NotCool = Cool -)"); - LUAU_REQUIRE_NO_ERRORS(result); - - std::optional ty = requireType("c"); - REQUIRE(ty); - CHECK_EQ(toString(*ty), "Cool"); - - const TableTypeVar* ttv = get(*ty); - REQUIRE(ttv); - - CHECK(ttv->instantiatedTypeParams.empty()); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_local_rename") -{ - CheckResult result = check(R"( -type Cool = { a: number, b: string } -type NotCool = Cool -local c: Cool = { a = 1, b = "s" } -local d: NotCool = { a = 1, b = "s" } -)"); - LUAU_REQUIRE_NO_ERRORS(result); - - std::optional ty = requireType("c"); - REQUIRE(ty); - CHECK_EQ(toString(*ty), "Cool"); - - ty = requireType("d"); - REQUIRE(ty); - CHECK_EQ(toString(*ty), "NotCool"); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_local_synthetic_mutation") -{ - CheckResult result = check(R"( -local c = { a = 1, b = "s" } -type Cool = typeof(c) -)"); - LUAU_REQUIRE_NO_ERRORS(result); - - std::optional ty = requireType("c"); - REQUIRE(ty); - - const TableTypeVar* ttv = get(*ty); - REQUIRE(ttv); - CHECK_EQ(ttv->name, "Cool"); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_type") -{ - ScopedFastFlag luauFixTableTypeAliasClone{"LuauFixTableTypeAliasClone", true}; - - fileResolver.source["game/A"] = R"( -export type X = { a: number, b: X? } -return {} - )"; - - CheckResult aResult = frontend.check("game/A"); - LUAU_REQUIRE_NO_ERRORS(aResult); - - CheckResult bResult = check(R"( -local Import = require(game.A) -type X = Import.X - )"); - LUAU_REQUIRE_NO_ERRORS(bResult); - - std::optional ty1 = lookupImportedType("Import", "X"); - REQUIRE(ty1); - - std::optional ty2 = lookupType("X"); - REQUIRE(ty2); - - CHECK_EQ(follow(*ty1), follow(*ty2)); -} - -TEST_CASE_FIXTURE(Fixture, "type_alias_of_an_imported_recursive_generic_type") -{ - ScopedFastFlag luauFixTableTypeAliasClone{"LuauFixTableTypeAliasClone", true}; - - fileResolver.source["game/A"] = R"( -export type X = { a: T, b: U, C: X? } -return {} - )"; - - CheckResult aResult = frontend.check("game/A"); - LUAU_REQUIRE_NO_ERRORS(aResult); - - CheckResult bResult = check(R"( -local Import = require(game.A) -type X = Import.X - )"); - LUAU_REQUIRE_NO_ERRORS(bResult); - - std::optional ty1 = lookupImportedType("Import", "X"); - REQUIRE(ty1); - - std::optional ty2 = lookupType("X"); - REQUIRE(ty2); - - CHECK_EQ(toString(*ty1, {true}), toString(*ty2, {true})); - - bResult = check(R"( -local Import = require(game.A) -type X = Import.X - )"); - LUAU_REQUIRE_NO_ERRORS(bResult); - - ty1 = lookupImportedType("Import", "X"); - REQUIRE(ty1); - - ty2 = lookupType("X"); - REQUIRE(ty2); - - CHECK_EQ(toString(*ty1, {true}), "t1 where t1 = {| C: t1?, a: T, b: U |}"); - CHECK_EQ(toString(*ty2, {true}), "{| C: t1, a: U, b: T |} where t1 = {| C: t1, a: U, b: T |}?"); -} - TEST_CASE_FIXTURE(Fixture, "nonstrict_self_mismatch_tail") { CheckResult result = check(R"( @@ -4560,32 +4107,6 @@ local c = a(2) -- too many arguments CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "module_export_free_type_leak") -{ - CheckResult result = check(R"( -function get() - return function(obj) return true end -end - -export type f = typeof(get()) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "module_export_wrapped_free_type_leak") -{ - CheckResult result = check(R"( -function get() - return {a = 1, b = function(obj) return true end} -end - -export type f = typeof(get()) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "custom_require_global") { CheckResult result = check(R"( @@ -4768,8 +4289,6 @@ TEST_CASE_FIXTURE(Fixture, "no_heap_use_after_free_error") TEST_CASE_FIXTURE(Fixture, "dont_invalidate_the_properties_iterator_of_free_table_when_rolled_back") { - ScopedFastFlag sff{"LuauLogTableTypeVarBoundTo", true}; - fileResolver.source["Module/Backend/Types"] = R"( export type Fiber = { return_: Fiber? @@ -4849,8 +4368,8 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload") ModulePtr module = getMainModule(); auto it = module->astOverloadResolvedTypes.find(parentExpr); - REQUIRE(it != module->astOverloadResolvedTypes.end()); - CHECK_EQ(toString(it->second), "(number) -> number"); + REQUIRE(it); + CHECK_EQ(toString(*it), "(number) -> number"); } TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") @@ -5013,76 +4532,6 @@ g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_ok") -{ - CheckResult result = check(R"( - type Tree = { data: T, children: Forest } - type Forest = {Tree} - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_1") -{ - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - - CheckResult result = check(R"( - -- OK because forwarded types are used with their parameters. - type Tree = { data: T, children: Forest } - type Forest = {Tree<{T}>} - )"); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_restriction_not_ok_2") -{ - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - - CheckResult result = check(R"( - -- Not OK because forwarded types are used with different types than their parameters. - type Forest = {Tree<{T}>} - type Tree = { data: T, children: Forest } - )"); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_ok") -{ - CheckResult result = check(R"( - type Tree1 = { data: T, children: {Tree2} } - type Tree2 = { data: U, children: {Tree1} } - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "mutually_recursive_types_swapsies_not_ok") -{ - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - - CheckResult result = check(R"( - type Tree1 = { data: T, children: {Tree2} } - type Tree2 = { data: U, children: {Tree1} } - )"); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "free_variables_from_typeof_in_aliases") -{ - CheckResult result = check(R"( - function f(x) return x[1] end - -- x has type X? for a free type variable X - local x = f ({}) - type ContainsFree = { this: a, that: typeof(x) } - type ContainsContainsFree = { that: ContainsFree } - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "infer_generic_lib_function_function_argument") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 91ac9f062..1f4b63ef2 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 5f7f28474..3e1dedd47 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -294,4 +294,370 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type Packed = (T...) -> T... +local a: Packed<> +local b: Packed +local c: Packed + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + auto tf = lookupType("Packed"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "(T...) -> (T...)"); + CHECK_EQ(toString(requireType("a")), "() -> ()"); + CHECK_EQ(toString(requireType("b")), "(number) -> number"); + CHECK_EQ(toString(requireType("c")), "(string, number) -> (string, number)"); + + result = check(R"( +-- (U..., T) cannot be parsed right now +type Packed = { f: (a: T, U...) -> (T, U...) } +local a: Packed +local b: Packed +local c: Packed + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + tf = lookupType("Packed"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "Packed"); + CHECK_EQ(toString(*tf, {true}), "{| f: (T, U...) -> (T, U...) |}"); + + auto ttvA = get(requireType("a")); + REQUIRE(ttvA); + CHECK_EQ(toString(requireType("a")), "Packed"); + CHECK_EQ(toString(requireType("a"), {true}), "{| f: (number) -> (number) |}"); + REQUIRE(ttvA->instantiatedTypeParams.size() == 1); + REQUIRE(ttvA->instantiatedTypePackParams.size() == 1); + CHECK_EQ(toString(ttvA->instantiatedTypeParams[0], {true}), "number"); + CHECK_EQ(toString(ttvA->instantiatedTypePackParams[0], {true}), ""); + + auto ttvB = get(requireType("b")); + REQUIRE(ttvB); + CHECK_EQ(toString(requireType("b")), "Packed"); + CHECK_EQ(toString(requireType("b"), {true}), "{| f: (string, number) -> (string, number) |}"); + REQUIRE(ttvB->instantiatedTypeParams.size() == 1); + REQUIRE(ttvB->instantiatedTypePackParams.size() == 1); + CHECK_EQ(toString(ttvB->instantiatedTypeParams[0], {true}), "string"); + CHECK_EQ(toString(ttvB->instantiatedTypePackParams[0], {true}), "number"); + + auto ttvC = get(requireType("c")); + REQUIRE(ttvC); + CHECK_EQ(toString(requireType("c")), "Packed"); + CHECK_EQ(toString(requireType("c"), {true}), "{| f: (string, number, boolean) -> (string, number, boolean) |}"); + REQUIRE(ttvC->instantiatedTypeParams.size() == 1); + REQUIRE(ttvC->instantiatedTypePackParams.size() == 1); + CHECK_EQ(toString(ttvC->instantiatedTypeParams[0], {true}), "string"); + CHECK_EQ(toString(ttvC->instantiatedTypePackParams[0], {true}), "number, boolean"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_import") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + fileResolver.source["game/A"] = R"( +export type Packed = { a: T, b: (U...) -> () } +return {} + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = check(R"( +local Import = require(game.A) +local a: Import.Packed +local b: Import.Packed +local c: Import.Packed +local d: { a: typeof(c) } + )"); + LUAU_REQUIRE_NO_ERRORS(bResult); + + auto tf = lookupImportedType("Import", "Packed"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "Packed"); + CHECK_EQ(toString(*tf, {true}), "{| a: T, b: (U...) -> () |}"); + + CHECK_EQ(toString(requireType("a"), {true}), "{| a: number, b: () -> () |}"); + CHECK_EQ(toString(requireType("b"), {true}), "{| a: string, b: (number) -> () |}"); + CHECK_EQ(toString(requireType("c"), {true}), "{| a: string, b: (number, boolean) -> () |}"); + CHECK_EQ(toString(requireType("d")), "{| a: Packed |}"); +} + +TEST_CASE_FIXTURE(Fixture, "type_pack_type_parameters") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + fileResolver.source["game/A"] = R"( +export type Packed = { a: T, b: (U...) -> () } +return {} + )"; + + CheckResult cResult = check(R"( +local Import = require(game.A) +type Alias = Import.Packed +local a: Alias + +type B = Import.Packed +type C = Import.Packed + )"); + LUAU_REQUIRE_NO_ERRORS(cResult); + + auto tf = lookupType("Alias"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "Alias"); + CHECK_EQ(toString(*tf, {true}), "{| a: S, b: (T, R...) -> () |}"); + + CHECK_EQ(toString(requireType("a"), {true}), "{| a: string, b: (number, boolean) -> () |}"); + + tf = lookupType("B"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "B"); + CHECK_EQ(toString(*tf, {true}), "{| a: string, b: (X...) -> () |}"); + + tf = lookupType("C"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), "C"); + CHECK_EQ(toString(*tf, {true}), "{| a: string, b: (number, X...) -> () |}"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_nested") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type Packed1 = (T...) -> (T...) +type Packed2 = (Packed1, T...) -> (Packed1, T...) +type Packed3 = (Packed2, T...) -> (Packed2, T...) +type Packed4 = (Packed3, T...) -> (Packed3, T...) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + auto tf = lookupType("Packed4"); + REQUIRE(tf); + CHECK_EQ(toString(*tf), + "((((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...) -> (((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...), T...) -> " + "((((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...) -> (((T...) -> (T...), T...) -> ((T...) -> (T...), T...), T...), T...)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_variadic") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type X = (T...) -> (string, T...) + +type D = X<...number> +type E = X<(number, ...string)> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("D")), "(...number) -> (string, ...number)"); + CHECK_EQ(toString(*lookupType("E")), "(number, ...string) -> (string, number, ...string)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_multi") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type Y = (T...) -> (U...) +type A = Y +type B = Y<(number, ...string), S...> + +type Z = (T) -> (U...) +type E = Z +type F = Z + +type W = (T, U...) -> (T, V...) +type H = W +type I = W + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("A")), "(S...) -> (S...)"); + CHECK_EQ(toString(*lookupType("B")), "(number, ...string) -> (S...)"); + + CHECK_EQ(toString(*lookupType("E")), "(number) -> (S...)"); + CHECK_EQ(toString(*lookupType("F")), "(number) -> (string, S...)"); + + CHECK_EQ(toString(*lookupType("H")), "(number, S...) -> (number, R...)"); + CHECK_EQ(toString(*lookupType("I")), "(number, string, S...) -> (number, R...)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type X = (T...) -> (T...) + +type A = X<(S...)> +type B = X<()> +type C = X<(number)> +type D = X<(number, string)> +type E = X<(...number)> +type F = X<(string, ...number)> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("A")), "(S...) -> (S...)"); + CHECK_EQ(toString(*lookupType("B")), "() -> ()"); + CHECK_EQ(toString(*lookupType("C")), "(number) -> number"); + CHECK_EQ(toString(*lookupType("D")), "(number, string) -> (number, string)"); + CHECK_EQ(toString(*lookupType("E")), "(...number) -> (...number)"); + CHECK_EQ(toString(*lookupType("F")), "(string, ...number) -> (string, ...number)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type Y = (T...) -> (U...) + +type A = Y<(number, string), (boolean)> +type B = Y<(), ()> +type C = Y<...string, (number, S...)> +type D = Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("A")), "(number, string) -> boolean"); + CHECK_EQ(toString(*lookupType("B")), "() -> ()"); + CHECK_EQ(toString(*lookupType("C")), "(...string) -> (number, S...)"); + CHECK_EQ(toString(*lookupType("D")), "(X...) -> (number, string, X...)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi_tostring") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + ScopedFastFlag luauInstantiatedTypeParamRecursion("LuauInstantiatedTypeParamRecursion", true); // For correct toString block + + CheckResult result = check(R"( +type Y = { f: (T...) -> (U...) } + +local a: Y<(number, string), (boolean)> +local b: Y<(), ()> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y<(number, string), (boolean)>"); + CHECK_EQ(toString(requireType("b")), "Y<(), ()>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_backwards_compatible") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type X = () -> T +type Y = (T) -> U + +type A = X<(number)> +type B = Y<(number), (boolean)> +type C = Y<(number), boolean> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("A")), "() -> number"); + CHECK_EQ(toString(*lookupType("B")), "(number) -> boolean"); + CHECK_EQ(toString(*lookupType("C")), "(number) -> boolean"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_errors") +{ + ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + CheckResult result = check(R"( +type Packed = (T, U) -> (V...) +local b: Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects at least 2 type arguments, but only 1 is specified"); + + result = check(R"( +type Packed = (T, U) -> () +type B = Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 0 type pack arguments, but 1 is specified"); + + result = check(R"( +type Packed = (T...) -> (U...) +type Other = Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type parameters must come before type pack parameters"); + + result = check(R"( +type Packed = (T) -> U +type Other = Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 2 type arguments, but only 1 is specified"); + + result = check(R"( +type Packed = (T...) -> T... +local a: Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type parameter list is required"); + + result = check(R"( +type Packed = (T...) -> (U...) +type Other = Packed<> + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 2 type pack arguments, but none are specified"); + + result = check(R"( +type Packed = (T...) -> (U...) +type Other = Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 2 type pack arguments, but only 1 is specified"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index ae4d836b4..037144e2a 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -237,21 +237,7 @@ TEST_CASE_FIXTURE(Fixture, "union_equality_comparisons") local z = a == c )"); - if (FFlag::LuauEqConstraint) - { - LUAU_REQUIRE_NO_ERRORS(result); - } - else - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(*typeChecker.booleanType, *requireType("x")); - CHECK_EQ(*typeChecker.booleanType, *requireType("y")); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("(number | string)?", toString(*tm->wantedType)); - CHECK_EQ("boolean | number", toString(*tm->givenType)); - } + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "optional_union_members") diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 98ce9f939..a679e3fd2 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -1,5 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" +#include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tools/tracegraph.py b/tools/tracegraph.py new file mode 100644 index 000000000..a46423e7e --- /dev/null +++ b/tools/tracegraph.py @@ -0,0 +1,95 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# Given a trace event file, this tool generates a flame graph based on the event scopes present in the file +# The result of analysis is a .svg file which can be viewed in a browser + +import sys +import svg +import json + +class Node(svg.Node): + def __init__(self): + svg.Node.__init__(self) + self.caption = "" + self.description = "" + self.ticks = 0 + + def text(self): + return self.caption + + def title(self): + return self.caption + + def details(self, root): + return "{} ({:,} usec, {:.1%}); self: {:,} usec".format(self.description, self.width, self.width / root.width, self.ticks) + +with open(sys.argv[1]) as f: + dump = f.read() + +root = Node() + +# Finish the file +if not dump.endswith("]"): + dump += "{}]" + +data = json.loads(dump) + +stacks = {} + +for l in data: + if len(l) == 0: + continue + + # Track stack of each thread, but aggregate values together + tid = l["tid"] + + if not tid in stacks: + stacks[tid] = [] + stack = stacks[tid] + + if l["ph"] == 'B': + stack.append(l) + elif l["ph"] == 'E': + node = root + + for e in stack: + caption = e["name"] + description = '' + + if "args" in e: + for arg in e["args"]: + if len(description) != 0: + description += ", " + + description += "{}: {}".format(arg, e["args"][arg]) + + child = node.child(caption + description) + child.caption = caption + child.description = description + + node = child + + begin = stack[-1] + + ticks = l["ts"] - begin["ts"] + rawticks = ticks + + # Flame graph requires ticks without children duration + if "childts" in begin: + ticks -= begin["childts"] + + node.ticks += int(ticks) + + stack.pop() + + if len(stack): + parent = stack[-1] + + if "childts" in parent: + parent["childts"] += rawticks + else: + parent["childts"] = rawticks + +svg.layout(root, lambda n: n.ticks) +svg.display(root, "Flame Graph", "hot", flip = True) From 34cf695fbc35eb435dcd9fb85c3b98234fdd266c Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 4 Nov 2021 19:42:00 -0700 Subject: [PATCH 02/32] Sync to upstream/release/503 - A series of major optimizations to type checking performance on complex programs/types (up to two orders of magnitude speedup for programs involving huge tagged unions) - Fix a few issues encountered by UBSAN (and maybe fix s390x builds) - Fix gcc-11 test builds - Fix a rare corner case where luau_load wouldn't wake inactive threads which could result in a use-after-free due to GC - Fix CLI crash when error object that's not a string escapes to top level --- Analysis/include/Luau/BuiltinDefinitions.h | 1 - Analysis/include/Luau/Quantify.h | 14 + Analysis/include/Luau/ToString.h | 2 + Analysis/include/Luau/TopoSortStatements.h | 1 + Analysis/include/Luau/TxnLog.h | 28 +- Analysis/include/Luau/TypeInfer.h | 7 +- Analysis/include/Luau/TypeVar.h | 9 +- Analysis/include/Luau/Unifier.h | 20 +- Analysis/include/Luau/UnifierSharedState.h | 44 ++ Analysis/include/Luau/VisitTypeVar.h | 59 +- Analysis/src/Autocomplete.cpp | 5 +- Analysis/src/BuiltinDefinitions.cpp | 12 - Analysis/src/Error.cpp | 21 +- Analysis/src/Frontend.cpp | 11 +- Analysis/src/Module.cpp | 18 +- Analysis/src/Quantify.cpp | 90 +++ Analysis/src/RequireTracer.cpp | 6 +- Analysis/src/ToString.cpp | 25 +- Analysis/src/TopoSortStatements.cpp | 25 + Analysis/src/TxnLog.cpp | 37 +- Analysis/src/TypeAttach.cpp | 93 ++- Analysis/src/TypeInfer.cpp | 199 ++--- Analysis/src/TypeVar.cpp | 86 ++- Analysis/src/Unifier.cpp | 349 ++++++++- Ast/include/Luau/TimeTrace.h | 16 +- Ast/src/Parser.cpp | 2 +- Ast/src/TimeTrace.cpp | 5 +- CLI/Repl.cpp | 24 +- Compiler/include/Luau/Bytecode.h | 4 +- Compiler/src/Compiler.cpp | 14 +- Makefile | 2 + Sources.cmake | 3 + VM/include/lualib.h | 2 +- VM/src/lapi.cpp | 4 +- VM/src/laux.cpp | 2 +- VM/src/lcorolib.cpp | 72 +- VM/src/ldo.cpp | 14 +- VM/src/ldo.h | 2 +- VM/src/lfunc.cpp | 2 +- VM/src/lgc.cpp | 331 ++++---- VM/src/lgc.h | 8 +- VM/src/lmem.cpp | 2 +- VM/src/lstring.cpp | 2 +- VM/src/lvmload.cpp | 4 +- bench/tests/chess.lua | 849 +++++++++++++++++++++ bench/tests/shootout/scimark.lua | 2 +- tests/Autocomplete.test.cpp | 24 +- tests/Compiler.test.cpp | 84 +- tests/Conformance.test.cpp | 13 + tests/IostreamOptional.h | 7 +- tests/Linter.test.cpp | 14 +- tests/TypeInfer.aliases.test.cpp | 50 ++ tests/TypeInfer.builtins.test.cpp | 2 +- tests/TypeInfer.classes.test.cpp | 2 - tests/TypeInfer.generics.test.cpp | 21 + tests/TypeInfer.provisional.test.cpp | 25 +- tests/TypeInfer.refinements.test.cpp | 11 +- tests/TypeInfer.tables.test.cpp | 2 +- tests/TypeInfer.test.cpp | 62 +- tests/TypeInfer.tryUnify.test.cpp | 10 +- tests/TypeInfer.typePacks.cpp | 6 +- tests/TypeInfer.unionTypes.test.cpp | 25 +- tests/TypeVar.test.cpp | 60 ++ tests/conformance/closure.lua | 2 +- tests/conformance/coroutine.lua | 2 +- tests/conformance/gc.lua | 2 +- tests/conformance/locals.lua | 2 +- tests/conformance/math.lua | 2 +- tests/conformance/pm.lua | 4 +- tests/conformance/tmerror.lua | 15 + tools/gdb-printers.py | 8 +- 71 files changed, 2300 insertions(+), 683 deletions(-) create mode 100644 Analysis/include/Luau/Quantify.h create mode 100644 Analysis/include/Luau/UnifierSharedState.h create mode 100644 Analysis/src/Quantify.cpp create mode 100644 bench/tests/chess.lua create mode 100644 tests/conformance/tmerror.lua diff --git a/Analysis/include/Luau/BuiltinDefinitions.h b/Analysis/include/Luau/BuiltinDefinitions.h index 57a1907a5..07d897b2f 100644 --- a/Analysis/include/Luau/BuiltinDefinitions.h +++ b/Analysis/include/Luau/BuiltinDefinitions.h @@ -34,7 +34,6 @@ TypeId makeFunction( // Polymorphic std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); void attachMagicFunction(TypeId ty, MagicFunction fn); -void attachFunctionTag(TypeId ty, std::string constraint); Property makeProperty(TypeId ty, std::optional documentationSymbol = std::nullopt); void assignPropDocumentationSymbols(TableTypeVar::Props& props, const std::string& baseName); diff --git a/Analysis/include/Luau/Quantify.h b/Analysis/include/Luau/Quantify.h new file mode 100644 index 000000000..f46df1460 --- /dev/null +++ b/Analysis/include/Luau/Quantify.h @@ -0,0 +1,14 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/TypeVar.h" + +namespace Luau +{ + +struct Module; +using ModulePtr = std::shared_ptr; + +void quantify(ModulePtr module, TypeId ty, TypeLevel level); + +} // namespace Luau diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 0897ec854..e5683fc40 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -69,4 +69,6 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {}); void dump(TypeId ty); void dump(TypePackId ty); +std::string generateName(size_t n); + } // namespace Luau diff --git a/Analysis/include/Luau/TopoSortStatements.h b/Analysis/include/Luau/TopoSortStatements.h index 751694f02..4a4acfa3a 100644 --- a/Analysis/include/Luau/TopoSortStatements.h +++ b/Analysis/include/Luau/TopoSortStatements.h @@ -12,6 +12,7 @@ struct AstArray; class AstStat; bool containsFunctionCall(const AstStat& stat); +bool containsFunctionCallOrReturn(const AstStat& stat); bool isFunction(const AstStat& stat); void toposort(std::vector& stats); diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 055441ce8..322abd198 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -3,17 +3,35 @@ #include "Luau/TypeVar.h" +LUAU_FASTFLAG(LuauShareTxnSeen); + namespace Luau { // Log of where what TypeIds we are rebinding and what they used to be struct TxnLog { - TxnLog() = default; + TxnLog() + : originalSeenSize(0) + , ownedSeen() + , sharedSeen(&ownedSeen) + { + } + + explicit TxnLog(std::vector>* sharedSeen) + : originalSeenSize(sharedSeen->size()) + , ownedSeen() + , sharedSeen(sharedSeen) + { + } - explicit TxnLog(const std::vector>& seen) - : seen(seen) + explicit TxnLog(const std::vector>& ownedSeen) + : originalSeenSize(ownedSeen.size()) + , ownedSeen(ownedSeen) + , sharedSeen(nullptr) { + // This is deprecated! + LUAU_ASSERT(!FFlag::LuauShareTxnSeen); } TxnLog(const TxnLog&) = delete; @@ -38,9 +56,11 @@ struct TxnLog std::vector> typeVarChanges; std::vector> typePackChanges; std::vector>> tableChanges; + size_t originalSeenSize; public: - std::vector> seen; // used to avoid infinite recursion when types are cyclic + std::vector> ownedSeen; // used to avoid infinite recursion when types are cyclic + std::vector>* sharedSeen; // shared with all the descendent logs }; } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index d701eb248..9d62fef0b 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -11,6 +11,7 @@ #include "Luau/TypePack.h" #include "Luau/TypeVar.h" #include "Luau/Unifier.h" +#include "Luau/UnifierSharedState.h" #include #include @@ -121,7 +122,7 @@ struct TypeChecker void check(const ScopePtr& scope, const AstStatForIn& forin); void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatFunction& function); void check(const ScopePtr& scope, TypeId ty, const ScopePtr& funScope, const AstStatLocalFunction& function); - void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, bool forwardDeclare = false); + void check(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel = 0, bool forwardDeclare = false); void check(const ScopePtr& scope, const AstStatDeclareClass& declaredClass); void check(const ScopePtr& scope, const AstStatDeclareFunction& declaredFunction); @@ -336,7 +337,7 @@ struct TypeChecker // Note: `scope` must be a fresh scope. std::pair, std::vector> createGenericTypes( - const ScopePtr& scope, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); + const ScopePtr& scope, std::optional levelOpt, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); public: ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); @@ -383,6 +384,8 @@ struct TypeChecker std::function prepareModuleScope; InternalErrorReporter* iceHandler; + UnifierSharedState unifierState; + public: const TypeId nilType; const TypeId numberType; diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index d4e4e4913..9611e881f 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -405,7 +405,7 @@ const std::string* getName(TypeId type); // Checks whether a union contains all types of another union. bool isSubset(const UnionTypeVar& super, const UnionTypeVar& sub); -// Checks if a type conains generic type binders +// Checks if a type contains generic type binders bool isGeneric(const TypeId ty); // Checks if a type may be instantiated to one containing generic type binders @@ -540,4 +540,11 @@ UnionTypeVarIterator end(const UnionTypeVar* utv); using TypeIdPredicate = std::function(TypeId)>; std::vector filterMap(TypeId type, TypeIdPredicate predicate); +void attachTag(TypeId ty, const std::string& tagName); +void attachTag(Property& prop, const std::string& tagName); + +bool hasTag(TypeId ty, const std::string& tagName); +bool hasTag(const Property& prop, const std::string& tagName); +bool hasTag(const Tags& tags, const std::string& tagName); // Do not use in new work. + } // namespace Luau diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 522914b2f..56632e33c 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -6,6 +6,7 @@ #include "Luau/TxnLog.h" #include "Luau/TypeInfer.h" #include "Luau/Module.h" // FIXME: For TypeArena. It merits breaking out into its own header. +#include "Luau/UnifierSharedState.h" #include @@ -41,11 +42,14 @@ struct Unifier std::shared_ptr counters_DEPRECATED; - InternalErrorReporter* iceHandler; + UnifierSharedState& sharedState; - Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler); - Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& seen, const Location& location, - Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters_DEPRECATED = nullptr, + Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState); + Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& ownedSeen, const Location& location, + Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED = nullptr, + UnifierCounters* counters = nullptr); + Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, + Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED = nullptr, UnifierCounters* counters = nullptr); // Test whether the two type vars unify. Never commits the result. @@ -69,7 +73,8 @@ struct Unifier void tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed); void tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed); void tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer); - TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); + TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); + void cacheResult(TypeId superTy, TypeId subTy); public: void tryUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); @@ -101,8 +106,9 @@ struct Unifier [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); - DenseHashSet tempSeenTy{nullptr}; - DenseHashSet tempSeenTp{nullptr}; + // Remove with FFlagLuauCacheUnifyTableResults + DenseHashSet tempSeenTy_DEPRECATED{nullptr}; + DenseHashSet tempSeenTp_DEPRECATED{nullptr}; }; } // namespace Luau diff --git a/Analysis/include/Luau/UnifierSharedState.h b/Analysis/include/Luau/UnifierSharedState.h new file mode 100644 index 000000000..f252a004b --- /dev/null +++ b/Analysis/include/Luau/UnifierSharedState.h @@ -0,0 +1,44 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/DenseHash.h" +#include "Luau/TypeVar.h" +#include "Luau/TypePack.h" + +#include + +namespace Luau +{ +struct InternalErrorReporter; + +struct TypeIdPairHash +{ + size_t hashOne(Luau::TypeId key) const + { + return (uintptr_t(key) >> 4) ^ (uintptr_t(key) >> 9); + } + + size_t operator()(const std::pair& x) const + { + return hashOne(x.first) ^ (hashOne(x.second) << 1); + } +}; + +struct UnifierSharedState +{ + UnifierSharedState(InternalErrorReporter* iceHandler) + : iceHandler(iceHandler) + { + } + + InternalErrorReporter* iceHandler; + + DenseHashSet seenAny{nullptr}; + DenseHashMap skipCacheForType{nullptr}; + DenseHashSet, TypeIdPairHash> cachedUnify{{nullptr, nullptr}}; + + DenseHashSet tempSeenTy{nullptr}; + DenseHashSet tempSeenTp{nullptr}; +}; + +} // namespace Luau diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index df0bd4205..a866655c9 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -1,9 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/DenseHash.h" #include "Luau/TypeVar.h" #include "Luau/TypePack.h" +LUAU_FASTFLAG(LuauCacheUnifyTableResults) + namespace Luau { @@ -32,17 +35,33 @@ inline bool hasSeen(std::unordered_set& seen, const void* tv) return !seen.insert(ttv).second; } +inline bool hasSeen(DenseHashSet& seen, const void* tv) +{ + void* ttv = const_cast(tv); + + if (seen.contains(ttv)) + return true; + + seen.insert(ttv); + return false; +} + inline void unsee(std::unordered_set& seen, const void* tv) { void* ttv = const_cast(tv); seen.erase(ttv); } -template -void visit(TypePackId tp, F& f, std::unordered_set& seen); +inline void unsee(DenseHashSet& seen, const void* tv) +{ + // When DenseHashSet is used for 'visitOnce', where don't forget visited elements +} + +template +void visit(TypePackId tp, F& f, Set& seen); -template -void visit(TypeId ty, F& f, std::unordered_set& seen) +template +void visit(TypeId ty, F& f, Set& seen) { if (visit_detail::hasSeen(seen, ty)) { @@ -79,15 +98,23 @@ void visit(TypeId ty, F& f, std::unordered_set& seen) else if (auto ttv = get(ty)) { + // Some visitors want to see bound tables, that's why we visit the original type if (apply(ty, *ttv, seen, f)) { - for (auto& [_name, prop] : ttv->props) - visit(prop.type, f, seen); - - if (ttv->indexer) + if (FFlag::LuauCacheUnifyTableResults && ttv->boundTo) { - visit(ttv->indexer->indexType, f, seen); - visit(ttv->indexer->indexResultType, f, seen); + visit(*ttv->boundTo, f, seen); + } + else + { + for (auto& [_name, prop] : ttv->props) + visit(prop.type, f, seen); + + if (ttv->indexer) + { + visit(ttv->indexer->indexType, f, seen); + visit(ttv->indexer->indexResultType, f, seen); + } } } } @@ -140,8 +167,8 @@ void visit(TypeId ty, F& f, std::unordered_set& seen) visit_detail::unsee(seen, ty); } -template -void visit(TypePackId tp, F& f, std::unordered_set& seen) +template +void visit(TypePackId tp, F& f, Set& seen) { if (visit_detail::hasSeen(seen, tp)) { @@ -182,6 +209,7 @@ void visit(TypePackId tp, F& f, std::unordered_set& seen) visit_detail::unsee(seen, tp); } + } // namespace visit_detail template @@ -197,4 +225,11 @@ void visitTypeVar(TID ty, F& f) visit_detail::visit(ty, f, seen); } +template +void visitTypeVarOnce(TID ty, F& f, DenseHashSet& seen) +{ + seen.clear(); + visit_detail::visit(ty, f, seen); +} + } // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 235abf36f..3c43c8086 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -196,7 +196,8 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ auto canUnify = [&typeArena, &module](TypeId expectedType, TypeId actualType) { InternalErrorReporter iceReporter; - Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, &iceReporter); + UnifierSharedState unifierState(&iceReporter); + Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState); unifier.tryUnify(expectedType, actualType); @@ -1460,7 +1461,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M result.erase(std::string(stringKey->value.data, stringKey->value.size)); } - // If we know for sure that a key is being written, do not offer general epxression suggestions + // If we know for sure that a key is being written, do not offer general expression suggestions if (!key) autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position, result); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 3b0c21638..f6f2363c6 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -106,18 +106,6 @@ void attachMagicFunction(TypeId ty, MagicFunction fn) LUAU_ASSERT(!"Got a non functional type"); } -void attachFunctionTag(TypeId ty, std::string tag) -{ - if (auto ftv = getMutable(ty)) - { - ftv->tags.emplace_back(std::move(tag)); - } - else - { - LUAU_ASSERT(!"Got a non functional type"); - } -} - Property makeProperty(TypeId ty, std::optional documentationSymbol) { return { diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 92fbffc80..04d91444a 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -158,21 +158,20 @@ struct ErrorConverter std::string operator()(const Luau::CountMismatch& e) const { + const std::string expectedS = e.expected == 1 ? "" : "s"; + const std::string actualS = e.actual == 1 ? "" : "s"; + const std::string actualVerb = e.actual == 1 ? "is" : "are"; + switch (e.context) { case CountMismatch::Return: - { - const std::string expectedS = e.expected == 1 ? "" : "s"; - const std::string actualS = e.actual == 1 ? "is" : "are"; - return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " + actualS + - " returned here"; - } + return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + + std::to_string(e.actual) + " " + actualVerb + " returned here"; case CountMismatch::Result: - if (e.expected > e.actual) - return "Function returns " + std::to_string(e.expected) + " values but there are only " + std::to_string(e.expected) + - " values to unpack them into."; - else - return "Function only returns " + std::to_string(e.expected) + " values. " + std::to_string(e.actual) + " are required here"; + // It is alright if right hand side produces more values than the + // left hand side accepts. In this context consider only the opposite case. + return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + + std::to_string(e.actual) + " are required here"; case CountMismatch::Arg: if (FFlag::LuauTypeAliasPacks) return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index b2529840b..5e7af50c5 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -23,6 +23,7 @@ LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) LUAU_FASTFLAG(LuauNewRequireTrace) +LUAU_FASTFLAGVARIABLE(LuauClearScopes, false) namespace Luau { @@ -248,7 +249,7 @@ struct RequireCycle // Note that this is O(V^2) for a fully connected graph and produces O(V) paths of length O(V) // However, when the graph is acyclic, this is O(V), as well as when only the first cycle is needed (stopAtFirst=true) std::vector getRequireCycles( - const std::unordered_map& sourceNodes, const SourceNode* start, bool stopAtFirst = false) + const FileResolver* resolver, const std::unordered_map& sourceNodes, const SourceNode* start, bool stopAtFirst = false) { std::vector result; @@ -282,9 +283,9 @@ std::vector getRequireCycles( if (top == start) { for (const SourceNode* node : path) - cycle.push_back(node->name); + cycle.push_back(resolver->getHumanReadableModuleName(node->name)); - cycle.push_back(top->name); + cycle.push_back(resolver->getHumanReadableModuleName(top->name)); break; } } @@ -404,7 +405,7 @@ CheckResult Frontend::check(const ModuleName& name) // however, for now getRequireCycles isn't expensive in practice on the cases we care about, and long term // all correct programs must be acyclic so this code triggers rarely if (cycleDetected) - requireCycles = getRequireCycles(sourceNodes, &sourceNode, mode == Mode::NoCheck); + requireCycles = getRequireCycles(fileResolver, sourceNodes, &sourceNode, mode == Mode::NoCheck); // This is used by the type checker to replace the resulting type of cyclic modules with any sourceModule.cyclic = !requireCycles.empty(); @@ -458,6 +459,8 @@ CheckResult Frontend::check(const ModuleName& name) module->astTypes.clear(); module->astExpectedTypes.clear(); module->astOriginalCallTypes.clear(); + if (FFlag::LuauClearScopes) + module->scopes.resize(1); } if (mode != Mode::NoCheck) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index df6be767b..2fd958965 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans) LUAU_FASTFLAG(LuauTypeAliasPacks) +LUAU_FASTFLAGVARIABLE(LuauCloneBoundTables, false) namespace Luau { @@ -299,6 +300,14 @@ void TypeCloner::operator()(const FunctionTypeVar& t) void TypeCloner::operator()(const TableTypeVar& t) { + // If table is now bound to another one, we ignore the content of the original + if (FFlag::LuauCloneBoundTables && t.boundTo) + { + TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + seenTypes[typeId] = boundTo; + return; + } + TypeId result = dest.addType(TableTypeVar{}); TableTypeVar* ttv = getMutable(result); LUAU_ASSERT(ttv != nullptr); @@ -321,8 +330,11 @@ void TypeCloner::operator()(const TableTypeVar& t) ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, encounteredFreeType), clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, encounteredFreeType)}; - if (t.boundTo) - ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + if (!FFlag::LuauCloneBoundTables) + { + if (t.boundTo) + ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + } for (TypeId& arg : ttv->instantiatedTypeParams) arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType); @@ -335,7 +347,7 @@ void TypeCloner::operator()(const TableTypeVar& t) if (ttv->state == TableState::Free) { - if (!t.boundTo) + if (FFlag::LuauCloneBoundTables || !t.boundTo) { if (encounteredFreeType) *encounteredFreeType = true; diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp new file mode 100644 index 000000000..bf6d81aa6 --- /dev/null +++ b/Analysis/src/Quantify.cpp @@ -0,0 +1,90 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Quantify.h" + +#include "Luau/VisitTypeVar.h" + +namespace Luau +{ + +struct Quantifier +{ + ModulePtr module; + TypeLevel level; + std::vector generics; + std::vector genericPacks; + + Quantifier(ModulePtr module, TypeLevel level) + : module(module) + , level(level) + { + } + + void cycle(TypeId) {} + void cycle(TypePackId) {} + + bool operator()(TypeId ty, const FreeTypeVar& ftv) + { + if (!level.subsumes(ftv.level)) + return false; + + *asMutable(ty) = GenericTypeVar{level}; + generics.push_back(ty); + + return false; + } + + template + bool operator()(TypeId ty, const T& t) + { + return true; + } + + template + bool operator()(TypePackId, const T&) + { + return true; + } + + bool operator()(TypeId ty, const TableTypeVar&) + { + TableTypeVar& ttv = *getMutable(ty); + + if (ttv.state == TableState::Sealed || ttv.state == TableState::Generic) + return false; + if (!level.subsumes(ttv.level)) + return false; + + if (ttv.state == TableState::Free) + ttv.state = TableState::Generic; + else if (ttv.state == TableState::Unsealed) + ttv.state = TableState::Sealed; + + ttv.level = level; + + return true; + } + + bool operator()(TypePackId tp, const FreeTypePack& ftp) + { + if (!level.subsumes(ftp.level)) + return false; + + *asMutable(tp) = GenericTypePack{level}; + genericPacks.push_back(tp); + return true; + } +}; + +void quantify(ModulePtr module, TypeId ty, TypeLevel level) +{ + Quantifier q{std::move(module), level}; + visitTypeVar(ty, q); + + FunctionTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + ftv->generics = q.generics; + ftv->genericPacks = q.genericPacks; +} + +} // namespace Luau diff --git a/Analysis/src/RequireTracer.cpp b/Analysis/src/RequireTracer.cpp index ad4d5ef43..95910b562 100644 --- a/Analysis/src/RequireTracer.cpp +++ b/Analysis/src/RequireTracer.cpp @@ -171,7 +171,7 @@ struct RequireTracerOld : AstVisitor result.exprs[call] = {fileResolver->concat(*rootName, v)}; - // 'WaitForChild' can be used on modules that are not awailable at the typecheck time, but will be awailable at runtime + // 'WaitForChild' can be used on modules that are not available at the typecheck time, but will be available at runtime // If we fail to find such module, we will not report an UnknownRequire error if (FFlag::LuauTraceRequireLookupChild && indexName->index == "WaitForChild") result.exprs[call].optional = true; @@ -182,7 +182,7 @@ struct RequireTracerOld : AstVisitor struct RequireTracer : AstVisitor { - RequireTracer(RequireTraceResult& result, FileResolver * fileResolver, const ModuleName& currentModuleName) + RequireTracer(RequireTraceResult& result, FileResolver* fileResolver, const ModuleName& currentModuleName) : result(result) , fileResolver(fileResolver) , currentModuleName(currentModuleName) @@ -260,7 +260,7 @@ struct RequireTracer : AstVisitor // seed worklist with require arguments work.reserve(requires.size()); - for (AstExprCall* require: requires) + for (AstExprCall* require : requires) work.push_back(require->args.data[0]); // push all dependent expressions to the work stack; note that the vector is modified during traversal diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 5651af7e9..cd8180dba 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -10,7 +10,6 @@ #include #include -LUAU_FASTFLAG(LuauExtraNilRecovery) LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) LUAU_FASTFLAGVARIABLE(LuauInstantiatedTypeParamRecursion, false) LUAU_FASTFLAG(LuauTypeAliasPacks) @@ -159,15 +158,6 @@ struct StringifierState seen.erase(iter); } - static std::string generateName(size_t i) - { - std::string n; - n = char('a' + i % 26); - if (i >= 26) - n += std::to_string(i / 26); - return n; - } - std::string getName(TypeId ty) { const size_t s = result.nameMap.typeVars.size(); @@ -584,8 +574,7 @@ struct TypeVarStringifier std::vector results = {}; for (auto el : &uv) { - if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow) - el = follow(el); + el = follow(el); if (isNil(el)) { @@ -649,8 +638,7 @@ struct TypeVarStringifier std::vector results = {}; for (auto el : uv.parts) { - if (FFlag::LuauExtraNilRecovery || FFlag::LuauAddMissingFollow) - el = follow(el); + el = follow(el); std::string saved = std::move(state.result.name); @@ -1204,4 +1192,13 @@ void dump(TypePackId ty) printf("%s\n", toString(ty, opts).c_str()); } +std::string generateName(size_t i) +{ + std::string n; + n = char('a' + i % 26); + if (i >= 26) + n += std::to_string(i / 26); + return n; +} + } // namespace Luau diff --git a/Analysis/src/TopoSortStatements.cpp b/Analysis/src/TopoSortStatements.cpp index 2d3563842..dba694be0 100644 --- a/Analysis/src/TopoSortStatements.cpp +++ b/Analysis/src/TopoSortStatements.cpp @@ -298,8 +298,15 @@ struct ArcCollector : public AstVisitor struct ContainsFunctionCall : public AstVisitor { + bool alsoReturn = false; bool result = false; + ContainsFunctionCall() = default; + explicit ContainsFunctionCall(bool alsoReturn) + : alsoReturn(alsoReturn) + { + } + bool visit(AstExpr*) override { return !result; // short circuit if result is true @@ -318,6 +325,17 @@ struct ContainsFunctionCall : public AstVisitor return false; } + bool visit(AstStatReturn* stat) override + { + if (alsoReturn) + { + result = true; + return false; + } + else + return AstVisitor::visit(stat); + } + bool visit(AstExprFunction*) override { return false; @@ -479,6 +497,13 @@ bool containsFunctionCall(const AstStat& stat) return cfc.result; } +bool containsFunctionCallOrReturn(const AstStat& stat) +{ + detail::ContainsFunctionCall cfc{true}; + const_cast(stat).visit(&cfc); + return cfc.result; +} + bool isFunction(const AstStat& stat) { return stat.is() || stat.is(); diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 702d0ca2c..383bb050d 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -5,6 +5,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauShareTxnSeen, false) + namespace Luau { @@ -33,6 +35,12 @@ void TxnLog::rollback() for (auto it = tableChanges.rbegin(); it != tableChanges.rend(); ++it) std::swap(it->first->boundTo, it->second); + + if (FFlag::LuauShareTxnSeen) + { + LUAU_ASSERT(originalSeenSize <= sharedSeen->size()); + sharedSeen->resize(originalSeenSize); + } } void TxnLog::concat(TxnLog rhs) @@ -46,27 +54,44 @@ void TxnLog::concat(TxnLog rhs) tableChanges.insert(tableChanges.end(), rhs.tableChanges.begin(), rhs.tableChanges.end()); rhs.tableChanges.clear(); - seen.swap(rhs.seen); - rhs.seen.clear(); + if (!FFlag::LuauShareTxnSeen) + { + ownedSeen.swap(rhs.ownedSeen); + rhs.ownedSeen.clear(); + } } bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) { const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - return (seen.end() != std::find(seen.begin(), seen.end(), sortedPair)); + if (FFlag::LuauShareTxnSeen) + return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)); + else + return (ownedSeen.end() != std::find(ownedSeen.begin(), ownedSeen.end(), sortedPair)); } void TxnLog::pushSeen(TypeId lhs, TypeId rhs) { const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - seen.push_back(sortedPair); + if (FFlag::LuauShareTxnSeen) + sharedSeen->push_back(sortedPair); + else + ownedSeen.push_back(sortedPair); } void TxnLog::popSeen(TypeId lhs, TypeId rhs) { const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - LUAU_ASSERT(sortedPair == seen.back()); - seen.pop_back(); + if (FFlag::LuauShareTxnSeen) + { + LUAU_ASSERT(sortedPair == sharedSeen->back()); + sharedSeen->pop_back(); + } + else + { + LUAU_ASSERT(sortedPair == ownedSeen.back()); + ownedSeen.pop_back(); + } } } // namespace Luau diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 266c19865..49f8e0cac 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -6,6 +6,7 @@ #include "Luau/Parser.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" +#include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" @@ -33,14 +34,31 @@ static char* allocateString(Luau::Allocator& allocator, const char* format, Data return result; } +using SyntheticNames = std::unordered_map; + namespace Luau { + +static const char* getName(Allocator* allocator, SyntheticNames* syntheticNames, const Unifiable::Generic& gen) +{ + size_t s = syntheticNames->size(); + char*& n = (*syntheticNames)[&gen]; + if (!n) + { + std::string str = gen.explicitName ? gen.name : generateName(s); + n = static_cast(allocator->allocate(str.size() + 1)); + strcpy(n, str.c_str()); + } + + return n; +} + class TypeRehydrationVisitor { - mutable std::map seen; - mutable int count = 0; + std::map seen; + int count = 0; - bool hasSeen(const void* tv) const + bool hasSeen(const void* tv) { void* ttv = const_cast(tv); auto it = seen.find(ttv); @@ -52,15 +70,16 @@ class TypeRehydrationVisitor } public: - TypeRehydrationVisitor(Allocator* alloc, const TypeRehydrationOptions& options = TypeRehydrationOptions()) + TypeRehydrationVisitor(Allocator* alloc, SyntheticNames* syntheticNames, const TypeRehydrationOptions& options = TypeRehydrationOptions()) : allocator(alloc) + , syntheticNames(syntheticNames) , options(options) { } - AstTypePack* rehydrate(TypePackId tp) const; + AstTypePack* rehydrate(TypePackId tp); - AstType* operator()(const PrimitiveTypeVar& ptv) const + AstType* operator()(const PrimitiveTypeVar& ptv) { switch (ptv.type) { @@ -78,11 +97,11 @@ class TypeRehydrationVisitor return nullptr; } } - AstType* operator()(const AnyTypeVar&) const + AstType* operator()(const AnyTypeVar&) { return allocator->alloc(Location(), std::nullopt, AstName("any")); } - AstType* operator()(const TableTypeVar& ttv) const + AstType* operator()(const TableTypeVar& ttv) { RecursionCounter counter(&count); @@ -144,12 +163,12 @@ class TypeRehydrationVisitor return allocator->alloc(Location(), props, indexer); } - AstType* operator()(const MetatableTypeVar& mtv) const + AstType* operator()(const MetatableTypeVar& mtv) { return Luau::visit(*this, mtv.table->ty); } - AstType* operator()(const ClassTypeVar& ctv) const + AstType* operator()(const ClassTypeVar& ctv) { RecursionCounter counter(&count); @@ -176,7 +195,7 @@ class TypeRehydrationVisitor return allocator->alloc(Location(), props); } - AstType* operator()(const FunctionTypeVar& ftv) const + AstType* operator()(const FunctionTypeVar& ftv) { RecursionCounter counter(&count); @@ -253,10 +272,12 @@ class TypeRehydrationVisitor size_t i = 0; for (const auto& el : ftv.argNames) { + std::optional* arg = &argNames.data[i++]; + if (el) - argNames.data[i++] = {AstName(el->name.c_str()), el->location}; + new (arg) std::optional(AstArgumentName(AstName(el->name.c_str()), el->location)); else - argNames.data[i++] = {}; + new (arg) std::optional(); } AstArray returnTypes; @@ -290,23 +311,23 @@ class TypeRehydrationVisitor return allocator->alloc( Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation}); } - AstType* operator()(const Unifiable::Error&) const + AstType* operator()(const Unifiable::Error&) { return allocator->alloc(Location(), std::nullopt, AstName("Unifiable")); } - AstType* operator()(const GenericTypeVar& gtv) const + AstType* operator()(const GenericTypeVar& gtv) { - return allocator->alloc(Location(), std::nullopt, AstName(gtv.name.c_str())); + return allocator->alloc(Location(), std::nullopt, AstName(getName(allocator, syntheticNames, gtv))); } - AstType* operator()(const Unifiable::Bound& bound) const + AstType* operator()(const Unifiable::Bound& bound) { return Luau::visit(*this, bound.boundTo->ty); } - AstType* operator()(Unifiable::Free ftv) const + AstType* operator()(const FreeTypeVar& ftv) { return allocator->alloc(Location(), std::nullopt, AstName("free")); } - AstType* operator()(const UnionTypeVar& uv) const + AstType* operator()(const UnionTypeVar& uv) { AstArray unionTypes; unionTypes.size = uv.options.size(); @@ -317,7 +338,7 @@ class TypeRehydrationVisitor } return allocator->alloc(Location(), unionTypes); } - AstType* operator()(const IntersectionTypeVar& uv) const + AstType* operator()(const IntersectionTypeVar& uv) { AstArray intersectionTypes; intersectionTypes.size = uv.parts.size(); @@ -328,23 +349,28 @@ class TypeRehydrationVisitor } return allocator->alloc(Location(), intersectionTypes); } - AstType* operator()(const LazyTypeVar& ltv) const + AstType* operator()(const LazyTypeVar& ltv) { return allocator->alloc(Location(), std::nullopt, AstName("")); } private: Allocator* allocator; + SyntheticNames* syntheticNames; const TypeRehydrationOptions& options; }; class TypePackRehydrationVisitor { public: - TypePackRehydrationVisitor(Allocator* allocator, const TypeRehydrationVisitor& typeVisitor) + TypePackRehydrationVisitor(Allocator* allocator, SyntheticNames* syntheticNames, TypeRehydrationVisitor* typeVisitor) : allocator(allocator) + , syntheticNames(syntheticNames) , typeVisitor(typeVisitor) { + LUAU_ASSERT(allocator); + LUAU_ASSERT(syntheticNames); + LUAU_ASSERT(typeVisitor); } AstTypePack* operator()(const BoundTypePack& btp) const @@ -359,7 +385,7 @@ class TypePackRehydrationVisitor head.data = static_cast(allocator->allocate(sizeof(AstType*) * tp.head.size())); for (size_t i = 0; i < tp.head.size(); i++) - head.data[i] = Luau::visit(typeVisitor, tp.head[i]->ty); + head.data[i] = Luau::visit(*typeVisitor, tp.head[i]->ty); AstTypePack* tail = nullptr; @@ -371,12 +397,12 @@ class TypePackRehydrationVisitor AstTypePack* operator()(const VariadicTypePack& vtp) const { - return allocator->alloc(Location(), Luau::visit(typeVisitor, vtp.ty->ty)); + return allocator->alloc(Location(), Luau::visit(*typeVisitor, vtp.ty->ty)); } AstTypePack* operator()(const GenericTypePack& gtp) const { - return allocator->alloc(Location(), AstName(gtp.name.c_str())); + return allocator->alloc(Location(), AstName(getName(allocator, syntheticNames, gtp))); } AstTypePack* operator()(const FreeTypePack& gtp) const @@ -391,12 +417,13 @@ class TypePackRehydrationVisitor private: Allocator* allocator; - const TypeRehydrationVisitor& typeVisitor; + SyntheticNames* syntheticNames; + TypeRehydrationVisitor* typeVisitor; }; -AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp) const +AstTypePack* TypeRehydrationVisitor::rehydrate(TypePackId tp) { - TypePackRehydrationVisitor tprv(allocator, *this); + TypePackRehydrationVisitor tprv(allocator, syntheticNames, this); return Luau::visit(tprv, tp->ty); } @@ -431,7 +458,7 @@ class TypeAttacher : public AstVisitor { if (!type) return nullptr; - return Luau::visit(TypeRehydrationVisitor(allocator), (*type)->ty); + return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), (*type)->ty); } AstArray typeAstPack(TypePackId type) @@ -443,7 +470,7 @@ class TypeAttacher : public AstVisitor result.data = static_cast(allocator->allocate(sizeof(AstType*) * v.size())); for (size_t i = 0; i < v.size(); ++i) { - result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator), v[i]->ty); + result.data[i] = Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames), v[i]->ty); } return result; } @@ -495,7 +522,7 @@ class TypeAttacher : public AstVisitor { if (FFlag::LuauTypeAliasPacks) { - variadicAnnotation = TypeRehydrationVisitor(allocator).rehydrate(*tail); + variadicAnnotation = TypeRehydrationVisitor(allocator, &syntheticNames).rehydrate(*tail); } else { @@ -515,6 +542,7 @@ class TypeAttacher : public AstVisitor private: Module& module; Allocator* allocator; + SyntheticNames syntheticNames; }; void attachTypeData(SourceModule& source, Module& result) @@ -525,7 +553,8 @@ void attachTypeData(SourceModule& source, Module& result) AstType* rehydrateAnnotation(TypeId type, Allocator* allocator, const TypeRehydrationOptions& options) { - return Luau::visit(TypeRehydrationVisitor(allocator, options), type->ty); + SyntheticNames syntheticNames; + return Luau::visit(TypeRehydrationVisitor(allocator, &syntheticNames, options), type->ty); } } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 3a1fdfff5..38e2e5270 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -4,6 +4,7 @@ #include "Luau/Common.h" #include "Luau/ModuleResolver.h" #include "Luau/Parser.h" +#include "Luau/Quantify.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/Substitution.h" @@ -33,18 +34,16 @@ LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAGVARIABLE(LuauRankNTypes, false) LUAU_FASTFLAGVARIABLE(LuauOrPredicate, false) -LUAU_FASTFLAGVARIABLE(LuauExtraNilRecovery, false) -LUAU_FASTFLAGVARIABLE(LuauMissingUnionPropertyError, false) LUAU_FASTFLAGVARIABLE(LuauInferReturnAssertAssign, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauAddMissingFollow, false) LUAU_FASTFLAGVARIABLE(LuauTypeGuardPeelsAwaySubclasses, false) LUAU_FASTFLAGVARIABLE(LuauSlightlyMoreFlexibleBinaryPredicates, false) -LUAU_FASTFLAGVARIABLE(LuauInferFunctionArgsFix, false) LUAU_FASTFLAGVARIABLE(LuauFollowInTypeFunApply, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false) LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes) +LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAG(LuauNewRequireTrace) LUAU_FASTFLAG(LuauTypeAliasPacks) @@ -215,6 +214,7 @@ static bool isMetamethod(const Name& name) TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHandler) : resolver(resolver) , iceHandler(iceHandler) + , unifierState(iceHandler) , nilType(singletonTypes.nilType) , numberType(singletonTypes.numberType) , stringType(singletonTypes.stringType) @@ -370,13 +370,18 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) return; } + int subLevel = 0; + std::vector sorted(block.body.data, block.body.data + block.body.size); toposort(sorted); for (const auto& stat : sorted) { if (const auto& typealias = stat->as()) - check(scope, *typealias, true); + { + check(scope, *typealias, subLevel, true); + ++subLevel; + } } auto protoIter = sorted.begin(); @@ -399,8 +404,6 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) } }; - int subLevel = 0; - while (protoIter != sorted.end()) { // protoIter walks forward @@ -416,7 +419,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) // ``` // These both call each other, so `f` will be ordered before `g`, so the call to `g` // is typechecked before `g` has had its body checked. For this reason, there's three - // types for each functuion: before its body is checked, during checking its body, + // types for each function: before its body is checked, during checking its body, // and after its body is checked. // // We currently treat the before-type and the during-type as the same, @@ -433,7 +436,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) // function f(x:a):a local x: number = g(37) return x end // function g(x:number):number return f(x) end // ``` - if (containsFunctionCall(**protoIter)) + if (FFlag::LuauQuantifyInPlace2 ? containsFunctionCallOrReturn(**protoIter) : containsFunctionCall(**protoIter)) { while (checkIter != protoIter) { @@ -1080,7 +1083,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); // If in nonstrict mode and allowing redefinition of global function, restore the previous definition type - // in case this function has a differing signature. The signature discrepency will be caught in checkBlock. + // in case this function has a differing signature. The signature discrepancy will be caught in checkBlock. if (previouslyDefined) globalBindings[name] = oldBinding; else @@ -1161,7 +1164,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco scope->bindings[function.name] = {quantify(scope, ty, function.name->location), function.name->location}; } -void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias, bool forwardDeclare) +void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel, bool forwardDeclare) { // This function should be called at most twice for each type alias. // Once with forwardDeclare, and once without. @@ -1189,11 +1192,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias } else { - ScopePtr aliasScope = childScope(scope, typealias.location); + ScopePtr aliasScope = + FFlag::LuauQuantifyInPlace2 ? childScope(scope, typealias.location, subLevel) : childScope(scope, typealias.location); if (FFlag::LuauTypeAliasPacks) { - auto [generics, genericPacks] = createGenericTypes(aliasScope, typealias, typealias.generics, typealias.genericPacks); + auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks); TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); FreeTypeVar* ftv = getMutable(ty); @@ -1418,7 +1422,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo { ScopePtr funScope = childFunctionScope(scope, global.location); - auto [generics, genericPacks] = createGenericTypes(funScope, global, global.generics, global.genericPacks); + auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, global, global.generics, global.genericPacks); TypePackId argPack = resolveTypePack(funScope, global.params); TypePackId retPack = resolveTypePack(funScope, global.retTypes); @@ -1610,25 +1614,11 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn if (std::optional ty = resolveLValue(scope, *lvalue)) return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}}; - if (FFlag::LuauExtraNilRecovery) - lhsType = stripFromNilAndReport(lhsType, expr.expr->location); + lhsType = stripFromNilAndReport(lhsType, expr.expr->location); if (std::optional ty = getIndexTypeFromType(scope, lhsType, name, expr.location, true)) return {*ty}; - if (!FFlag::LuauMissingUnionPropertyError) - reportError(expr.indexLocation, UnknownProperty{lhsType, expr.index.value}); - - if (!FFlag::LuauExtraNilRecovery) - { - // Try to recover using a union without 'nil' options - if (std::optional strippedUnion = tryStripUnionFromNil(lhsType)) - { - if (std::optional ty = getIndexTypeFromType(scope, *strippedUnion, name, expr.location, false)) - return {*ty}; - } - } - return {errorType}; } @@ -1694,61 +1684,37 @@ std::optional TypeChecker::getIndexTypeFromType( } else if (const UnionTypeVar* utv = get(type)) { - if (FFlag::LuauMissingUnionPropertyError) - { - std::vector goodOptions; - std::vector badOptions; - - for (TypeId t : utv) - { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + std::vector goodOptions; + std::vector badOptions; - if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) - goodOptions.push_back(*ty); - else - badOptions.push_back(t); - } - - if (!badOptions.empty()) - { - if (addErrors) - { - if (goodOptions.empty()) - reportError(location, UnknownProperty{type, name}); - else - reportError(location, MissingUnionProperty{type, badOptions, name}); - } - return std::nullopt; - } - - std::vector result = reduceUnion(goodOptions); - - if (result.size() == 1) - return result[0]; + for (TypeId t : utv) + { + RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); - return addType(UnionTypeVar{std::move(result)}); + if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) + goodOptions.push_back(*ty); + else + badOptions.push_back(t); } - else - { - std::vector options; - for (TypeId t : utv->options) + if (!badOptions.empty()) + { + if (addErrors) { - RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); - - if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) - options.push_back(*ty); + if (goodOptions.empty()) + reportError(location, UnknownProperty{type, name}); else - return std::nullopt; + reportError(location, MissingUnionProperty{type, badOptions, name}); } + return std::nullopt; + } - std::vector result = reduceUnion(options); + std::vector result = reduceUnion(goodOptions); - if (result.size() == 1) - return result[0]; + if (result.size() == 1) + return result[0]; - return addType(UnionTypeVar{std::move(result)}); - } + return addType(UnionTypeVar{std::move(result)}); } else if (const IntersectionTypeVar* itv = get(type)) { @@ -1765,7 +1731,7 @@ std::optional TypeChecker::getIndexTypeFromType( // If no parts of the intersection had the property we looked up for, it never existed at all. if (parts.empty()) { - if (FFlag::LuauMissingUnionPropertyError && addErrors) + if (addErrors) reportError(location, UnknownProperty{type, name}); return std::nullopt; } @@ -1779,7 +1745,7 @@ std::optional TypeChecker::getIndexTypeFromType( return addType(IntersectionTypeVar{result}); } - if (FFlag::LuauMissingUnionPropertyError && addErrors) + if (addErrors) reportError(location, UnknownProperty{type, name}); return std::nullopt; @@ -2062,8 +2028,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn case AstExprUnary::Len: tablify(operandType); - if (FFlag::LuauExtraNilRecovery) - operandType = stripFromNilAndReport(operandType, expr.location); + operandType = stripFromNilAndReport(operandType, expr.location); if (get(operandType)) return {errorType}; @@ -2635,8 +2600,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope Name name = expr.index.value; - if (FFlag::LuauExtraNilRecovery) - lhs = stripFromNilAndReport(lhs, expr.expr->location); + lhs = stripFromNilAndReport(lhs, expr.expr->location); if (TableTypeVar* lhsTable = getMutableTableType(lhs)) { @@ -2710,8 +2674,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope TypeId exprType = checkExpr(scope, *expr.expr).type; tablify(exprType); - if (FFlag::LuauExtraNilRecovery) - exprType = stripFromNilAndReport(exprType, expr.expr->location); + exprType = stripFromNilAndReport(exprType, expr.expr->location); TypeId indexType = checkExpr(scope, *expr.index).type; @@ -2738,10 +2701,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!exprTable) { - if (FFlag::LuauExtraNilRecovery) - reportError(TypeError{expr.expr->location, NotATable{exprType}}); - else - reportError(TypeError{expr.location, NotATable{exprType}}); + reportError(TypeError{expr.expr->location, NotATable{exprType}}); return std::pair(errorType, nullptr); } @@ -2910,7 +2870,7 @@ std::pair TypeChecker::checkFunctionSignature( if (FFlag::LuauGenericFunctions) { - std::tie(generics, genericPacks) = createGenericTypes(funScope, expr, expr.generics, expr.genericPacks); + std::tie(generics, genericPacks) = createGenericTypes(funScope, std::nullopt, expr, expr.generics, expr.genericPacks); } TypePackId retPack; @@ -3016,9 +2976,6 @@ std::pair TypeChecker::checkFunctionSignature( if (expectedArgsCurr != expectedArgsEnd) { argType = *expectedArgsCurr; - - if (!FFlag::LuauInferFunctionArgsFix) - ++expectedArgsCurr; } else if (auto expectedArgsTail = expectedArgsCurr.tail()) { @@ -3034,7 +2991,7 @@ std::pair TypeChecker::checkFunctionSignature( funScope->bindings[local] = {argType, local->location}; argTypes.push_back(argType); - if (FFlag::LuauInferFunctionArgsFix && expectedArgsCurr != expectedArgsEnd) + if (expectedArgsCurr != expectedArgsEnd) ++expectedArgsCurr; } @@ -3170,7 +3127,7 @@ void TypeChecker::checkArgumentList( const ScopePtr& scope, Unifier& state, TypePackId argPack, TypePackId paramPack, const std::vector& argLocations) { /* Important terminology refresher: - * A function requires paramaters. + * A function requires parameters. * To call a function, you supply arguments. */ TypePackIterator argIter = begin(argPack); @@ -3402,8 +3359,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A if (!FFlag::LuauRankNTypes) instantiate(scope, selfType, expr.func->location); - if (FFlag::LuauExtraNilRecovery) - selfType = stripFromNilAndReport(selfType, expr.func->location); + selfType = stripFromNilAndReport(selfType, expr.func->location); if (std::optional propTy = getIndexTypeFromType(scope, selfType, indexExpr->index.value, expr.location, true)) { @@ -3412,34 +3368,8 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A } else { - if (!FFlag::LuauMissingUnionPropertyError) - reportError(indexExpr->indexLocation, UnknownProperty{selfType, indexExpr->index.value}); - - if (!FFlag::LuauExtraNilRecovery) - { - // Try to recover using a union without 'nil' options - if (std::optional strippedUnion = tryStripUnionFromNil(selfType)) - { - if (std::optional propTy = getIndexTypeFromType(scope, *strippedUnion, indexExpr->index.value, expr.location, false)) - { - selfType = *strippedUnion; - - functionType = *propTy; - actualFunctionType = instantiate(scope, functionType, expr.func->location); - } - } - - if (!actualFunctionType) - { - functionType = errorType; - actualFunctionType = errorType; - } - } - else - { - functionType = errorType; - actualFunctionType = errorType; - } + functionType = errorType; + actualFunctionType = errorType; } } else @@ -3555,8 +3485,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& errors) { - if (FFlag::LuauExtraNilRecovery) - fn = stripFromNilAndReport(fn, expr.func->location); + fn = stripFromNilAndReport(fn, expr.func->location); if (get(fn)) { @@ -4283,6 +4212,12 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location if (!ftv || !ftv->generics.empty() || !ftv->genericPacks.empty()) return ty; + if (FFlag::LuauQuantifyInPlace2) + { + Luau::quantify(currentModule, ty, scope->level); + return ty; + } + quantification.level = scope->level; quantification.generics.clear(); quantification.genericPacks.clear(); @@ -4491,12 +4426,12 @@ void TypeChecker::merge(RefinementMap& l, const RefinementMap& r) Unifier TypeChecker::mkUnifier(const Location& location) { - return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, location, Variance::Covariant, iceHandler}; + return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, location, Variance::Covariant, unifierState}; } Unifier TypeChecker::mkUnifier(const std::vector>& seen, const Location& location) { - return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, seen, location, Variance::Covariant, iceHandler}; + return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, seen, location, Variance::Covariant, unifierState}; } TypeId TypeChecker::freshType(const ScopePtr& scope) @@ -4753,7 +4688,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (FFlag::LuauGenericFunctions) { - std::tie(generics, genericPacks) = createGenericTypes(funcScope, annotation, func->generics, func->genericPacks); + std::tie(generics, genericPacks) = createGenericTypes(funcScope, std::nullopt, annotation, func->generics, func->genericPacks); } // TODO: better error message CLI-39912 @@ -5041,10 +4976,12 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, } std::pair, std::vector> TypeChecker::createGenericTypes( - const ScopePtr& scope, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) + const ScopePtr& scope, std::optional levelOpt, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) { LUAU_ASSERT(scope->parent); + const TypeLevel level = (FFlag::LuauQuantifyInPlace2 && levelOpt) ? *levelOpt : scope->level; + std::vector generics; for (const AstName& generic : genericNames) { @@ -5063,12 +5000,12 @@ std::pair, std::vector> TypeChecker::createGener { TypeId& cached = scope->parent->typeAliasTypeParameters[n]; if (!cached) - cached = addType(GenericTypeVar{scope->level, n}); + cached = addType(GenericTypeVar{level, n}); g = cached; } else { - g = addType(Unifiable::Generic{scope->level, n}); + g = addType(Unifiable::Generic{level, n}); } generics.push_back(g); @@ -5093,12 +5030,12 @@ std::pair, std::vector> TypeChecker::createGener { TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; if (!cached) - cached = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); + cached = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); g = cached; } else { - g = addTypePack(TypePackVar{Unifiable::Generic{scope->level, n}}); + g = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); } genericPacks.push_back(g); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index e963fc74e..e82f7519d 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -22,6 +22,7 @@ LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTFLAG(LuauRankNTypes) LUAU_FASTFLAG(LuauTypeGuardPeelsAwaySubclasses) LUAU_FASTFLAG(LuauTypeAliasPacks) +LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) namespace Luau { @@ -217,8 +218,7 @@ std::optional getMetatable(TypeId type) return mtType->metatable; else if (const ClassTypeVar* classType = get(type)) return classType->metatable; - else if (const PrimitiveTypeVar* primitiveType = get(type); - primitiveType && primitiveType->metatable) + else if (const PrimitiveTypeVar* primitiveType = get(type); primitiveType && primitiveType->metatable) { LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String); return primitiveType->metatable; @@ -1490,4 +1490,86 @@ std::vector filterMap(TypeId type, TypeIdPredicate predicate) return {}; } +static Tags* getTags(TypeId ty) +{ + ty = follow(ty); + + if (auto ftv = getMutable(ty)) + return &ftv->tags; + else if (auto ttv = getMutable(ty)) + return &ttv->tags; + else if (auto ctv = getMutable(ty)) + return &ctv->tags; + + return nullptr; +} + +void attachTag(TypeId ty, const std::string& tagName) +{ + if (!FFlag::LuauRefactorTagging) + { + if (auto ftv = getMutable(ty)) + { + ftv->tags.emplace_back(tagName); + } + else + { + LUAU_ASSERT(!"Got a non functional type"); + } + } + else + { + if (auto tags = getTags(ty)) + tags->push_back(tagName); + else + LUAU_ASSERT(!"This TypeId does not support tags"); + } +} + +void attachTag(Property& prop, const std::string& tagName) +{ + LUAU_ASSERT(FFlag::LuauRefactorTagging); + + prop.tags.push_back(tagName); +} + +// We would ideally not expose this because it could cause a footgun. +// If the Base class has a tag and you ask if Derived has that tag, it would return false. +// Unfortunately, there's already use cases that's hard to disentangle. For now, we expose it. +bool hasTag(const Tags& tags, const std::string& tagName) +{ + LUAU_ASSERT(FFlag::LuauRefactorTagging); + return std::find(tags.begin(), tags.end(), tagName) != tags.end(); +} + +bool hasTag(TypeId ty, const std::string& tagName) +{ + ty = follow(ty); + + // We special case classes because getTags only returns a pointer to one vector of tags. + // But classes has multiple vector of tags, represented throughout the hierarchy. + if (auto ctv = get(ty)) + { + while (ctv) + { + if (hasTag(ctv->tags, tagName)) + return true; + else if (!ctv->parent) + return false; + + ctv = get(*ctv->parent); + LUAU_ASSERT(ctv); + } + } + else if (auto tags = getTags(ty)) + return hasTag(*tags, tagName); + + return false; +} + +bool hasTag(const Property& prop, const std::string& tagName) +{ + return hasTag(prop.tags, tagName); +} + } // namespace Luau diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 117cbc289..2539650a4 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -7,6 +7,7 @@ #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" #include "Luau/TimeTrace.h" +#include "Luau/VisitTypeVar.h" #include @@ -22,9 +23,99 @@ LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAGVARIABLE(LuauSealedTableUnifyOptionalFix, false) LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false) +LUAU_FASTFLAG(LuauShareTxnSeen); +LUAU_FASTFLAGVARIABLE(LuauCacheUnifyTableResults, false) namespace Luau { +struct SkipCacheForType +{ + SkipCacheForType(const DenseHashMap& skipCacheForType) + : skipCacheForType(skipCacheForType) + { + } + + void cycle(TypeId) {} + void cycle(TypePackId) {} + + bool operator()(TypeId ty, const FreeTypeVar& ftv) + { + result = true; + return false; + } + + bool operator()(TypeId ty, const BoundTypeVar& btv) + { + result = true; + return false; + } + + bool operator()(TypeId ty, const GenericTypeVar& btv) + { + result = true; + return false; + } + + bool operator()(TypeId ty, const TableTypeVar&) + { + TableTypeVar& ttv = *getMutable(ty); + + if (ttv.boundTo) + { + result = true; + return false; + } + + if (ttv.state != TableState::Sealed) + { + result = true; + return false; + } + + return true; + } + + template + bool operator()(TypeId ty, const T& t) + { + const bool* prev = skipCacheForType.find(ty); + + if (prev && *prev) + { + result = true; + return false; + } + + return true; + } + + template + bool operator()(TypePackId, const T&) + { + return true; + } + + bool operator()(TypePackId tp, const FreeTypePack& ftp) + { + result = true; + return false; + } + + bool operator()(TypePackId tp, const BoundTypePack& ftp) + { + result = true; + return false; + } + + bool operator()(TypePackId tp, const GenericTypePack& ftp) + { + result = true; + return false; + } + + const DenseHashMap& skipCacheForType; + bool result = false; +}; static std::optional hasUnificationTooComplex(const ErrorVec& errors) { @@ -39,7 +130,7 @@ static std::optional hasUnificationTooComplex(const ErrorVec& errors) return *it; } -Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, InternalErrorReporter* iceHandler) +Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState) : types(types) , mode(mode) , globalScope(std::move(globalScope)) @@ -47,24 +138,39 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Locati , variance(variance) , counters(&countersData) , counters_DEPRECATED(std::make_shared()) - , iceHandler(iceHandler) + , sharedState(sharedState) +{ + LUAU_ASSERT(sharedState.iceHandler); +} + +Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& ownedSeen, const Location& location, + Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) + : types(types) + , mode(mode) + , globalScope(std::move(globalScope)) + , log(ownedSeen) + , location(location) + , variance(variance) + , counters(counters ? counters : &countersData) + , counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared()) + , sharedState(sharedState) { - LUAU_ASSERT(iceHandler); + LUAU_ASSERT(sharedState.iceHandler); } -Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& seen, const Location& location, - Variance variance, InternalErrorReporter* iceHandler, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) +Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, + Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) : types(types) , mode(mode) , globalScope(std::move(globalScope)) - , log(seen) + , log(sharedSeen) , location(location) , variance(variance) , counters(counters ? counters : &countersData) , counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared()) - , iceHandler(iceHandler) + , sharedState(sharedState) { - LUAU_ASSERT(iceHandler); + LUAU_ASSERT(sharedState.iceHandler); } void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) @@ -74,7 +180,7 @@ void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool i else counters_DEPRECATED->iterationCount = 0; - return tryUnify_(superTy, subTy, isFunctionCall, isIntersection); + tryUnify_(superTy, subTy, isFunctionCall, isIntersection); } void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) @@ -206,6 +312,13 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (get(subTy) || get(subTy)) return tryUnifyWithAny(subTy, superTy); + bool cacheEnabled = FFlag::LuauCacheUnifyTableResults && !isFunctionCall && !isIntersection; + auto& cache = sharedState.cachedUnify; + + // What if the types are immutable and we proved their relation before + if (cacheEnabled && cache.contains({superTy, subTy}) && (variance == Covariant || cache.contains({subTy, superTy}))) + return; + // If we have seen this pair of types before, we are currently recursing into cyclic types. // Here, we assume that the types unify. If they do not, we will find out as we roll back // the stack. @@ -257,6 +370,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (FFlag::LuauUnionHeuristic) { + bool found = false; + const std::string* subName = getName(subTy); if (subName) { @@ -264,6 +379,21 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { const std::string* optionName = getName(uv->options[i]); if (optionName && *optionName == *subName) + { + found = true; + startIndex = i; + break; + } + } + } + + if (!found && cacheEnabled) + { + for (size_t i = 0; i < uv->options.size(); ++i) + { + TypeId type = uv->options[i]; + + if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type}))) { startIndex = i; break; @@ -311,8 +441,25 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool bool found = false; std::optional unificationTooComplex; - for (TypeId type : uv->parts) + size_t startIndex = 0; + + if (cacheEnabled) { + for (size_t i = 0; i < uv->parts.size(); ++i) + { + TypeId type = uv->parts[i]; + + if (cache.contains({superTy, type}) && (variance == Covariant || cache.contains({type, superTy}))) + { + startIndex = i; + break; + } + } + } + + for (size_t i = 0; i < uv->parts.size(); ++i) + { + TypeId type = uv->parts[(i + startIndex) % uv->parts.size()]; Unifier innerState = makeChildUnifier(); innerState.tryUnify_(superTy, type, isFunctionCall); @@ -342,8 +489,13 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool tryUnifyFunctions(superTy, subTy, isFunctionCall); else if (get(superTy) && get(subTy)) + { tryUnifyTables(superTy, subTy, isIntersection); + if (cacheEnabled && errors.empty()) + cacheResult(superTy, subTy); + } + // tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical. else if (get(superTy)) tryUnifyWithMetatable(superTy, subTy, /*reversed*/ false); @@ -364,6 +516,41 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool log.popSeen(superTy, subTy); } +void Unifier::cacheResult(TypeId superTy, TypeId subTy) +{ + LUAU_ASSERT(FFlag::LuauCacheUnifyTableResults); + + bool* superTyInfo = sharedState.skipCacheForType.find(superTy); + + if (superTyInfo && *superTyInfo) + return; + + bool* subTyInfo = sharedState.skipCacheForType.find(subTy); + + if (subTyInfo && *subTyInfo) + return; + + auto skipCacheFor = [this](TypeId ty) { + SkipCacheForType visitor{sharedState.skipCacheForType}; + visitTypeVarOnce(ty, visitor, sharedState.seenAny); + + sharedState.skipCacheForType[ty] = visitor.result; + + return visitor.result; + }; + + if (!superTyInfo && skipCacheFor(superTy)) + return; + + if (!subTyInfo && skipCacheFor(subTy)) + return; + + sharedState.cachedUnify.insert({superTy, subTy}); + + if (variance == Invariant) + sharedState.cachedUnify.insert({subTy, superTy}); +} + struct WeirdIter { TypePackId packId; @@ -459,7 +646,7 @@ void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall else counters_DEPRECATED->iterationCount = 0; - return tryUnify_(superTp, subTp, isFunctionCall); + tryUnify_(superTp, subTp, isFunctionCall); } /* @@ -650,11 +837,11 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal } // This is a bit weird because we don't actually know expected vs actual. We just know - // subtype vs supertype. If we are checking a return value, we swap these to produce - // the expected error message. + // subtype vs supertype. If we are checking the values returned by a function, we swap + // these to produce the expected error message. size_t expectedSize = size(superTp); size_t actualSize = size(subTp); - if (ctx == CountMismatch::Result || ctx == CountMismatch::Return) + if (ctx == CountMismatch::Result) std::swap(expectedSize, actualSize); errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); @@ -797,6 +984,40 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) std::vector missingProperties; std::vector extraProperties; + // Optimization: First test that the property sets are compatible without doing any recursive unification + if (FFlag::LuauTableUnificationEarlyTest && !rt->indexer && rt->state != TableState::Free) + { + for (const auto& [propName, superProp] : lt->props) + { + auto subIter = rt->props.find(propName); + if (subIter == rt->props.end() && !isOptional(superProp.type) && !get(follow(superProp.type))) + missingProperties.push_back(propName); + } + + if (!missingProperties.empty()) + { + errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingProperties)}}); + return; + } + } + + // And vice versa if we're invariant + if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && lt->state != TableState::Free) + { + for (const auto& [propName, subProp] : rt->props) + { + auto superIter = lt->props.find(propName); + if (superIter == lt->props.end() && !isOptional(subProp.type) && !get(follow(subProp.type))) + extraProperties.push_back(propName); + } + + if (!extraProperties.empty()) + { + errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraProperties), MissingProperties::Extra}}); + return; + } + } + // Reminder: left is the supertype, right is the subtype. // Width subtyping: any property in the supertype must be in the subtype, // and the types must agree. @@ -833,9 +1054,10 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) innerState.log.rollback(); } else if (isOptional(prop.type) || get(follow(prop.type))) - // TODO: this case is unsound, but without it our test suite fails. CLI-46031 - // TODO: should isOptional(anyType) be true? - {} + // TODO: this case is unsound, but without it our test suite fails. CLI-46031 + // TODO: should isOptional(anyType) be true? + { + } else if (rt->state == TableState::Free) { log(rt); @@ -878,11 +1100,13 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) lt->props[name] = clone; } else if (variance == Covariant) - {} + { + } else if (isOptional(prop.type) || get(follow(prop.type))) - // TODO: this case is unsound, but without it our test suite fails. CLI-46031 - // TODO: should isOptional(anyType) be true? - {} + // TODO: this case is unsound, but without it our test suite fails. CLI-46031 + // TODO: should isOptional(anyType) be true? + { + } else if (lt->state == TableState::Free) { log(lt); @@ -980,10 +1204,10 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map see TableTypeVar* resultTtv = getMutable(result); for (auto& [name, prop] : resultTtv->props) prop.type = deeplyOptional(prop.type, seen); - return types->addType(UnionTypeVar{{ singletonTypes.nilType, result }});; + return types->addType(UnionTypeVar{{singletonTypes.nilType, result}}); } else - return types->addType(UnionTypeVar{{ singletonTypes.nilType, ty }}); + return types->addType(UnionTypeVar{{singletonTypes.nilType, ty}}); } void Unifier::DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection) @@ -1247,7 +1471,7 @@ void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersectio // If the superTy/left is an immediate part of an intersection type, do not do extra-property check. // Otherwise, we would falsely generate an extra-property-error for 's' in this code: // local a: {n: number} & {s: string} = {n=1, s=""} - // When checking agaist the table '{n: number}'. + // When checking against the table '{n: number}'. if (!isIntersection && lt->state != TableState::Unsealed && !lt->indexer) { // Check for extra properties in the subTy @@ -1697,10 +1921,20 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) { std::vector queue = {ty}; - tempSeenTy.clear(); - tempSeenTp.clear(); + if (FFlag::LuauCacheUnifyTableResults) + { + sharedState.tempSeenTy.clear(); + sharedState.tempSeenTp.clear(); + + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, singletonTypes.anyType, anyTP); + } + else + { + tempSeenTy_DEPRECATED.clear(); + tempSeenTp_DEPRECATED.clear(); - Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, singletonTypes.anyType, anyTP); + Luau::tryUnifyWithAny(queue, *this, tempSeenTy_DEPRECATED, tempSeenTp_DEPRECATED, singletonTypes.anyType, anyTP); + } } else { @@ -1721,12 +1955,24 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) { std::vector queue; - tempSeenTy.clear(); - tempSeenTp.clear(); + if (FFlag::LuauCacheUnifyTableResults) + { + sharedState.tempSeenTy.clear(); + sharedState.tempSeenTp.clear(); - queueTypePack(queue, tempSeenTp, *this, ty, any); + queueTypePack(queue, sharedState.tempSeenTp, *this, ty, any); - Luau::tryUnifyWithAny(queue, *this, tempSeenTy, tempSeenTp, anyTy, any); + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, anyTy, any); + } + else + { + tempSeenTy_DEPRECATED.clear(); + tempSeenTp_DEPRECATED.clear(); + + queueTypePack(queue, tempSeenTp_DEPRECATED, *this, ty, any); + + Luau::tryUnifyWithAny(queue, *this, tempSeenTy_DEPRECATED, tempSeenTp_DEPRECATED, anyTy, any); + } } else { @@ -1775,10 +2021,20 @@ void Unifier::occursCheck(TypeId needle, TypeId haystack) { std::unordered_set seen_DEPRECATED; - if (FFlag::LuauTypecheckOpts) - tempSeenTy.clear(); + if (FFlag::LuauCacheUnifyTableResults) + { + if (FFlag::LuauTypecheckOpts) + sharedState.tempSeenTy.clear(); + + return occursCheck(seen_DEPRECATED, sharedState.tempSeenTy, needle, haystack); + } + else + { + if (FFlag::LuauTypecheckOpts) + tempSeenTy_DEPRECATED.clear(); - return occursCheck(seen_DEPRECATED, tempSeenTy, needle, haystack); + return occursCheck(seen_DEPRECATED, tempSeenTy_DEPRECATED, needle, haystack); + } } void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypeId needle, TypeId haystack) @@ -1851,10 +2107,20 @@ void Unifier::occursCheck(TypePackId needle, TypePackId haystack) { std::unordered_set seen_DEPRECATED; - if (FFlag::LuauTypecheckOpts) - tempSeenTp.clear(); + if (FFlag::LuauCacheUnifyTableResults) + { + if (FFlag::LuauTypecheckOpts) + sharedState.tempSeenTp.clear(); - return occursCheck(seen_DEPRECATED, tempSeenTp, needle, haystack); + return occursCheck(seen_DEPRECATED, sharedState.tempSeenTp, needle, haystack); + } + else + { + if (FFlag::LuauTypecheckOpts) + tempSeenTp_DEPRECATED.clear(); + + return occursCheck(seen_DEPRECATED, tempSeenTp_DEPRECATED, needle, haystack); + } } void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypePackId needle, TypePackId haystack) @@ -1922,7 +2188,10 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, Dense Unifier Unifier::makeChildUnifier() { - return Unifier{types, mode, globalScope, log.seen, location, variance, iceHandler, counters_DEPRECATED, counters}; + if (FFlag::LuauShareTxnSeen) + return Unifier{types, mode, globalScope, log.sharedSeen, location, variance, sharedState, counters_DEPRECATED, counters}; + else + return Unifier{types, mode, globalScope, log.ownedSeen, location, variance, sharedState, counters_DEPRECATED, counters}; } bool Unifier::isNonstrictMode() const @@ -1940,12 +2209,12 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId void Unifier::ice(const std::string& message, const Location& location) { - iceHandler->ice(message, location); + sharedState.iceHandler->ice(message, location); } void Unifier::ice(const std::string& message) { - iceHandler->ice(message); + sharedState.iceHandler->ice(message); } } // namespace Luau diff --git a/Ast/include/Luau/TimeTrace.h b/Ast/include/Luau/TimeTrace.h index 641dfd3c3..503eca61d 100644 --- a/Ast/include/Luau/TimeTrace.h +++ b/Ast/include/Luau/TimeTrace.h @@ -194,20 +194,20 @@ LUAU_NOINLINE std::pair createScopeDa } // namespace Luau // Regular scope -#define LUAU_TIMETRACE_SCOPE(name, category) \ +#define LUAU_TIMETRACE_SCOPE(name, category) \ static auto lttScopeStatic = Luau::TimeTrace::createScopeData(name, category); \ Luau::TimeTrace::Scope lttScope(lttScopeStatic.second, lttScopeStatic.first) // A scope without nested scopes that may be skipped if the time it took is less than the threshold -#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \ +#define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) \ static auto lttScopeStaticOptTail = Luau::TimeTrace::createScopeData(name, category); \ Luau::TimeTrace::OptionalTailScope lttScope(lttScopeStaticOptTail.second, lttScopeStaticOptTail.first, microsec) // Extra key/value data can be added to regular scopes -#define LUAU_TIMETRACE_ARGUMENT(name, value) \ - do \ - { \ - if (FFlag::DebugLuauTimeTracing) \ +#define LUAU_TIMETRACE_ARGUMENT(name, value) \ + do \ + { \ + if (FFlag::DebugLuauTimeTracing) \ lttScopeStatic.second.eventArgument(name, value); \ } while (false) @@ -216,8 +216,8 @@ LUAU_NOINLINE std::pair createScopeDa #define LUAU_TIMETRACE_SCOPE(name, category) #define LUAU_TIMETRACE_OPTIONAL_TAIL_SCOPE(name, category, microsec) #define LUAU_TIMETRACE_ARGUMENT(name, value) \ - do \ - { \ + do \ + { \ } while (false) #endif diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 40026d8be..846bc0ba9 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -1593,7 +1593,7 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) { Location location = lexer.current().location; - // For a missing type annoation, capture 'space' between last token and the next one + // For a missing type annotation, capture 'space' between last token and the next one location = Location(lexer.previousLocation().end, lexer.current().location.begin); return {reportTypeAnnotationError(location, {}, /*isMissing*/ true, "Expected type, got %s", lexer.current().toString().c_str()), {}}; diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp index e6aab20e5..ded50e53e 100644 --- a/Ast/src/TimeTrace.cpp +++ b/Ast/src/TimeTrace.cpp @@ -77,7 +77,10 @@ struct GlobalContext // Ideally we would want all ThreadContext destructors to run // But in VS, not all thread_local object instances are destroyed for (ThreadContext* context : threads) - context->flushEvents(); + { + if (!context->events.empty()) + context->flushEvents(); + } if (traceFile) fclose(traceFile); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 6baa21ea6..4968d0800 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -169,7 +169,17 @@ static std::string runCode(lua_State* L, const std::string& source) } else { - std::string error = (status == LUA_YIELD) ? "thread yielded unexpectedly" : lua_tostring(T, -1); + std::string error; + + if (status == LUA_YIELD) + { + error = "thread yielded unexpectedly"; + } + else if (const char* str = lua_tostring(T, -1)) + { + error = str; + } + error += "\nstack backtrace:\n"; error += lua_debugtrace(T); @@ -322,7 +332,17 @@ static bool runFile(const char* name, lua_State* GL) } else { - std::string error = (status == LUA_YIELD) ? "thread yielded unexpectedly" : lua_tostring(L, -1); + std::string error; + + if (status == LUA_YIELD) + { + error = "thread yielded unexpectedly"; + } + else if (const char* str = lua_tostring(L, -1)) + { + error = str; + } + error += "\nstacktrace:\n"; error += lua_debugtrace(L); diff --git a/Compiler/include/Luau/Bytecode.h b/Compiler/include/Luau/Bytecode.h index 07be2e749..4b03ed1c7 100644 --- a/Compiler/include/Luau/Bytecode.h +++ b/Compiler/include/Luau/Bytecode.h @@ -208,14 +208,14 @@ enum LuauOpcode LOP_MODK, LOP_POWK, - // AND, OR: perform `and` or `or` operation (selecting first or second register based on whether the first one is truthful) and put the result into target register + // AND, OR: perform `and` or `or` operation (selecting first or second register based on whether the first one is truthy) and put the result into target register // A: target register // B: source register 1 // C: source register 2 LOP_AND, LOP_OR, - // ANDK, ORK: perform `and` or `or` operation (selecting source register or constant based on whether the source register is truthful) and put the result into target register + // ANDK, ORK: perform `and` or `or` operation (selecting source register or constant based on whether the source register is truthy) and put the result into target register // A: target register // B: source register // C: constant table index (0..255) diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 797ee20dc..7750a1d9f 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -718,9 +718,9 @@ struct Compiler } // compile expr to target temp register - // if the expr (or not expr if onlyTruth is false) is truthful, jump via skipJump - // if the expr (or not expr if onlyTruth is false) is falseful, fall through (target isn't guaranteed to be updated in this case) - // if target is omitted, then the jump behavior is the same - skipJump or fallthrough depending on the truthfulness of the expression + // if the expr (or not expr if onlyTruth is false) is truthy, jump via skipJump + // if the expr (or not expr if onlyTruth is false) is falsy, fall through (target isn't guaranteed to be updated in this case) + // if target is omitted, then the jump behavior is the same - skipJump or fallthrough depending on the truthiness of the expression void compileConditionValue(AstExpr* node, const uint8_t* target, std::vector& skipJump, bool onlyTruth) { // Optimization: we don't need to compute constant values @@ -728,7 +728,7 @@ struct Compiler if (cv && cv->type != Constant::Type_Unknown) { - // note that we only need to compute the value if it's truthful; otherwise we cal fall through + // note that we only need to compute the value if it's truthy; otherwise we cal fall through if (cv->isTruthful() == onlyTruth) { if (target) @@ -747,7 +747,7 @@ struct Compiler case AstExprBinary::And: case AstExprBinary::Or: { - // disambiguation: there's 4 cases (we only need truthful or falseful results based on onlyTruth) + // disambiguation: there's 4 cases (we only need truthy or falsy results based on onlyTruth) // onlyTruth = 1: a and b transforms to a ? b : dontcare // onlyTruth = 1: a or b transforms to a ? a : a // onlyTruth = 0: a and b transforms to !a ? a : b @@ -791,8 +791,8 @@ struct Compiler if (target) { // since target is a temp register, we'll initialize it to 1, and then jump if the comparison is true - // if the comparison is false, we'll fallthrough and target will still be 1 but target has unspecified value for falseful results - // when we only care about falseful values instead of truthful values, the process is the same but with flipped conditionals + // if the comparison is false, we'll fallthrough and target will still be 1 but target has unspecified value for falsy results + // when we only care about falsy values instead of truthy values, the process is the same but with flipped conditionals bytecode.emitABC(LOP_LOADB, *target, onlyTruth ? 1 : 0, 0); } diff --git a/Makefile b/Makefile index 0056870b6..7788251d8 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,5 @@ # This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +.SUFFIXES: MAKEFLAGS+=-r -j8 COMMA=, @@ -107,6 +108,7 @@ coverage: $(TESTS_TARGET) rm default.profraw default-flags.profraw llvm-cov show -format=html -show-instantiations=false -show-line-counts=true -show-region-summary=false -ignore-filename-regex=\(tests\|extern\)/.* -output-dir=coverage --instr-profile default.profdata build/coverage/luau-tests llvm-cov report -ignore-filename-regex=\(tests\|extern\)/.* -show-region-summary=false --instr-profile default.profdata build/coverage/luau-tests + llvm-cov export -format lcov --instr-profile default.profdata build/coverage/luau-tests >coverage.info format: find . -name '*.h' -or -name '*.cpp' | xargs clang-format -i diff --git a/Sources.cmake b/Sources.cmake index 83ed52301..c30cf77d9 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -46,6 +46,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Module.h Analysis/include/Luau/ModuleResolver.h Analysis/include/Luau/Predicate.h + Analysis/include/Luau/Quantify.h Analysis/include/Luau/RecursionCounter.h Analysis/include/Luau/RequireTracer.h Analysis/include/Luau/Scope.h @@ -63,6 +64,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/TypeVar.h Analysis/include/Luau/Unifiable.h Analysis/include/Luau/Unifier.h + Analysis/include/Luau/UnifierSharedState.h Analysis/include/Luau/Variant.h Analysis/include/Luau/VisitTypeVar.h @@ -77,6 +79,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Linter.cpp Analysis/src/Module.cpp Analysis/src/Predicate.cpp + Analysis/src/Quantify.cpp Analysis/src/RequireTracer.cpp Analysis/src/Scope.cpp Analysis/src/Substitution.cpp diff --git a/VM/include/lualib.h b/VM/include/lualib.h index 7a09ae9f8..30cffaff8 100644 --- a/VM/include/lualib.h +++ b/VM/include/lualib.h @@ -76,7 +76,7 @@ struct luaL_Buffer char buffer[LUA_BUFFERSIZE]; }; -// when internal buffer storage is exhaused, a mutable string value 'storage' will be placed on the stack +// when internal buffer storage is exhausted, a mutable string value 'storage' will be placed on the stack // in general, functions expect the mutable string buffer to be placed on top of the stack (top-1) // with the exception of luaL_addvalue that expects the value at the top and string buffer further away (top-2) // functions that accept a 'boxloc' support string buffer placement at any location in the stack diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 013153605..f2e97c669 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -13,8 +13,6 @@ #include -LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads) - const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -1153,7 +1151,7 @@ void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)) luaC_checkGC(L); luaC_checkthreadsleep(L); Udata* u = luaS_newudata(L, sz + sizeof(dtor), UTAG_IDTOR); - memcpy(u->data + sz, &dtor, sizeof(dtor)); + memcpy(&u->data + sz, &dtor, sizeof(dtor)); setuvalue(L, L->top, u); api_incr_top(L); return u->data; diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index e37618f7b..2a684ee4e 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -313,7 +313,7 @@ static size_t getnextbuffersize(lua_State* L, size_t currentsize, size_t desired { size_t newsize = currentsize + currentsize / 2; - // check for size oveflow + // check for size overflow if (SIZE_MAX - desiredsize < currentsize) luaL_error(L, "buffer too large"); diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index c0b50b969..9724c0e72 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -5,8 +5,6 @@ #include "lstate.h" #include "lvm.h" -LUAU_FASTFLAGVARIABLE(LuauPreferXpush, false) - #define CO_RUN 0 /* running */ #define CO_SUS 1 /* suspended */ #define CO_NOR 2 /* 'normal' (it resumed another coroutine) */ @@ -17,7 +15,7 @@ LUAU_FASTFLAGVARIABLE(LuauPreferXpush, false) static const char* const statnames[] = {"running", "suspended", "normal", "dead"}; -static int costatus(lua_State* L, lua_State* co) +static int auxstatus(lua_State* L, lua_State* co) { if (co == L) return CO_RUN; @@ -25,7 +23,7 @@ static int costatus(lua_State* L, lua_State* co) return CO_SUS; if (co->status == LUA_BREAK) return CO_NOR; - if (co->status != 0) /* some error occured */ + if (co->status != 0) /* some error occurred */ return CO_DEAD; if (co->ci != co->base_ci) /* does it have frames? */ return CO_NOR; @@ -34,11 +32,11 @@ static int costatus(lua_State* L, lua_State* co) return CO_SUS; /* initial state */ } -static int luaB_costatus(lua_State* L) +static int costatus(lua_State* L) { lua_State* co = lua_tothread(L, 1); luaL_argexpected(L, co, 1, "thread"); - lua_pushstring(L, statnames[costatus(L, co)]); + lua_pushstring(L, statnames[auxstatus(L, co)]); return 1; } @@ -47,7 +45,7 @@ static int auxresume(lua_State* L, lua_State* co, int narg) // error handling for edge cases if (co->status != LUA_YIELD) { - int status = costatus(L, co); + int status = auxstatus(L, co); if (status != CO_SUS) { lua_pushfstring(L, "cannot resume %s coroutine", statnames[status]); @@ -115,7 +113,7 @@ static int auxresumecont(lua_State* L, lua_State* co) } } -static int luaB_coresumefinish(lua_State* L, int r) +static int coresumefinish(lua_State* L, int r) { if (r < 0) { @@ -131,7 +129,7 @@ static int luaB_coresumefinish(lua_State* L, int r) } } -static int luaB_coresumey(lua_State* L) +static int coresumey(lua_State* L) { lua_State* co = lua_tothread(L, 1); luaL_argexpected(L, co, 1, "thread"); @@ -141,10 +139,10 @@ static int luaB_coresumey(lua_State* L) if (r == CO_STATUS_BREAK) return interruptThread(L, co); - return luaB_coresumefinish(L, r); + return coresumefinish(L, r); } -static int luaB_coresumecont(lua_State* L, int status) +static int coresumecont(lua_State* L, int status) { lua_State* co = lua_tothread(L, 1); luaL_argexpected(L, co, 1, "thread"); @@ -155,10 +153,10 @@ static int luaB_coresumecont(lua_State* L, int status) int r = auxresumecont(L, co); - return luaB_coresumefinish(L, r); + return coresumefinish(L, r); } -static int luaB_auxwrapfinish(lua_State* L, int r) +static int auxwrapfinish(lua_State* L, int r) { if (r < 0) { @@ -173,7 +171,7 @@ static int luaB_auxwrapfinish(lua_State* L, int r) return r; } -static int luaB_auxwrapy(lua_State* L) +static int auxwrapy(lua_State* L) { lua_State* co = lua_tothread(L, lua_upvalueindex(1)); int narg = cast_int(L->top - L->base); @@ -182,10 +180,10 @@ static int luaB_auxwrapy(lua_State* L) if (r == CO_STATUS_BREAK) return interruptThread(L, co); - return luaB_auxwrapfinish(L, r); + return auxwrapfinish(L, r); } -static int luaB_auxwrapcont(lua_State* L, int status) +static int auxwrapcont(lua_State* L, int status) { lua_State* co = lua_tothread(L, lua_upvalueindex(1)); @@ -195,62 +193,52 @@ static int luaB_auxwrapcont(lua_State* L, int status) int r = auxresumecont(L, co); - return luaB_auxwrapfinish(L, r); + return auxwrapfinish(L, r); } -static int luaB_cocreate(lua_State* L) +static int cocreate(lua_State* L) { luaL_checktype(L, 1, LUA_TFUNCTION); lua_State* NL = lua_newthread(L); - - if (FFlag::LuauPreferXpush) - { - lua_xpush(L, NL, 1); // push function on top of NL - } - else - { - lua_pushvalue(L, 1); /* move function to top */ - lua_xmove(L, NL, 1); /* move function from L to NL */ - } - + lua_xpush(L, NL, 1); // push function on top of NL return 1; } -static int luaB_cowrap(lua_State* L) +static int cowrap(lua_State* L) { - luaB_cocreate(L); + cocreate(L); - lua_pushcfunction(L, luaB_auxwrapy, NULL, 1, luaB_auxwrapcont); + lua_pushcfunction(L, auxwrapy, NULL, 1, auxwrapcont); return 1; } -static int luaB_yield(lua_State* L) +static int coyield(lua_State* L) { int nres = cast_int(L->top - L->base); return lua_yield(L, nres); } -static int luaB_corunning(lua_State* L) +static int corunning(lua_State* L) { if (lua_pushthread(L)) lua_pushnil(L); /* main thread is not a coroutine */ return 1; } -static int luaB_yieldable(lua_State* L) +static int coyieldable(lua_State* L) { lua_pushboolean(L, lua_isyieldable(L)); return 1; } static const luaL_Reg co_funcs[] = { - {"create", luaB_cocreate}, - {"running", luaB_corunning}, - {"status", luaB_costatus}, - {"wrap", luaB_cowrap}, - {"yield", luaB_yield}, - {"isyieldable", luaB_yieldable}, + {"create", cocreate}, + {"running", corunning}, + {"status", costatus}, + {"wrap", cowrap}, + {"yield", coyield}, + {"isyieldable", coyieldable}, {NULL, NULL}, }; @@ -258,7 +246,7 @@ LUALIB_API int luaopen_coroutine(lua_State* L) { luaL_register(L, LUA_COLIBNAME, co_funcs); - lua_pushcfunction(L, luaB_coresumey, "resume", 0, luaB_coresumecont); + lua_pushcfunction(L, coresumey, "resume", 0, coresumecont); lua_setfield(L, -2, "resume"); return 1; diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 39a615978..328b47e69 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -18,6 +18,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false) +LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false) /* ** {====================================================== @@ -536,7 +537,13 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e status = LUA_ERRERR; } - // an error occured, check if we have a protected error callback + if (FFlag::LuauCcallRestoreFix) + { + // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. + L->nCcalls = oldnCcalls; + } + + // an error occurred, check if we have a protected error callback if (L->global->cb.debugprotectederror) { L->global->cb.debugprotectederror(L); @@ -549,7 +556,10 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e StkId oldtop = restorestack(L, old_top); luaF_close(L, oldtop); /* close eventual pending closures */ seterrorobj(L, status, oldtop); - L->nCcalls = oldnCcalls; + if (!FFlag::LuauCcallRestoreFix) + { + L->nCcalls = oldnCcalls; + } L->ci = restoreci(L, old_ci); L->base = L->ci->base; restore_stack_limit(L); diff --git a/VM/src/ldo.h b/VM/src/ldo.h index 4fe1c3418..72807f0f3 100644 --- a/VM/src/ldo.h +++ b/VM/src/ldo.h @@ -37,7 +37,7 @@ /* results from luaD_precall */ #define PCRLUA 0 /* initiated a call to a Lua function */ #define PCRC 1 /* did a call to a C function */ -#define PCRYIELD 2 /* C funtion yielded */ +#define PCRYIELD 2 /* C function yielded */ /* type of protected functions, to be ran by `runprotected' */ typedef void (*Pfunc)(lua_State* L, void* ud); diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 0b5430269..648785697 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -76,7 +76,7 @@ UpVal* luaF_findupval(lua_State* L, StkId level) if (p->v == level) { /* found a corresponding upvalue? */ if (isdead(g, obj2gco(p))) /* is it dead? */ - changewhite(obj2gco(p)); /* ressurect it */ + changewhite(obj2gco(p)); /* resurrect it */ return p; } pp = &p->next; diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 510a9f548..6553009ff 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -12,11 +12,9 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgain, false) LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgainForwardBarrier, false) -LUAU_FASTFLAGVARIABLE(LuauGcFullSkipInactiveThreads, false) -LUAU_FASTFLAGVARIABLE(LuauShrinkWeakTables, false) LUAU_FASTFLAGVARIABLE(LuauConsolidatedStep, false) +LUAU_FASTFLAGVARIABLE(LuauSeparateAtomic, false) LUAU_FASTFLAG(LuauArrayBoundary) @@ -66,13 +64,18 @@ static void recordGcStateTime(global_State* g, int startgcstate, double seconds, g->gcstats.currcycle.marktime += seconds; // atomic step had to be performed during the switch and it's tracked separately - if (g->gcstate == GCSsweepstring) + if (!FFlag::LuauSeparateAtomic && g->gcstate == GCSsweepstring) g->gcstats.currcycle.marktime -= g->gcstats.currcycle.atomictime; break; + case GCSatomic: + g->gcstats.currcycle.atomictime += seconds; + break; case GCSsweepstring: case GCSsweep: g->gcstats.currcycle.sweeptime += seconds; break; + default: + LUAU_ASSERT(!"Unexpected GC state"); } if (assist) @@ -183,33 +186,15 @@ static int traversetable(global_State* g, Table* h) if (h->metatable) markobject(g, cast_to(Table*, h->metatable)); - if (FFlag::LuauShrinkWeakTables) - { - /* is there a weak mode? */ - if (const char* modev = gettablemode(g, h)) - { - weakkey = (strchr(modev, 'k') != NULL); - weakvalue = (strchr(modev, 'v') != NULL); - if (weakkey || weakvalue) - { /* is really weak? */ - h->gclist = g->weak; /* must be cleared after GC, ... */ - g->weak = obj2gco(h); /* ... so put in the appropriate list */ - } - } - } - else + /* is there a weak mode? */ + if (const char* modev = gettablemode(g, h)) { - const TValue* mode = gfasttm(g, h->metatable, TM_MODE); - if (mode && ttisstring(mode)) - { /* is there a weak mode? */ - const char* modev = svalue(mode); - weakkey = (strchr(modev, 'k') != NULL); - weakvalue = (strchr(modev, 'v') != NULL); - if (weakkey || weakvalue) - { /* is really weak? */ - h->gclist = g->weak; /* must be cleared after GC, ... */ - g->weak = obj2gco(h); /* ... so put in the appropriate list */ - } + weakkey = (strchr(modev, 'k') != NULL); + weakvalue = (strchr(modev, 'v') != NULL); + if (weakkey || weakvalue) + { /* is really weak? */ + h->gclist = g->weak; /* must be cleared after GC, ... */ + g->weak = obj2gco(h); /* ... so put in the appropriate list */ } } @@ -297,7 +282,7 @@ static void traversestack(global_State* g, lua_State* l, bool clearstack) for (StkId o = l->stack; o < l->top; o++) markvalue(g, o); /* final traversal? */ - if (g->gcstate == GCSatomic || (FFlag::LuauGcFullSkipInactiveThreads && clearstack)) + if (g->gcstate == GCSatomic || clearstack) { StkId stack_end = l->stack + l->stacksize; for (StkId o = l->top; o < stack_end; o++) /* clear not-marked stack slice */ @@ -336,28 +321,16 @@ static size_t propagatemark(global_State* g) lua_State* th = gco2th(o); g->gray = th->gclist; - if (FFlag::LuauGcFullSkipInactiveThreads) - { - LUAU_ASSERT(!luaC_threadsleeping(th)); - - // threads that are executing and the main thread are not deactivated - bool active = luaC_threadactive(th) || th == th->global->mainthread; + LUAU_ASSERT(!luaC_threadsleeping(th)); - if (!active && g->gcstate == GCSpropagate) - { - traversestack(g, th, /* clearstack= */ true); + // threads that are executing and the main thread are not deactivated + bool active = luaC_threadactive(th) || th == th->global->mainthread; - l_setbit(th->stackstate, THREAD_SLEEPINGBIT); - } - else - { - th->gclist = g->grayagain; - g->grayagain = o; - - black2gray(o); + if (!active && g->gcstate == GCSpropagate) + { + traversestack(g, th, /* clearstack= */ true); - traversestack(g, th, /* clearstack= */ false); - } + l_setbit(th->stackstate, THREAD_SLEEPINGBIT); } else { @@ -385,12 +358,14 @@ static size_t propagatemark(global_State* g) } } -static void propagateall(global_State* g) +static size_t propagateall(global_State* g) { + size_t work = 0; while (g->gray) { - propagatemark(g); + work += propagatemark(g); } + return work; } /* @@ -415,11 +390,14 @@ static int isobjcleared(GCObject* o) /* ** clear collected entries from weaktables */ -static void cleartable(lua_State* L, GCObject* l) +static size_t cleartable(lua_State* L, GCObject* l) { + size_t work = 0; while (l) { Table* h = gco2h(l); + work += sizeof(Table) + sizeof(TValue) * h->sizearray + sizeof(LuaNode) * sizenode(h); + int i = h->sizearray; while (i--) { @@ -433,50 +411,36 @@ static void cleartable(lua_State* L, GCObject* l) { LuaNode* n = gnode(h, i); - if (FFlag::LuauShrinkWeakTables) + // non-empty entry? + if (!ttisnil(gval(n))) { - // non-empty entry? - if (!ttisnil(gval(n))) - { - // can we clear key or value? - if (iscleared(gkey(n)) || iscleared(gval(n))) - { - setnilvalue(gval(n)); /* remove value ... */ - removeentry(n); /* remove entry from table */ - } - else - { - activevalues++; - } - } - } - else - { - if (!ttisnil(gval(n)) && /* non-empty entry? */ - (iscleared(gkey(n)) || iscleared(gval(n)))) + // can we clear key or value? + if (iscleared(gkey(n)) || iscleared(gval(n))) { setnilvalue(gval(n)); /* remove value ... */ removeentry(n); /* remove entry from table */ } + else + { + activevalues++; + } } } - if (FFlag::LuauShrinkWeakTables) + if (const char* modev = gettablemode(L->global, h)) { - if (const char* modev = gettablemode(L->global, h)) + // are we allowed to shrink this weak table? + if (strchr(modev, 's')) { - // are we allowed to shrink this weak table? - if (strchr(modev, 's')) - { - // shrink at 37.5% occupancy - if (activevalues < sizenode(h) * 3 / 8) - luaH_resizehash(L, h, activevalues); - } + // shrink at 37.5% occupancy + if (activevalues < sizenode(h) * 3 / 8) + luaH_resizehash(L, h, activevalues); } } l = h->gclist; } + return work; } static void shrinkstack(lua_State* L) @@ -655,37 +619,49 @@ static void markroot(lua_State* L) g->gcstate = GCSpropagate; } -static void remarkupvals(global_State* g) +static size_t remarkupvals(global_State* g) { - UpVal* uv; - for (uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) + size_t work = 0; + for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) { + work += sizeof(UpVal); LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); if (isgray(obj2gco(uv))) markvalue(g, uv->v); } + return work; } -static void atomic(lua_State* L) +static size_t atomic(lua_State* L) { global_State* g = L->global; - g->gcstate = GCSatomic; + size_t work = 0; + + if (FFlag::LuauSeparateAtomic) + { + LUAU_ASSERT(g->gcstate == GCSatomic); + } + else + { + g->gcstate = GCSatomic; + } + /* remark occasional upvalues of (maybe) dead threads */ - remarkupvals(g); + work += remarkupvals(g); /* traverse objects caught by write barrier and by 'remarkupvals' */ - propagateall(g); + work += propagateall(g); /* remark weak tables */ g->gray = g->weak; g->weak = NULL; LUAU_ASSERT(!iswhite(obj2gco(g->mainthread))); markobject(g, L); /* mark running thread */ markmt(g); /* mark basic metatables (again) */ - propagateall(g); + work += propagateall(g); /* remark gray again */ g->gray = g->grayagain; g->grayagain = NULL; - propagateall(g); - cleartable(L, g->weak); /* remove collected objects from weak tables */ + work += propagateall(g); + work += cleartable(L, g->weak); /* remove collected objects from weak tables */ g->weak = NULL; /* flip current white */ g->currentwhite = cast_byte(otherwhite(g)); @@ -693,7 +669,12 @@ static void atomic(lua_State* L) g->sweepgc = &g->rootgc; g->gcstate = GCSsweepstring; - GC_INTERRUPT(GCSatomic); + if (!FFlag::LuauSeparateAtomic) + { + GC_INTERRUPT(GCSatomic); + } + + return work; } static size_t singlestep(lua_State* L) @@ -705,46 +686,24 @@ static size_t singlestep(lua_State* L) case GCSpause: { markroot(L); /* start a new collection */ + LUAU_ASSERT(g->gcstate == GCSpropagate); break; } case GCSpropagate: { - if (FFlag::LuauRescanGrayAgain) + if (g->gray) { - if (g->gray) - { - g->gcstats.currcycle.markitems++; - - cost = propagatemark(g); - } - else - { - // perform one iteration over 'gray again' list - g->gray = g->grayagain; - g->grayagain = NULL; + g->gcstats.currcycle.markitems++; - g->gcstate = GCSpropagateagain; - } + cost = propagatemark(g); } else { - if (g->gray) - { - g->gcstats.currcycle.markitems++; - - cost = propagatemark(g); - } - else /* no more `gray' objects */ - { - double starttimestamp = lua_clock(); + // perform one iteration over 'gray again' list + g->gray = g->grayagain; + g->grayagain = NULL; - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - - atomic(L); /* finish mark phase */ - - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; - } + g->gcstate = GCSpropagateagain; } break; } @@ -758,17 +717,34 @@ static size_t singlestep(lua_State* L) } else /* no more `gray' objects */ { - double starttimestamp = lua_clock(); + if (FFlag::LuauSeparateAtomic) + { + g->gcstate = GCSatomic; + } + else + { + double starttimestamp = lua_clock(); - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - atomic(L); /* finish mark phase */ + atomic(L); /* finish mark phase */ + LUAU_ASSERT(g->gcstate == GCSsweepstring); - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + } } break; } + case GCSatomic: + { + g->gcstats.currcycle.atomicstarttimestamp = lua_clock(); + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + + cost = atomic(L); /* finish mark phase */ + LUAU_ASSERT(g->gcstate == GCSsweepstring); + break; + } case GCSsweepstring: { size_t traversedcount = 0; @@ -806,7 +782,7 @@ static size_t singlestep(lua_State* L) break; } default: - LUAU_ASSERT(0); + LUAU_ASSERT(!"Unexpected GC state"); } return cost; @@ -821,48 +797,25 @@ static size_t gcstep(lua_State* L, size_t limit) case GCSpause: { markroot(L); /* start a new collection */ + LUAU_ASSERT(g->gcstate == GCSpropagate); break; } case GCSpropagate: { - if (FFlag::LuauRescanGrayAgain) + while (g->gray && cost < limit) { - while (g->gray && cost < limit) - { - g->gcstats.currcycle.markitems++; - - cost += propagatemark(g); - } - - if (!g->gray) - { - // perform one iteration over 'gray again' list - g->gray = g->grayagain; - g->grayagain = NULL; + g->gcstats.currcycle.markitems++; - g->gcstate = GCSpropagateagain; - } + cost += propagatemark(g); } - else - { - while (g->gray && cost < limit) - { - g->gcstats.currcycle.markitems++; - - cost += propagatemark(g); - } - - if (!g->gray) /* no more `gray' objects */ - { - double starttimestamp = lua_clock(); - - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - atomic(L); /* finish mark phase */ + if (!g->gray) + { + // perform one iteration over 'gray again' list + g->gray = g->grayagain; + g->grayagain = NULL; - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; - } + g->gcstate = GCSpropagateagain; } break; } @@ -877,17 +830,34 @@ static size_t gcstep(lua_State* L, size_t limit) if (!g->gray) /* no more `gray' objects */ { - double starttimestamp = lua_clock(); + if (FFlag::LuauSeparateAtomic) + { + g->gcstate = GCSatomic; + } + else + { + double starttimestamp = lua_clock(); - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - atomic(L); /* finish mark phase */ + atomic(L); /* finish mark phase */ + LUAU_ASSERT(g->gcstate == GCSsweepstring); - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; + } } break; } + case GCSatomic: + { + g->gcstats.currcycle.atomicstarttimestamp = lua_clock(); + g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; + + cost = atomic(L); /* finish mark phase */ + LUAU_ASSERT(g->gcstate == GCSsweepstring); + break; + } case GCSsweepstring: { while (g->sweepstrgc < g->strt.size && cost < limit) @@ -934,7 +904,7 @@ static size_t gcstep(lua_State* L, size_t limit) break; } default: - LUAU_ASSERT(0); + LUAU_ASSERT(!"Unexpected GC state"); } return cost; } @@ -1008,7 +978,7 @@ void luaC_step(lua_State* L, bool assist) startGcCycleStats(g); int lastgcstate = g->gcstate; - double lastttimestamp = lua_clock(); + double lasttimestamp = lua_clock(); if (FFlag::LuauConsolidatedStep) { @@ -1034,15 +1004,15 @@ void luaC_step(lua_State* L, bool assist) double now = lua_clock(); - recordGcStateTime(g, lastgcstate, now - lastttimestamp, assist); + recordGcStateTime(g, lastgcstate, now - lasttimestamp, assist); - lastttimestamp = now; + lasttimestamp = now; lastgcstate = g->gcstate; } } while (lim > 0 && g->gcstate != GCSpause); } - recordGcStateTime(g, lastgcstate, lua_clock() - lastttimestamp, assist); + recordGcStateTime(g, lastgcstate, lua_clock() - lasttimestamp, assist); // at the end of the last cycle if (g->gcstate == GCSpause) @@ -1084,7 +1054,7 @@ void luaC_fullgc(lua_State* L) if (g->gcstate == GCSpause) startGcCycleStats(g); - if (g->gcstate <= GCSpropagateagain) + if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain)) { /* reset sweep marks to sweep all elements (returning them to white) */ g->sweepstrgc = 0; @@ -1095,7 +1065,7 @@ void luaC_fullgc(lua_State* L) g->weak = NULL; g->gcstate = GCSsweepstring; } - LUAU_ASSERT(g->gcstate != GCSpause && g->gcstate != GCSpropagate && g->gcstate != GCSpropagateagain); + LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); /* finish any pending sweep phase */ while (g->gcstate != GCSpause) { @@ -1143,14 +1113,11 @@ void luaC_fullgc(lua_State* L) void luaC_barrierupval(lua_State* L, GCObject* v) { - if (FFlag::LuauGcFullSkipInactiveThreads) - { - global_State* g = L->global; - LUAU_ASSERT(iswhite(v) && !isdead(g, v)); + global_State* g = L->global; + LUAU_ASSERT(iswhite(v) && !isdead(g, v)); - if (keepinvariant(g)) - reallymarkobject(g, v); - } + if (keepinvariant(g)) + reallymarkobject(g, v); } void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v) @@ -1778,7 +1745,7 @@ int64_t luaC_allocationrate(lua_State* L) global_State* g = L->global; const double durationthreshold = 1e-3; // avoid measuring intervals smaller than 1ms - if (g->gcstate <= GCSpropagateagain) + if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain)) { double duration = lua_clock() - g->gcstats.lastcycle.endtimestamp; diff --git a/VM/src/lgc.h b/VM/src/lgc.h index dc780bba5..f434e5064 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -6,8 +6,6 @@ #include "lobject.h" #include "lstate.h" -LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads) - /* ** Possible states of the Garbage Collector */ @@ -25,10 +23,10 @@ LUAU_FASTFLAG(LuauGcFullSkipInactiveThreads) ** still-black objects. The invariant is restored when sweep ends and ** all objects are white again. */ -#define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain) +#define keepinvariant(g) ((g)->gcstate == GCSpropagate || (g)->gcstate == GCSpropagateagain || (g)->gcstate == GCSatomic) /* -** some userful bit tricks +** some useful bit tricks */ #define resetbits(x, m) ((x) &= cast_to(uint8_t, ~(m))) #define setbits(x, m) ((x) |= (m)) @@ -147,4 +145,4 @@ LUAI_FUNC void luaC_validate(lua_State* L); LUAI_FUNC void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)); LUAI_FUNC int64_t luaC_allocationrate(lua_State* L); LUAI_FUNC void luaC_wakethread(lua_State* L); -LUAI_FUNC const char* luaC_statename(int state); \ No newline at end of file +LUAI_FUNC const char* luaC_statename(int state); diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 2759f3b80..d8b265cba 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -199,7 +199,7 @@ static void* luaM_newblock(lua_State* L, int sizeClass) if (page->freeNext >= 0) { - block = page->data + page->freeNext; + block = &page->data + page->freeNext; ASAN_UNPOISON_MEMORY_REGION(block, page->blockSize); page->freeNext -= page->blockSize; diff --git a/VM/src/lstring.cpp b/VM/src/lstring.cpp index d77e17c9b..18ee1cda5 100644 --- a/VM/src/lstring.cpp +++ b/VM/src/lstring.cpp @@ -226,7 +226,7 @@ void luaS_freeudata(lua_State* L, Udata* u) void (*dtor)(void*) = nullptr; if (u->tag == UTAG_IDTOR) - memcpy(&dtor, u->data + u->len - sizeof(dtor), sizeof(dtor)); + memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor)); else if (u->tag) dtor = L->global->udatagc[u->tag]; diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index b932a85b3..a168b6522 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -13,7 +13,7 @@ #include // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens -template +template struct TempBuffer { lua_State* L; @@ -346,6 +346,8 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size uint32_t mainid = readVarInt(data, size, offset); Proto* main = protos[mainid]; + luaC_checkthreadsleep(L); + Closure* cl = luaF_newLclosure(L, 0, envt, main); setclvalue(L, L->top, cl); incr_top(L); diff --git a/bench/tests/chess.lua b/bench/tests/chess.lua new file mode 100644 index 000000000..87b9abfd4 --- /dev/null +++ b/bench/tests/chess.lua @@ -0,0 +1,849 @@ + +local bench = script and require(script.Parent.bench_support) or require("bench_support") + +local RANKS = "12345678" +local FILES = "abcdefgh" +local PieceSymbols = "PpRrNnBbQqKk" +local UnicodePieces = {"♙", "♟", "♖", "♜", "♘", "♞", "♗", "♝", "♕", "♛", "♔", "♚"} +local StartingFen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1" + +-- +-- Lua 5.2 Compat +-- + +if not table.create then + function table.create(n, v) + local result = {} + for i=1,n do result[i] = v end + return result + end +end + +if not table.move then + function table.move(a, from, to, start, target) + local dx = start - from + for i=from,to do + target[i+dx] = a[i] + end + end +end + + +-- +-- Utils +-- + +local function square(s) + return RANKS:find(s:sub(2,2)) * 8 + FILES:find(s:sub(1,1)) - 9 +end + +local function squareName(n) + local file = n % 8 + local rank = (n-file)/8 + return FILES:sub(file+1,file+1) .. RANKS:sub(rank+1,rank+1) +end + +local function moveName(v ) + local from = bit32.extract(v, 6, 6) + local to = bit32.extract(v, 0, 6) + local piece = bit32.extract(v, 20, 4) + local captured = bit32.extract(v, 25, 4) + + local move = PieceSymbols:sub(piece,piece) .. ' ' .. squareName(from) .. (captured ~= 0 and 'x' or '-') .. squareName(to) + + if bit32.extract(v,14) == 1 then + if to > from then + return "O-O" + else + return "O-O-O" + end + end + + local promote = bit32.extract(v,15,4) + if promote ~= 0 then + move = move .. "=" .. PieceSymbols:sub(promote,promote) + end + return move +end + +local function ucimove(m) + local mm = squareName(bit32.extract(m, 6, 6)) .. squareName(bit32.extract(m, 0, 6)) + local promote = bit32.extract(m,15,4) + if promote > 0 then + mm = mm .. PieceSymbols:sub(promote,promote):lower() + end + return mm +end + +local _utils = {squareName, moveName} + +-- +-- Bitboards +-- + +local Bitboard = {} + + +function Bitboard:toString() + local out = {} + + local src = self.h + for x=7,0,-1 do + table.insert(out, RANKS:sub(x+1,x+1)) + table.insert(out, " ") + local bit = bit32.lshift(1,(x%4) * 8) + for x=0,7 do + if bit32.band(src, bit) ~= 0 then + table.insert(out, "x ") + else + table.insert(out, "- ") + end + bit = bit32.lshift(bit, 1) + end + if x == 4 then + src = self.l + end + table.insert(out, "\n") + end + table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n') + table.insert(out, '#: ' .. self:popcnt() .. "\tl:" .. self.l .. "\th:" .. self.h) + return table.concat(out) +end + + +function Bitboard.from(l ,h ) + return setmetatable({l=l, h=h}, Bitboard) +end + +Bitboard.zero = Bitboard.from(0,0) +Bitboard.full = Bitboard.from(0xFFFFFFFF, 0xFFFFFFFF) + +local Rank1 = Bitboard.from(0x000000FF, 0) +local Rank3 = Bitboard.from(0x00FF0000, 0) +local Rank6 = Bitboard.from(0, 0x0000FF00) +local Rank8 = Bitboard.from(0, 0xFF000000) +local FileA = Bitboard.from(0x01010101, 0x01010101) +local FileB = Bitboard.from(0x02020202, 0x02020202) +local FileC = Bitboard.from(0x04040404, 0x04040404) +local FileD = Bitboard.from(0x08080808, 0x08080808) +local FileE = Bitboard.from(0x10101010, 0x10101010) +local FileF = Bitboard.from(0x20202020, 0x20202020) +local FileG = Bitboard.from(0x40404040, 0x40404040) +local FileH = Bitboard.from(0x80808080, 0x80808080) + +local _Files = {FileA, FileB, FileC, FileD, FileE, FileF, FileG, FileH} + +-- These masks are filled out below for all files +local RightMasks = {FileH} +local LeftMasks = {FileA} + + + +local function popcnt32(i) + i = i - bit32.band(bit32.rshift(i,1), 0x55555555) + i = bit32.band(i, 0x33333333) + bit32.band(bit32.rshift(i,2), 0x33333333) + return bit32.rshift(bit32.band(i + bit32.rshift(i,4), 0x0F0F0F0F) * 0x01010101, 24) +end + +function Bitboard:up() + return self:lshift(8) +end + +function Bitboard:down() + return self:rshift(8) +end + +function Bitboard:right() + return self:band(FileH:inverse()):lshift(1) +end + +function Bitboard:left() + return self:band(FileA:inverse()):rshift(1) +end + +function Bitboard:move(x,y) + local out = self + + if x < 0 then out = out:bandnot(RightMasks[-x]):lshift(-x) end + if x > 0 then out = out:bandnot(LeftMasks[x]):rshift(x) end + + if y < 0 then out = out:rshift(-8 * y) end + if y > 0 then out = out:lshift(8 * y) end + return out +end + + +function Bitboard:popcnt() + return popcnt32(self.l) + popcnt32(self.h) +end + +function Bitboard:band(other ) + return Bitboard.from(bit32.band(self.l,other.l), bit32.band(self.h, other.h)) +end + +function Bitboard:bandnot(other ) + return Bitboard.from(bit32.band(self.l,bit32.bnot(other.l)), bit32.band(self.h, bit32.bnot(other.h))) +end + +function Bitboard:bandempty(other ) + return bit32.band(self.l,other.l) == 0 and bit32.band(self.h, other.h) == 0 +end + +function Bitboard:bor(other ) + return Bitboard.from(bit32.bor(self.l,other.l), bit32.bor(self.h, other.h)) +end + +function Bitboard:bxor(other ) + return Bitboard.from(bit32.bxor(self.l,other.l), bit32.bxor(self.h, other.h)) +end + +function Bitboard:inverse() + return Bitboard.from(bit32.bxor(self.l,0xFFFFFFFF), bit32.bxor(self.h, 0xFFFFFFFF)) +end + +function Bitboard:empty() + return self.h == 0 and self.l == 0 +end + +function Bitboard:ctz() + local target = self.l + local offset = 0 + local result = 0 + + if target == 0 then + target = self.h + result = 32 + end + + if target == 0 then + return 64 + end + + while bit32.extract(target, offset) == 0 do + offset = offset + 1 + end + + return result + offset +end + +function Bitboard:ctzafter(start) + start = start + 1 + if start < 32 then + for i=start,31 do + if bit32.extract(self.l, i) == 1 then return i end + end + end + for i=math.max(32,start),63 do + if bit32.extract(self.h, i-32) == 1 then return i end + end + return 64 +end + + +function Bitboard:lshift(amt) + assert(amt >= 0) + if amt == 0 then return self end + + if amt > 31 then + return Bitboard.from(0, bit32.lshift(self.l, amt-31)) + end + + local l = bit32.lshift(self.l, amt) + local h = bit32.bor( + bit32.lshift(self.h, amt), + bit32.extract(self.l, 32-amt, amt) + ) + return Bitboard.from(l, h) +end + +function Bitboard:rshift(amt) + assert(amt >= 0) + if amt == 0 then return self end + local h = bit32.rshift(self.h, amt) + local l = bit32.bor( + bit32.rshift(self.l, amt), + bit32.lshift(bit32.extract(self.h, 0, amt), 32-amt) + ) + return Bitboard.from(l, h) +end + +function Bitboard:index(i) + if i > 31 then + return bit32.extract(self.h, i - 32) + else + return bit32.extract(self.l, i) + end +end + +function Bitboard:set(i , v) + if i > 31 then + return Bitboard.from(self.l, bit32.replace(self.h, v, i - 32)) + else + return Bitboard.from(bit32.replace(self.l, v, i), self.h) + end +end + +function Bitboard:isolate(i) + return self:band(Bitboard.some(i)) +end + +function Bitboard.some(idx ) + return Bitboard.zero:set(idx, 1) +end + +Bitboard.__index = Bitboard +Bitboard.__tostring = Bitboard.toString + +for i=2,8 do + RightMasks[i] = RightMasks[i-1]:rshift(1):bor(FileH) + LeftMasks[i] = LeftMasks[i-1]:lshift(1):bor(FileA) +end +-- +-- Board +-- + +local Board = {} + + +function Board.new() + local boards = table.create(12, Bitboard.zero) + boards.ocupied = Bitboard.zero + boards.white = Bitboard.zero + boards.black = Bitboard.zero + boards.unocupied = Bitboard.full + boards.ep = Bitboard.zero + boards.castle = Bitboard.zero + boards.toMove = 1 + boards.hm = 0 + boards.moves = 0 + boards.material = 0 + + return setmetatable(boards, Board) +end + +function Board.fromFen(fen ) + local b = Board.new() + local i = 0 + local rank = 7 + local file = 0 + + while true do + i = i + 1 + local p = fen:sub(i,i) + if p == '/' then + rank = rank - 1 + file = 0 + elseif tonumber(p) ~= nil then + file = file + tonumber(p) + else + local pidx = PieceSymbols:find(p) + if pidx == nil then break end + b[pidx] = b[pidx]:set(rank*8+file, 1) + file = file + 1 + end + end + + + local move, castle, ep, hm, m = string.match(fen, "^ ([bw]) ([KQkq-]*) ([a-h-][0-9]?) (%d*) (%d*)", i) + if move == nil then print(fen:sub(i)) end + b.toMove = move == 'w' and 1 or 2 + + if ep ~= "-" then + b.ep = Bitboard.some(square(ep)) + end + + if castle ~= "-" then + local oo = Bitboard.zero + if castle:find("K") then + oo = oo:set(7, 1) + end + if castle:find("Q") then + oo = oo:set(0, 1) + end + if castle:find("k") then + oo = oo:set(63, 1) + end + if castle:find("q") then + oo = oo:set(56, 1) + end + + b.castle = oo + end + + b.hm = hm + b.moves = m + + b:updateCache() + return b + +end + +function Board:index(idx ) + if self.white:index(idx) == 1 then + for p=1,12,2 do + if self[p]:index(idx) == 1 then + return p + end + end + else + for p=2,12,2 do + if self[p]:index(idx) == 1 then + return p + end + end + end + + return 0 +end + +function Board:updateCache() + for i=1,11,2 do + self.white = self.white:bor(self[i]) + self.black = self.black:bor(self[i+1]) + end + + self.ocupied = self.black:bor(self.white) + self.unocupied = self.ocupied:inverse() + self.material = + 100*self[1]:popcnt() - 100*self[2]:popcnt() + + 500*self[3]:popcnt() - 500*self[4]:popcnt() + + 300*self[5]:popcnt() - 300*self[6]:popcnt() + + 300*self[7]:popcnt() - 300*self[8]:popcnt() + + 900*self[9]:popcnt() - 900*self[10]:popcnt() + +end + +function Board:fen() + local out = {} + local s = 0 + local idx = 56 + for i=0,63 do + if i % 8 == 0 and i > 0 then + idx = idx - 16 + if s > 0 then + table.insert(out, '' .. s) + s = 0 + end + table.insert(out, '/') + end + local p = self:index(idx) + if p == 0 then + s = s + 1 + else + if s > 0 then + table.insert(out, '' .. s) + s = 0 + end + table.insert(out, PieceSymbols:sub(p,p)) + end + + idx = idx + 1 + end + if s > 0 then + table.insert(out, '' .. s) + end + + table.insert(out, self.toMove == 1 and ' w ' or ' b ') + if self.castle:empty() then + table.insert(out, '-') + else + if self.castle:index(7) == 1 then table.insert(out, 'K') end + if self.castle:index(0) == 1 then table.insert(out, 'Q') end + if self.castle:index(63) == 1 then table.insert(out, 'k') end + if self.castle:index(56) == 1 then table.insert(out, 'q') end + end + + table.insert(out, ' ') + if self.ep:empty() then + table.insert(out, '-') + else + table.insert(out, squareName(self.ep:ctz())) + end + + table.insert(out, ' ' .. self.hm) + table.insert(out, ' ' .. self.moves) + + return table.concat(out) +end + +function Board:pmoves(idx) + return self:generate(idx) +end + +function Board:pcaptures(idx) + return self:generate(idx):band(self.ocupied) +end + +local ROOK_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}} +local BISHOP_SLIDES = {{1,1}, {-1,1}, {1,-1}, {-1,-1}} +local QUEEN_SLIDES = {{1,0}, {-1,0}, {0,1}, {0,-1}, {1,1}, {-1,1}, {1,-1}, {-1,-1}} +local KNIGHT_MOVES = {{2,1}, {2,-1}, {-2,1}, {-2,-1}, {1,2}, {1,-2}, {-1,2}, {-1,-2}} + +function Board:generate(idx) + local piece = self:index(idx) + local r = Bitboard.some(idx) + local out = Bitboard.zero + local type = bit32.rshift(piece - 1, 1) + local cancapture = piece % 2 == 1 and self.black or self.white + + if piece == 0 then return Bitboard.zero end + + if type == 0 then + -- Pawn + local d = -(piece*2 - 3) + local movetwo = piece == 1 and Rank3 or Rank6 + + out = out:bor(r:move(0,d):band(self.unocupied)) + out = out:bor(out:band(movetwo):move(0,d):band(self.unocupied)) + + local captures = r:move(0,d) + captures = captures:right():bor(captures:left()) + + if not captures:bandempty(self.ep) then + out = out:bor(self.ep) + end + + captures = captures:band(cancapture) + out = out:bor(captures) + + return out + elseif type == 5 then + -- King + for x=-1,1,1 do + for y = -1,1,1 do + local w = r:move(x,y) + if self.ocupied:bandempty(w) then + out = out:bor(w) + else + if not cancapture:bandempty(w) then + out = out:bor(w) + end + end + end + end + elseif type == 2 then + -- Knight + for _,j in ipairs(KNIGHT_MOVES) do + local w = r:move(j[1],j[2]) + + if self.ocupied:bandempty(w) then + out = out:bor(w) + else + if not cancapture:bandempty(w) then + out = out:bor(w) + end + end + end + else + -- Sliders (Rook, Bishop, Queen) + local slides + if type == 1 then + slides = ROOK_SLIDES + elseif type == 3 then + slides = BISHOP_SLIDES + else + slides = QUEEN_SLIDES + end + + for _, op in ipairs(slides) do + local w = r + for i=1,7 do + w = w:move(op[1], op[2]) + if w:empty() then break end + + if self.ocupied:bandempty(w) then + out = out:bor(w) + else + if not cancapture:bandempty(w) then + out = out:bor(w) + end + break + end + end + end + end + + + return out +end + +-- 0-5 - From Square +-- 6-11 - To Square +-- 12 - is Check +-- 13 - Is EnPassent +-- 14 - Is Castle +-- 15-19 - Promotion Piece +-- 20-24 - Moved Pice +-- 25-29 - Captured Piece + + +function Board:toString(mark ) + local out = {} + for x=8,1,-1 do + table.insert(out, RANKS:sub(x,x) .. " ") + + for y=1,8 do + local n = 8*x+y-9 + local i = self:index(n) + if i == 0 then + table.insert(out, '-') + else + -- out = out .. PieceSymbols:sub(i,i) + table.insert(out, UnicodePieces[i]) + end + if mark ~= nil and mark:index(n) ~= 0 then + table.insert(out, ')') + elseif mark ~= nil and n < 63 and y < 8 and mark:index(n+1) ~= 0 then + table.insert(out, '(') + else + table.insert(out, ' ') + end + end + + table.insert(out, "\n") + end + table.insert(out, ' ' .. FILES:gsub('.', '%1 ') .. '\n') + table.insert(out, (self.toMove == 1 and "White" or "Black") .. ' e:' .. (self.material/100) .. "\n") + return table.concat(out) +end + +function Board:moveList() + local tm = self.toMove == 1 and self.white or self.black + local castle_rank = self.toMove == 1 and Rank1 or Rank8 + local out = {} + local function emit(id) + if not self:applyMove(id):illegalyChecked() then + table.insert(out, id) + end + end + + local cr = tm:band(self.castle):band(castle_rank) + if not cr:empty() then + local p = self.toMove == 1 and 11 or 12 + local tcolor = self.toMove == 1 and self.black or self.white + local kidx = self[p]:ctz() + + + local castle = bit32.replace(0, p, 20, 4) + castle = bit32.replace(castle, kidx, 6, 6) + castle = bit32.replace(castle, 1, 14) + + + local mustbeemptyl = LeftMasks[4]:bxor(FileA):band(castle_rank) + local cantbethreatened = FileD:bor(FileC):band(castle_rank):bor(self[p]) + if + not cr:bandempty(FileA) and + mustbeemptyl:bandempty(self.ocupied) and + not self:isSquareThreatened(cantbethreatened, tcolor) + then + emit(bit32.replace(castle, kidx - 2, 0, 6)) + end + + + local mustbeemptyr = RightMasks[3]:bxor(FileH):band(castle_rank) + if + not cr:bandempty(FileH) and + mustbeemptyr:bandempty(self.ocupied) and + not self:isSquareThreatened(mustbeemptyr:bor(self[p]), tcolor) + then + emit(bit32.replace(castle, kidx + 2, 0, 6)) + end + end + + local sq = tm:ctz() + repeat + local p = self:index(sq) + local moves = self:pmoves(sq) + + while not moves:empty() do + local m = moves:ctz() + moves = moves:set(m, 0) + local id = bit32.replace(m, sq, 6, 6) + id = bit32.replace(id, p, 20, 4) + local mbb = Bitboard.some(m) + if not self.ocupied:bandempty(mbb) then + id = bit32.replace(id, self:index(m), 25, 4) + end + + -- Check if pawn needs to be promoted + if p == 1 and m >= 8*7 then + for i=3,9,2 do + emit(bit32.replace(id, i, 15, 4)) + end + elseif p == 2 and m < 8 then + for i=4,10,2 do + emit(bit32.replace(id, i, 15, 4)) + end + else + emit(id) + end + end + sq = tm:ctzafter(sq) + until sq == 64 + return out +end + +function Board:illegalyChecked() + local target = self.toMove == 1 and self[PieceSymbols:find("k")] or self[PieceSymbols:find("K")] + return self:isSquareThreatened(target, self.toMove == 1 and self.white or self.black) +end + +function Board:isSquareThreatened(target , color ) + local tm = color + local sq = tm:ctz() + repeat + local moves = self:pmoves(sq) + if not moves:bandempty(target) then + return true + end + sq = color:ctzafter(sq) + until sq == 64 + return false +end + +function Board:perft(depth ) + if depth == 0 then return 1 end + if depth == 1 then + return #self:moveList() + end + local result = 0 + for k,m in ipairs(self:moveList()) do + local c = self:applyMove(m):perft(depth - 1) + if c == 0 then + -- Perft only counts leaf nodes at target depth + -- result = result + 1 + else + result = result + c + end + end + return result +end + + +function Board:applyMove(move ) + local out = Board.new() + table.move(self, 1, 12, 1, out) + local from = bit32.extract(move, 6, 6) + local to = bit32.extract(move, 0, 6) + local promote = bit32.extract(move, 15, 4) + local piece = self:index(from) + local captured = self:index(to) + local tom = Bitboard.some(to) + local isCastle = bit32.extract(move, 14) + + if piece % 2 == 0 then + out.moves = self.moves + 1 + end + + if captured == 1 or piece < 3 then + out.hm = 0 + else + out.hm = self.hm + 1 + end + out.castle = self.castle + out.toMove = self.toMove == 1 and 2 or 1 + + if isCastle == 1 then + local rank = piece == 11 and Rank1 or Rank8 + local colorOffset = piece - 11 + + out[3 + colorOffset] = out[3 + colorOffset]:bandnot(from < to and FileH or FileA) + out[3 + colorOffset] = out[3 + colorOffset]:bor((from < to and FileF or FileD):band(rank)) + + out[piece] = (from < to and FileG or FileC):band(rank) + out.castle = out.castle:bandnot(rank) + out:updateCache() + return out + end + + if piece < 3 then + local dist = math.abs(to - from) + -- Pawn moved two squares, set ep square + if dist == 16 then + out.ep = Bitboard.some((from + to) / 2) + end + + -- Remove enpasent capture + if not tom:bandempty(self.ep) then + if piece == 1 then + out[2] = out[2]:bandnot(self.ep:down()) + end + if piece == 2 then + out[1] = out[1]:bandnot(self.ep:up()) + end + end + end + + if piece == 3 or piece == 4 then + out.castle = out.castle:set(from, 0) + end + + if piece > 10 then + local rank = piece == 11 and Rank1 or Rank8 + out.castle = out.castle:bandnot(rank) + end + + out[piece] = out[piece]:set(from, 0) + if promote == 0 then + out[piece] = out[piece]:set(to, 1) + else + out[promote] = out[promote]:set(to, 1) + end + if captured ~= 0 then + out[captured] = out[captured]:set(to, 0) + end + + out:updateCache() + return out +end + +Board.__index = Board +Board.__tostring = Board.toString +-- +-- Main +-- + +local failures = 0 +local function test(fen, ply, target) + local b = Board.fromFen(fen) + if b:fen() ~= fen then + print("FEN MISMATCH", fen, b:fen()) + failures = failures + 1 + return + end + + local found = b:perft(ply) + if found ~= target then + print(fen, "Found", found, "target", target) + failures = failures + 1 + for k,v in pairs(b:moveList()) do + print(ucimove(v) .. ': ' .. (ply > 1 and b:applyMove(v):perft(ply-1) or '1')) + end + --error("Test Failure") + else + print("OK", found, fen) + end +end + +-- From https://www.chessprogramming.org/Perft_Results +-- If interpreter, computers, or algorithm gets too fast +-- feel free to go deeper + +local testCases = {} +local function addTest(...) table.insert(testCases, {...}) end + +addTest(StartingFen, 3, 8902) +addTest("r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 0", 2, 2039) +addTest("8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 0", 3, 2812) +addTest("r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", 3, 9467) +addTest("rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", 2, 1486) +addTest("r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", 2, 2079) + + +local function chess() + for k,v in ipairs(testCases) do + test(v[1],v[2],v[3]) + end +end + +bench.runCode(chess, "chess") diff --git a/bench/tests/shootout/scimark.lua b/bench/tests/shootout/scimark.lua index 41d97bb8b..ad0557b1d 100644 --- a/bench/tests/shootout/scimark.lua +++ b/bench/tests/shootout/scimark.lua @@ -30,7 +30,7 @@ ------------------------------------------------------------------------------ ------------------------------------------------------------------------------ --- Modificatin to be compatible with Lua 5.3 +-- Modification to be compatible with Lua 5.3 ------------------------------------------------------------------------------ local bench = script and require(script.Parent.bench_support) or require("bench_support") diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 07910a0ac..44b8362df 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1596,7 +1596,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_argument_type_suggestion") local function target(a: number, b: string) return a + #b end local function d(a: n@1, b) - return target(a, b) + return target(a, b) end )"); @@ -1609,7 +1609,7 @@ end local function target(a: number, b: string) return a + #b end local function d(a, b: s@1) - return target(a, b) + return target(a, b) end )"); @@ -1622,7 +1622,7 @@ end local function target(a: number, b: string) return a + #b end local function d(a:@1 @2, b) - return target(a, b) + return target(a, b) end )"); @@ -1640,7 +1640,7 @@ end local function target(a: number, b: string) return a + #b end local function d(a, b: @1)@2: number - return target(a, b) + return target(a, b) end )"); @@ -1682,7 +1682,7 @@ local x = target(function(a: n@1 local function target(callback: (a: number, b: string) -> number) return callback(4, "hello") end local x = target(function(a: n@1, b: @2) - return a + #b + return a + #b end) )"); @@ -1700,7 +1700,7 @@ end) local function target(callback: (...number) -> number) return callback(1, 2, 3) end local x = target(function(a: n@1) - return a + return a end )"); @@ -1716,7 +1716,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_argument_type_pack_suggestio local function target(callback: (...number) -> number) return callback(1, 2, 3) end local x = target(function(...:n@1) - return a + return a end )"); @@ -1729,7 +1729,7 @@ end local function target(callback: (...number) -> number) return callback(1, 2, 3) end local x = target(function(a:number, b:number, ...:@1) - return a + b + return a + b end )"); @@ -1745,7 +1745,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_suggestion") local function target(callback: () -> number) return callback() end local x = target(function(): n@1 - return 1 + return 1 end )"); @@ -1758,7 +1758,7 @@ end local function target(callback: () -> (number, number)) return callback() end local x = target(function(): (number, n@1 - return 1, 2 + return 1, 2 end )"); @@ -1774,7 +1774,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_expected_return_type_pack_suggestion" local function target(callback: () -> ...number) return callback() end local x = target(function(): ...n@1 - return 1, 2, 3 + return 1, 2, 3 end )"); @@ -1787,7 +1787,7 @@ end local function target(callback: () -> ...number) return callback() end local x = target(function(): (number, number, ...n@1 - return 1, 2, 3 + return 1, 2, 3 end )"); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 54a31a68a..bbac3302c 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -768,11 +768,11 @@ TEST_CASE("CaptureSelf") local MaterialsListClass = {} function MaterialsListClass:_MakeToolTip(guiElement, text) - local function updateTooltipPosition() - self._tweakingTooltipFrame = 5 - end + local function updateTooltipPosition() + self._tweakingTooltipFrame = 5 + end - updateTooltipPosition() + updateTooltipPosition() end return MaterialsListClass @@ -2001,14 +2001,14 @@ TEST_CASE("UpvaluesLoopsBytecode") { CHECK_EQ("\n" + compileFunction(R"( function test() - for i=1,10 do + for i=1,10 do i = i - foo(function() return i end) - if bar then - break - end - end - return 0 + foo(function() return i end) + if bar then + break + end + end + return 0 end )", 1), @@ -2035,14 +2035,14 @@ RETURN R0 1 CHECK_EQ("\n" + compileFunction(R"( function test() - for i in ipairs(data) do + for i in ipairs(data) do i = i - foo(function() return i end) - if bar then - break - end - end - return 0 + foo(function() return i end) + if bar then + break + end + end + return 0 end )", 1), @@ -2068,17 +2068,17 @@ RETURN R0 1 CHECK_EQ("\n" + compileFunction(R"( function test() - local i = 0 - while i < 5 do - local j + local i = 0 + while i < 5 do + local j j = i - foo(function() return j end) - i = i + 1 - if bar then - break - end - end - return 0 + foo(function() return j end) + i = i + 1 + if bar then + break + end + end + return 0 end )", 1), @@ -2105,17 +2105,17 @@ RETURN R1 1 CHECK_EQ("\n" + compileFunction(R"( function test() - local i = 0 - repeat - local j + local i = 0 + repeat + local j j = i - foo(function() return j end) - i = i + 1 - if bar then - break - end - until i < 5 - return 0 + foo(function() return j end) + i = i + 1 + if bar then + break + end + until i < 5 + return 0 end )", 1), @@ -2304,10 +2304,10 @@ local Value1, Value2, Value3 = ... local Table = {} Table.SubTable["Key"] = { - Key1 = Value1, - Key2 = Value2, - Key3 = Value3, - Key4 = true, + Key1 = Value1, + Key2 = Value2, + Key3 = Value3, + Key4 = true, } )"); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 5a697a49e..06b3c5237 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -801,4 +801,17 @@ TEST_CASE("IfElseExpression") runConformance("ifelseexpr.lua"); } +TEST_CASE("TagMethodError") +{ + ScopedFastFlag sff{"LuauCcallRestoreFix", true}; + + runConformance("tmerror.lua", [](lua_State* L) { + auto* cb = lua_callbacks(L); + + cb->debugprotectederror = [](lua_State* L) { + CHECK(lua_isyieldable(L)); + }; + }); +} + TEST_SUITE_END(); diff --git a/tests/IostreamOptional.h b/tests/IostreamOptional.h index 9f8748993..e55b5b0c3 100644 --- a/tests/IostreamOptional.h +++ b/tests/IostreamOptional.h @@ -2,6 +2,9 @@ #pragma once #include +#include + +namespace std { inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&) { @@ -9,10 +12,12 @@ inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&) } template -std::ostream& operator<<(std::ostream& lhs, const std::optional& t) +auto operator<<(std::ostream& lhs, const std::optional& t) -> decltype(lhs << *t) // SFINAE to only instantiate << for supported types { if (t) return lhs << *t; else return lhs << "none"; } + +} // namespace std diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index a9ed139f1..37f1b60b8 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -791,13 +791,13 @@ TEST_CASE_FIXTURE(Fixture, "TypeAnnotationsShouldNotProduceWarnings") { LintResult result = lint(R"(--!strict type InputData = { - id: number, - inputType: EnumItem, - inputState: EnumItem, - updated: number, - position: Vector3, - keyCode: EnumItem, - name: string + id: number, + inputType: EnumItem, + inputState: EnumItem, + updated: number, + position: Vector3, + keyCode: EnumItem, + name: string } )"); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 045f0230d..f580604ca 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -554,4 +554,54 @@ TEST_CASE_FIXTURE(Fixture, "non_recursive_aliases_that_reuse_a_generic_name") CHECK_EQ("{number | string}", toString(requireType("p"), {true})); } +/* + * We had a problem where all type aliases would be prototyped into a child scope that happened + * to have the same level. This caused a problem where, if a sibling function referred to that + * type alias in its type signature, it would erroneously be quantified away, even though it doesn't + * actually belong to the function. + * + * We solved this by ascribing a unique subLevel to each prototyped alias. + */ +TEST_CASE_FIXTURE(Fixture, "do_not_quantify_unresolved_aliases") +{ + CheckResult result = check(R"( + --!strict + + local KeyPool = {} + + local function newkey(pool: KeyPool, index) + return {} + end + + function newKeyPool() + local pool = { + available = {} :: {Key}, + } + + return setmetatable(pool, KeyPool) + end + + export type KeyPool = typeof(newKeyPool()) + export type Key = typeof(newkey(newKeyPool(), 1)) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +/* + * We keep a cache of type alias onto TypeVar to prevent infinite types from + * being constructed via recursive or corecursive aliases. We have to adjust + * the TypeLevels of those generic TypeVars so that the unifier doesn't think + * they have improperly leaked out of their scope. + */ +TEST_CASE_FIXTURE(Fixture, "generic_typevars_are_not_considered_to_escape_their_scope_if_they_are_reused_in_multiple_aliases") +{ + CheckResult result = check(R"( + type Array = {T} + type Exclude = T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index e89747762..17e32e9f9 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -359,7 +359,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_correctly_infers_type_of_array_2_args_o CHECK_EQ(typeChecker.stringType, requireType("s")); } -TEST_CASE_FIXTURE(Fixture, "table_insert_corrrectly_infers_type_of_array_3_args_overload") +TEST_CASE_FIXTURE(Fixture, "table_insert_correctly_infers_type_of_array_3_args_overload") { CheckResult result = check(R"( local t = {} diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 6da33a08f..eabf7e65d 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -437,8 +437,6 @@ TEST_CASE_FIXTURE(ClassFixture, "class_unification_type_mismatch_is_correct_orde TEST_CASE_FIXTURE(ClassFixture, "optional_class_field_access_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( local b: Vector2? = nil local a = b.X + b.Z diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 3a04a18f5..581375a12 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -695,4 +695,25 @@ TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types") CHECK(requireType("y1") == requireType("y2")); } +TEST_CASE_FIXTURE(Fixture, "bound_tables_do_not_clone_original_fields") +{ + ScopedFastFlag luauRankNTypes{"LuauRankNTypes", true}; + ScopedFastFlag luauCloneBoundTables{"LuauCloneBoundTables", true}; + + CheckResult result = check(R"( +local exports = {} +local nested = {} + +nested.name = function(t, k) + local a = t.x.y + return rawget(t, k) +end + +exports.nested = nested +return exports + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 8bcb02424..419da8ad1 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -9,12 +9,13 @@ #include LUAU_FASTFLAG(LuauEqConstraint) +LUAU_FASTFLAG(LuauQuantifyInPlace2) using namespace Luau; TEST_SUITE_BEGIN("ProvisionalTests"); -// These tests check for behavior that differes from the final behavior we'd +// These tests check for behavior that differs from the final behavior we'd // like to have. They serve to document the current state of the typechecker. // When making future improvements, its very likely these tests will break and // will need to be replaced. @@ -42,7 +43,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end )"; - const std::string expected = R"( + const std::string old_expected = R"( function f(a:{fn:()->(free,free...)}): () if type(a) == 'boolean'then local a1:boolean=a @@ -51,7 +52,21 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end end )"; - CHECK_EQ(expected, decorateWithTypes(code)); + + const std::string expected = R"( + function f(a:{fn:()->(a,b...)}): () + if type(a) == 'boolean'then + local a1:boolean=a + elseif a.fn()then + local a2:{fn:()->(a,b...)}=a + end + end + )"; + + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ(expected, decorateWithTypes(code)); + else + CHECK_EQ(old_expected, decorateWithTypes(code)); } TEST_CASE_FIXTURE(Fixture, "xpcall_returns_what_f_returns") @@ -263,8 +278,8 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") TEST_CASE_FIXTURE(Fixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5)) { - ScopedFastInt sffi{"LuauTarjanChildLimit", 50}; - ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 50}; + ScopedFastInt sffi{"LuauTarjanChildLimit", 1}; + ScopedFastInt sffi2{"LuauTypeInferIterationLimit", 1}; CheckResult result = check(R"LUA( local Result diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 31739cdc7..36dcaa959 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -8,6 +8,7 @@ LUAU_FASTFLAG(LuauWeakEqConstraint) LUAU_FASTFLAG(LuauOrPredicate) +LUAU_FASTFLAG(LuauQuantifyInPlace2) using namespace Luau; @@ -698,10 +699,16 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector" - CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); + else + CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" - CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" + else + CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" } TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector") diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index b7f0dc7b0..f1451a815 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -617,7 +617,7 @@ TEST_CASE_FIXTURE(Fixture, "indexers_get_quantified_too") REQUIRE_EQ(indexer.indexType, typeChecker.numberType); - REQUIRE(nullptr != get(indexer.indexResultType)); + REQUIRE(nullptr != get(follow(indexer.indexResultType))); } TEST_CASE_FIXTURE(Fixture, "indexers_quantification_2") diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index b75878b7e..453817574 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -180,7 +180,12 @@ TEST_CASE_FIXTURE(Fixture, "expr_statement") TEST_CASE_FIXTURE(Fixture, "generic_function") { - CheckResult result = check("function id(x) return x end local a = id(55) local b = id(nil)"); + CheckResult result = check(R"( + function id(x) return x end + local a = id(55) + local b = id(nil) + )"); + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ(*typeChecker.numberType, *requireType("a")); @@ -406,7 +411,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_factory_not_returning_the_right for p in primes2() do print(p) end -- mismatch in argument types, prime_iter takes {}, number, we are given {}, string - for p in primes3() do print(p) end -- no errror + for p in primes3() do print(p) end -- no error )"); LUAU_REQUIRE_ERROR_COUNT(2, result); @@ -1889,7 +1894,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") REQUIRE_EQ(2, argVec.size()); - const FunctionTypeVar* fType = get(argVec[0]); + const FunctionTypeVar* fType = get(follow(argVec[0])); REQUIRE(fType != nullptr); std::vector fArgs = flatten(fType->argTypes).first; @@ -1926,7 +1931,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_2") REQUIRE_EQ(6, argVec.size()); - const FunctionTypeVar* fType = get(argVec[0]); + const FunctionTypeVar* fType = get(follow(argVec[0])); REQUIRE(fType != nullptr); } @@ -2549,7 +2554,7 @@ TEST_CASE_FIXTURE(Fixture, "toposort_doesnt_break_mutual_recursion") --!strict local x = nil function f() g() end - -- make sure print(x) doen't get toposorted here, breaking the mutual block + -- make sure print(x) doesn't get toposorted here, breaking the mutual block function g() x = f end print(x) )"); @@ -2987,7 +2992,7 @@ TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_while") CHECK_EQ(us->name, "a"); } -TEST_CASE_FIXTURE(Fixture, "ipairs_produces_integral_indeces") +TEST_CASE_FIXTURE(Fixture, "ipairs_produces_integral_indices") { CheckResult result = check(R"( local key @@ -3176,7 +3181,24 @@ TEST_CASE_FIXTURE(Fixture, "too_many_return_values") CountMismatch* acm = get(result.errors[0]); REQUIRE(acm); - CHECK(acm->context == CountMismatch::Result); + CHECK_EQ(acm->context, CountMismatch::Result); + CHECK_EQ(acm->expected, 1); + CHECK_EQ(acm->actual, 2); +} + +TEST_CASE_FIXTURE(Fixture, "ignored_return_values") +{ + CheckResult result = check(R"( + --!strict + + function f() + return 55, "" + end + + local a = f() + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); } TEST_CASE_FIXTURE(Fixture, "function_does_not_return_enough_values") @@ -3194,6 +3216,8 @@ TEST_CASE_FIXTURE(Fixture, "function_does_not_return_enough_values") CountMismatch* acm = get(result.errors[0]); REQUIRE(acm); CHECK_EQ(acm->context, CountMismatch::Return); + CHECK_EQ(acm->expected, 2); + CHECK_EQ(acm->actual, 1); } TEST_CASE_FIXTURE(Fixture, "typecheck_unary_minus") @@ -3823,10 +3847,10 @@ local T: any T = {} T.__index = T function T.new(...) - local self = {} - setmetatable(self, T) - self:construct(...) - return self + local self = {} + setmetatable(self, T) + self:construct(...) + return self end function T:construct(index) end @@ -4049,11 +4073,11 @@ function n:Clone() end local m = {} function m.a(x) - x:Clone() + x:Clone() end function m.b() - m.a(n) + m.a(n) end return m @@ -4374,8 +4398,6 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload") TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") { - ScopedFastFlag luauInferFunctionArgsFix("LuauInferFunctionArgsFix", true); - // Simple direct arg to arg propagation CheckResult result = check(R"( type Table = { x: number, y: number } @@ -4385,7 +4407,7 @@ f(function(a) return a.x + a.y end) LUAU_REQUIRE_NO_ERRORS(result); - // An optional funciton is accepted, but since we already provide a function, nil can be ignored + // An optional function is accepted, but since we already provide a function, nil can be ignored result = check(R"( type Table = { x: number, y: number } local function f(a: ((Table) -> number)?) if a then return a({x = 1, y = 2}) else return 0 end end @@ -4413,7 +4435,7 @@ f(function(a: number, b, c) return c and a + b or b - a end) LUAU_REQUIRE_NO_ERRORS(result); - // Anonymous function has a varyadic pack + // Anonymous function has a variadic pack result = check(R"( type Table = { x: number, y: number } local function f(a: (Table) -> number) return a({x = 1, y = 2}) end @@ -4432,7 +4454,7 @@ f(function(a, b, c, ...) return a + b end) LUAU_REQUIRE_ERRORS(result); CHECK_EQ("Type '(number, number, a) -> number' could not be converted into '(number, number) -> number'", toString(result.errors[0])); - // Infer from varyadic packs into elements + // Infer from variadic packs into elements result = check(R"( function f(a: (...number) -> number) return a(1, 2) end f(function(a, b) return a + b end) @@ -4440,7 +4462,7 @@ f(function(a, b) return a + b end) LUAU_REQUIRE_NO_ERRORS(result); - // Infer from varyadic packs into varyadic packs + // Infer from variadic packs into variadic packs result = check(R"( type Table = { x: number, y: number } function f(a: (...Table) -> number) return a({x = 1, y = 2}, {x = 3, y = 4}) end @@ -4662,7 +4684,6 @@ TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early") { ScopedFastFlag sffs[] = { {"LuauSlightlyMoreFlexibleBinaryPredicates", true}, - {"LuauExtraNilRecovery", true}, }; CheckResult result = check(R"( @@ -4679,7 +4700,6 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") { ScopedFastFlag sffs[] = { {"LuauSlightlyMoreFlexibleBinaryPredicates", true}, - {"LuauExtraNilRecovery", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 1f4b63ef2..1192a8ac0 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -8,6 +8,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauQuantifyInPlace2); + using namespace Luau; struct TryUnifyFixture : Fixture @@ -15,7 +17,8 @@ struct TryUnifyFixture : Fixture TypeArena arena; ScopePtr globalScope{new Scope{arena.addTypePack({TypeId{}})}}; InternalErrorReporter iceHandler; - Unifier state{&arena, Mode::Strict, globalScope, Location{}, Variance::Covariant, &iceHandler}; + UnifierSharedState unifierState{&iceHandler}; + Unifier state{&arena, Mode::Strict, globalScope, Location{}, Variance::Covariant, unifierState}; }; TEST_SUITE_BEGIN("TryUnifyTests"); @@ -139,7 +142,10 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails" )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("(number) -> (boolean)", toString(requireType("f"))); + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ("(number) -> boolean", toString(requireType("f"))); + else + CHECK_EQ("(number) -> (boolean)", toString(requireType("f"))); } TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification") diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 3e1dedd47..8dab2605b 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -98,10 +98,10 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function") std::vector applyArgs = flatten(applyType->argTypes).first; REQUIRE_EQ(3, applyArgs.size()); - const FunctionTypeVar* fType = get(applyArgs[0]); + const FunctionTypeVar* fType = get(follow(applyArgs[0])); REQUIRE(fType != nullptr); - const FunctionTypeVar* gType = get(applyArgs[1]); + const FunctionTypeVar* gType = get(follow(applyArgs[1])); REQUIRE(gType != nullptr); std::vector gArgs = flatten(gType->argTypes).first; @@ -285,7 +285,7 @@ TEST_CASE_FIXTURE(Fixture, "variadic_argument_tail") { CheckResult result = check(R"( local _ = function():((...any)->(...any),()->()) - return function() end, function() end + return function() end, function() end end for y in _() do end diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 037144e2a..34c25a9fe 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -52,7 +52,7 @@ TEST_CASE_FIXTURE(Fixture, "allow_more_specific_assign") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "disallow_less_specifc_assign") +TEST_CASE_FIXTURE(Fixture, "disallow_less_specific_assign") { CheckResult result = check(R"( local a:number = 10 @@ -63,7 +63,7 @@ TEST_CASE_FIXTURE(Fixture, "disallow_less_specifc_assign") LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "disallow_less_specifc_assign2") +TEST_CASE_FIXTURE(Fixture, "disallow_less_specific_assign2") { CheckResult result = check(R"( local a:number? = 10 @@ -181,8 +181,6 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_optional_property") TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property") { - ScopedFastFlag luauMissingUnionPropertyError("LuauMissingUnionPropertyError", true); - CheckResult result = check(R"( type A = {x: number} type B = {} @@ -242,8 +240,6 @@ TEST_CASE_FIXTURE(Fixture, "union_equality_comparisons") TEST_CASE_FIXTURE(Fixture, "optional_union_members") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( local a = { a = { x = 1, y = 2 }, b = 3 } type A = typeof(a) @@ -259,8 +255,6 @@ local c = bf.a.y TEST_CASE_FIXTURE(Fixture, "optional_union_functions") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( local a = {} function a.foo(x:number, y:number) return x + y end @@ -276,8 +270,6 @@ local c = b.foo(1, 2) TEST_CASE_FIXTURE(Fixture, "optional_union_methods") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( local a = {} function a:foo(x:number, y:number) return x + y end @@ -310,8 +302,6 @@ return f() TEST_CASE_FIXTURE(Fixture, "optional_field_access_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = { x: number } local b: A? = { x = 2 } @@ -327,8 +317,6 @@ local d = b.y TEST_CASE_FIXTURE(Fixture, "optional_index_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = {number} local a: A? = {1, 2, 3} @@ -341,8 +329,6 @@ local b = a[1] TEST_CASE_FIXTURE(Fixture, "optional_call_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = (number) -> number local a: A? = function(a) return -a end @@ -355,8 +341,6 @@ local b = a(4) TEST_CASE_FIXTURE(Fixture, "optional_assignment_errors") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = { x: number } local a: A? = { x = 2 } @@ -378,8 +362,6 @@ a.x = 2 TEST_CASE_FIXTURE(Fixture, "optional_length_error") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - CheckResult result = check(R"( type A = {number} local a: A? = {1, 2, 3} @@ -392,9 +374,6 @@ local b = #a TEST_CASE_FIXTURE(Fixture, "optional_missing_key_error_details") { - ScopedFastFlag luauExtraNilRecovery("LuauExtraNilRecovery", true); - ScopedFastFlag luauMissingUnionPropertyError("LuauMissingUnionPropertyError", true); - CheckResult result = check(R"( type A = { x: number, y: number } type B = { x: number, y: number } diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index a679e3fd2..930c1a39b 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -265,4 +265,64 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") CHECK_EQ("{ f: t1 } where t1 = () -> { f: () -> { f: ({ f: t1 }) -> (), signal: { f: (any) -> () } } }", toString(result)); } +TEST_CASE("tagging_tables") +{ + ScopedFastFlag sff{"LuauRefactorTagging", true}; + + TypeVar ttv{TableTypeVar{}}; + CHECK(!Luau::hasTag(&ttv, "foo")); + Luau::attachTag(&ttv, "foo"); + CHECK(Luau::hasTag(&ttv, "foo")); +} + +TEST_CASE("tagging_classes") +{ + ScopedFastFlag sff{"LuauRefactorTagging", true}; + + TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}}; + CHECK(!Luau::hasTag(&base, "foo")); + Luau::attachTag(&base, "foo"); + CHECK(Luau::hasTag(&base, "foo")); +} + +TEST_CASE("tagging_subclasses") +{ + ScopedFastFlag sff{"LuauRefactorTagging", true}; + + TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}}; + TypeVar derived{ClassTypeVar{"Derived", {}, &base, std::nullopt, {}, nullptr}}; + + CHECK(!Luau::hasTag(&base, "foo")); + CHECK(!Luau::hasTag(&derived, "foo")); + + Luau::attachTag(&base, "foo"); + CHECK(Luau::hasTag(&base, "foo")); + CHECK(Luau::hasTag(&derived, "foo")); + + Luau::attachTag(&derived, "bar"); + CHECK(!Luau::hasTag(&base, "bar")); + CHECK(Luau::hasTag(&derived, "bar")); +} + +TEST_CASE("tagging_functions") +{ + ScopedFastFlag sff{"LuauRefactorTagging", true}; + + TypePackVar empty{TypePack{}}; + TypeVar ftv{FunctionTypeVar{&empty, &empty}}; + CHECK(!Luau::hasTag(&ftv, "foo")); + Luau::attachTag(&ftv, "foo"); + CHECK(Luau::hasTag(&ftv, "foo")); +} + +TEST_CASE("tagging_props") +{ + ScopedFastFlag sff{"LuauRefactorTagging", true}; + + Property prop{}; + CHECK(!Luau::hasTag(prop, "foo")); + Luau::attachTag(prop, "foo"); + CHECK(Luau::hasTag(prop, "foo")); +} + TEST_SUITE_END(); diff --git a/tests/conformance/closure.lua b/tests/conformance/closure.lua index 79f8d9c2c..aac42c56f 100644 --- a/tests/conformance/closure.lua +++ b/tests/conformance/closure.lua @@ -319,7 +319,7 @@ end assert(a == 5^4) --- access to locals of collected corroutines +-- access to locals of collected coroutines local C = {}; setmetatable(C, {__mode = "kv"}) local x = coroutine.wrap (function () local a = 10 diff --git a/tests/conformance/coroutine.lua b/tests/conformance/coroutine.lua index 73c3833da..753296422 100644 --- a/tests/conformance/coroutine.lua +++ b/tests/conformance/coroutine.lua @@ -185,7 +185,7 @@ end assert(a == 5^4) --- access to locals of collected corroutines +-- access to locals of collected coroutines local C = {}; setmetatable(C, {__mode = "kv"}) local x = coroutine.wrap (function () local a = 10 diff --git a/tests/conformance/gc.lua b/tests/conformance/gc.lua index fd4b4de1b..4263dfda7 100644 --- a/tests/conformance/gc.lua +++ b/tests/conformance/gc.lua @@ -277,7 +277,7 @@ do assert(getmetatable(o) == tt) -- create new objects during GC local a = 'xuxu'..(10+3)..'joao', {} - ___Glob = o -- ressurect object! + ___Glob = o -- resurrect object! newproxy(o) -- creates a new one with same metatable print(">>> closing state " .. "<<<\n") end diff --git a/tests/conformance/locals.lua b/tests/conformance/locals.lua index cbe5f92d0..2d8d004b8 100644 --- a/tests/conformance/locals.lua +++ b/tests/conformance/locals.lua @@ -117,7 +117,7 @@ if rawget(_G, "querytab") then local t = querytab(a) for k,_ in pairs(a) do a[k] = nil end - collectgarbage() -- restore GC and collect dead fiels in `a' + collectgarbage() -- restore GC and collect dead fields in `a' for i=0,t-1 do local k = querytab(a, i) assert(k == nil or type(k) == 'number' or k == 'alo') diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index 5e8b9398c..d5bca44f0 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -172,7 +172,7 @@ end a = nil --- testing implicit convertions +-- testing implicit conversions local a,b = '10', '20' assert(a*b == 200 and a+b == 30 and a-b == -10 and a/b == 0.5 and -b == -20) diff --git a/tests/conformance/pm.lua b/tests/conformance/pm.lua index 9a113964a..263759ac0 100644 --- a/tests/conformance/pm.lua +++ b/tests/conformance/pm.lua @@ -21,9 +21,9 @@ a,b = string.find('alo', '') assert(a == 1 and b == 0) a,b = string.find('a\0o a\0o a\0o', 'a', 1) -- first position assert(a == 1 and b == 1) -a,b = string.find('a\0o a\0o a\0o', 'a\0o', 2) -- starts in the midle +a,b = string.find('a\0o a\0o a\0o', 'a\0o', 2) -- starts in the middle assert(a == 5 and b == 7) -a,b = string.find('a\0o a\0o a\0o', 'a\0o', 9) -- starts in the midle +a,b = string.find('a\0o a\0o a\0o', 'a\0o', 9) -- starts in the middle assert(a == 9 and b == 11) a,b = string.find('a\0a\0a\0a\0\0ab', '\0ab', 2); -- finds at the end assert(a == 9 and b == 11); diff --git a/tests/conformance/tmerror.lua b/tests/conformance/tmerror.lua new file mode 100644 index 000000000..1ad4dd16f --- /dev/null +++ b/tests/conformance/tmerror.lua @@ -0,0 +1,15 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes + +-- Generate an error (i.e. throw an exception) inside a tag method which is indirectly +-- called via pcall. +-- This test is meant to detect a regression in handling errors inside a tag method + +local testtable = {} +setmetatable(testtable, { __index = function() error("Error") end }) + +pcall(function() + testtable.missingmethod() +end) + +return('OK') diff --git a/tools/gdb-printers.py b/tools/gdb-printers.py index c711c5e2e..017b9f95e 100644 --- a/tools/gdb-printers.py +++ b/tools/gdb-printers.py @@ -11,9 +11,9 @@ def to_string(self): return type.name + " [" + str(value) + "]" def match_printer(val): - type = val.type.strip_typedefs() - if type.name and type.name.startswith('Luau::Variant<'): - return VariantPrinter(val) - return None + type = val.type.strip_typedefs() + if type.name and type.name.startswith('Luau::Variant<'): + return VariantPrinter(val) + return None gdb.pretty_printers.append(match_printer) From 82d74e6f73fa8d9a81c4c932c668543c08c27597 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 11 Nov 2021 18:12:39 -0800 Subject: [PATCH 03/32] Sync to upstream/release/504 --- .gitignore | 1 + Analysis/include/Luau/Error.h | 13 +- Analysis/include/Luau/FileResolver.h | 2 +- Analysis/include/Luau/Transpiler.h | 3 +- Analysis/include/Luau/TypeInfer.h | 8 +- Analysis/include/Luau/TypePack.h | 20 +- Analysis/include/Luau/TypeVar.h | 45 +- Analysis/include/Luau/Unifiable.h | 3 - Analysis/include/Luau/Unifier.h | 2 +- Analysis/src/Autocomplete.cpp | 46 +- Analysis/src/BuiltinDefinitions.cpp | 244 +------- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 22 +- Analysis/src/Error.cpp | 291 ++++++---- Analysis/src/Frontend.cpp | 9 +- Analysis/src/Linter.cpp | 10 +- Analysis/src/Module.cpp | 17 +- Analysis/src/Predicate.cpp | 4 - Analysis/src/RequireTracer.cpp | 8 +- Analysis/src/Substitution.cpp | 13 +- Analysis/src/ToString.cpp | 114 +--- Analysis/src/Transpiler.cpp | 117 +++- Analysis/src/TypeAttach.cpp | 41 +- Analysis/src/TypeInfer.cpp | 586 ++++++-------------- Analysis/src/TypePack.cpp | 8 +- Analysis/src/TypeVar.cpp | 70 ++- Analysis/src/Unifiable.cpp | 10 - Analysis/src/Unifier.cpp | 345 ++++++------ Ast/include/Luau/Parser.h | 1 - Ast/src/Parser.cpp | 51 +- CLI/Analyze.cpp | 27 +- CLI/Repl.cpp | 93 +++- CMakeLists.txt | 38 +- Compiler/include/Luau/Bytecode.h | 4 + Compiler/include/Luau/Compiler.h | 3 + Compiler/src/Compiler.cpp | 47 +- Makefile | 10 +- VM/include/lua.h | 4 + VM/src/lapi.cpp | 11 + VM/src/lbitlib.cpp | 42 ++ VM/src/lbuiltins.cpp | 54 +- VM/src/lgc.cpp | 164 +----- VM/src/lstrlib.cpp | 20 +- VM/src/ltable.cpp | 1 + VM/src/ltable.h | 1 - VM/src/ltablib.cpp | 8 - bench/tests/chess.lua | 76 +-- fuzz/luau.proto | 7 + fuzz/proto.cpp | 7 + fuzz/protoprint.cpp | 10 + tests/AstQuery.test.cpp | 1 - tests/Autocomplete.test.cpp | 3 - tests/Compiler.test.cpp | 126 +++++ tests/Conformance.test.cpp | 6 +- tests/Linter.test.cpp | 6 - tests/Parser.test.cpp | 43 +- tests/Predicate.test.cpp | 6 - tests/ToString.test.cpp | 9 - tests/Transpiler.test.cpp | 252 ++++++++- tests/TypeInfer.aliases.test.cpp | 3 - tests/TypeInfer.builtins.test.cpp | 8 - tests/TypeInfer.classes.test.cpp | 25 +- tests/TypeInfer.definitions.test.cpp | 3 - tests/TypeInfer.generics.test.cpp | 89 --- tests/TypeInfer.intersectionTypes.test.cpp | 39 ++ tests/TypeInfer.provisional.test.cpp | 7 +- tests/TypeInfer.refinements.test.cpp | 44 +- tests/TypeInfer.tables.test.cpp | 72 +++ tests/TypeInfer.test.cpp | 34 +- tests/TypeInfer.tryUnify.test.cpp | 5 - tests/TypeInfer.typePacks.cpp | 12 - tests/TypeInfer.unionTypes.test.cpp | 41 +- tests/TypeVar.test.cpp | 2 - tests/conformance/bitwise.lua | 16 + 73 files changed, 1737 insertions(+), 1846 deletions(-) diff --git a/.gitignore b/.gitignore index 0b2422ced..fa11b45b5 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,4 @@ ^default.prof* ^fuzz-* ^luau$ +/.vs diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index ac6f13e96..9ee750043 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -8,11 +8,20 @@ namespace Luau { +struct TypeError; struct TypeMismatch { - TypeId wantedType; - TypeId givenType; + TypeMismatch() = default; + TypeMismatch(TypeId wantedType, TypeId givenType); + TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason); + TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, TypeError error); + + TypeId wantedType = nullptr; + TypeId givenType = nullptr; + + std::string reason; + std::shared_ptr error; bool operator==(const TypeMismatch& rhs) const; }; diff --git a/Analysis/include/Luau/FileResolver.h b/Analysis/include/Luau/FileResolver.h index a05ec5e91..9b74fc12d 100644 --- a/Analysis/include/Luau/FileResolver.h +++ b/Analysis/include/Luau/FileResolver.h @@ -53,7 +53,7 @@ struct FileResolver } // DEPRECATED APIS - // These are going to be removed with LuauNewRequireTracer + // These are going to be removed with LuauNewRequireTrace2 virtual bool moduleExists(const ModuleName& name) const = 0; virtual std::optional fromAstFragment(AstExpr* expr) const = 0; virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0; diff --git a/Analysis/include/Luau/Transpiler.h b/Analysis/include/Luau/Transpiler.h index 817459fed..df01008ca 100644 --- a/Analysis/include/Luau/Transpiler.h +++ b/Analysis/include/Luau/Transpiler.h @@ -18,6 +18,7 @@ struct TranspileResult std::string parseError; // Nonempty if the transpile failed }; +std::string toString(AstNode* node); void dump(AstNode* node); // Never fails on a well-formed AST @@ -25,6 +26,6 @@ std::string transpile(AstStatBlock& ast); std::string transpileWithTypes(AstStatBlock& block); // Only fails when parsing fails -TranspileResult transpile(std::string_view source, ParseOptions options = ParseOptions{}); +TranspileResult transpile(std::string_view source, ParseOptions options = ParseOptions{}, bool withTypes = false); } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 9d62fef0b..306ac77d8 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -263,8 +263,6 @@ struct TypeChecker * */ TypeId instantiate(const ScopePtr& scope, TypeId ty, Location location); - // Removed by FFlag::LuauRankNTypes - TypePackId DEPRECATED_instantiate(const ScopePtr& scope, TypePackId ty, Location location); // Replace any free types or type packs by `any`. // This is used when exporting types from modules, to make sure free types don't leak. @@ -298,8 +296,6 @@ struct TypeChecker // Produce a new free type var. TypeId freshType(const ScopePtr& scope); TypeId freshType(TypeLevel level); - TypeId DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneric = false); - TypeId DEPRECATED_freshType(TypeLevel level, bool canBeGeneric = false); // Returns nullopt if the predicate filters down the TypeId to 0 options. std::optional filterMap(TypeId type, TypeIdPredicate predicate); @@ -326,10 +322,8 @@ struct TypeChecker TypePackId addTypePack(std::initializer_list&& ty); TypePackId freshTypePack(const ScopePtr& scope); TypePackId freshTypePack(TypeLevel level); - TypePackId DEPRECATED_freshTypePack(const ScopePtr& scope, bool canBeGeneric = false); - TypePackId DEPRECATED_freshTypePack(TypeLevel level, bool canBeGeneric = false); - TypeId resolveType(const ScopePtr& scope, const AstType& annotation, bool canBeGeneric = false); + TypeId resolveType(const ScopePtr& scope, const AstType& annotation); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypeList& types); TypePackId resolveTypePack(const ScopePtr& scope, const AstTypePack& annotation); TypeId instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index d987d46ca..e72808da7 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -8,8 +8,6 @@ #include #include -LUAU_FASTFLAG(LuauAddMissingFollow) - namespace Luau { @@ -128,13 +126,10 @@ TypePack* asMutable(const TypePack* tp); template const T* get(TypePackId tp) { - if (FFlag::LuauAddMissingFollow) - { - LUAU_ASSERT(tp); + LUAU_ASSERT(tp); - if constexpr (!std::is_same_v) - LUAU_ASSERT(get_if(&tp->ty) == nullptr); - } + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tp->ty) == nullptr); return get_if(&(tp->ty)); } @@ -142,13 +137,10 @@ const T* get(TypePackId tp) template T* getMutable(TypePackId tp) { - if (FFlag::LuauAddMissingFollow) - { - LUAU_ASSERT(tp); + LUAU_ASSERT(tp); - if constexpr (!std::is_same_v) - LUAU_ASSERT(get_if(&tp->ty) == nullptr); - } + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tp->ty) == nullptr); return get_if(&(asMutable(tp)->ty)); } diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 9611e881f..6bd7932db 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -18,7 +18,6 @@ LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) LUAU_FASTINT(LuauTypeMaximumStringifierLength) -LUAU_FASTFLAG(LuauAddMissingFollow) namespace Luau { @@ -413,13 +412,17 @@ bool maybeGeneric(const TypeId ty); struct SingletonTypes { - const TypeId nilType = &nilType_; - const TypeId numberType = &numberType_; - const TypeId stringType = &stringType_; - const TypeId booleanType = &booleanType_; - const TypeId threadType = &threadType_; - const TypeId anyType = &anyType_; - const TypeId errorType = &errorType_; + const TypeId nilType; + const TypeId numberType; + const TypeId stringType; + const TypeId booleanType; + const TypeId threadType; + const TypeId anyType; + const TypeId errorType; + const TypeId optionalNumberType; + + const TypePackId anyTypePack; + const TypePackId errorTypePack; SingletonTypes(); SingletonTypes(const SingletonTypes&) = delete; @@ -427,14 +430,6 @@ struct SingletonTypes private: std::unique_ptr arena; - TypeVar nilType_; - TypeVar numberType_; - TypeVar stringType_; - TypeVar booleanType_; - TypeVar threadType_; - TypeVar anyType_; - TypeVar errorType_; - TypeId makeStringMetatable(); }; @@ -472,13 +467,10 @@ TypeVar* asMutable(TypeId ty); template const T* get(TypeId tv) { - if (FFlag::LuauAddMissingFollow) - { - LUAU_ASSERT(tv); + LUAU_ASSERT(tv); - if constexpr (!std::is_same_v) - LUAU_ASSERT(get_if(&tv->ty) == nullptr); - } + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tv->ty) == nullptr); return get_if(&tv->ty); } @@ -486,13 +478,10 @@ const T* get(TypeId tv) template T* getMutable(TypeId tv) { - if (FFlag::LuauAddMissingFollow) - { - LUAU_ASSERT(tv); + LUAU_ASSERT(tv); - if constexpr (!std::is_same_v) - LUAU_ASSERT(get_if(&tv->ty) == nullptr); - } + if constexpr (!std::is_same_v) + LUAU_ASSERT(get_if(&tv->ty) == nullptr); return get_if(&asMutable(tv)->ty); } diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index 10dbf3335..c2e07e466 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -63,12 +63,9 @@ using Name = std::string; struct Free { explicit Free(TypeLevel level); - Free(TypeLevel level, bool DEPRECATED_canBeGeneric); int index; TypeLevel level; - // Removed by FFlag::LuauRankNTypes - bool DEPRECATED_canBeGeneric = false; // True if this free type variable is part of a mutually // recursive type alias whose definitions haven't been // resolved yet. diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 56632e33c..be0aadd05 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -87,7 +87,6 @@ struct Unifier void tryUnifyWithAny(TypePackId any, TypePackId ty); std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name); - std::optional findMetatableEntry(TypeId type, std::string entry); public: // Report an "infinite type error" if the type "needle" already occurs within "haystack" @@ -102,6 +101,7 @@ struct Unifier bool isNonstrictMode() const; void checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType); + void checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType); [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 3c43c8086..1c94bb684 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -12,7 +12,6 @@ #include #include -LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAGVARIABLE(ElseElseIfCompletionImprovements, false); LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) @@ -369,20 +368,10 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId while (iter != endIter) { - if (FFlag::LuauAddMissingFollow) - { - if (isNil(*iter)) - ++iter; - else - break; - } + if (isNil(*iter)) + ++iter; else - { - if (auto primTy = Luau::get(*iter); primTy && primTy->type == PrimitiveTypeVar::NilType) - ++iter; - else - break; - } + break; } if (iter == endIter) @@ -397,21 +386,10 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId AutocompleteEntryMap inner; std::unordered_set innerSeen = seen; - if (FFlag::LuauAddMissingFollow) + if (isNil(*iter)) { - if (isNil(*iter)) - { - ++iter; - continue; - } - } - else - { - if (auto innerPrimTy = Luau::get(*iter); innerPrimTy && innerPrimTy->type == PrimitiveTypeVar::NilType) - { - ++iter; - continue; - } + ++iter; + continue; } autocompleteProps(module, typeArena, *iter, indexType, nodes, inner, innerSeen); @@ -496,7 +474,7 @@ static bool canSuggestInferredType(ScopePtr scope, TypeId ty) return false; // No syntax for unnamed tables with a metatable - if (const MetatableTypeVar* mtv = get(ty)) + if (get(ty)) return false; if (const TableTypeVar* ttv = get(ty)) @@ -688,7 +666,7 @@ static std::optional functionIsExpectedAt(const Module& module, AstNode* n TypeId expectedType = follow(*it); - if (const FunctionTypeVar* ftv = get(expectedType)) + if (get(expectedType)) return true; if (const IntersectionTypeVar* itv = get(expectedType)) @@ -1519,10 +1497,10 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName return {}; TypeChecker& typeChecker = - (frontend.options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); ModulePtr module = - (frontend.options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel ? frontend.moduleResolverForAutocomplete.getModule(moduleName) - : frontend.moduleResolver.getModule(moduleName)); + (frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName) + : frontend.moduleResolver.getModule(moduleName)); if (!module) return {}; @@ -1550,7 +1528,7 @@ OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view sourceModule->commentLocations = std::move(result.commentLocations); TypeChecker& typeChecker = - (frontend.options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index f6f2363c6..62a06a3cd 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -8,10 +8,7 @@ #include -LUAU_FASTFLAG(LuauParseGenericFunctions) -LUAU_FASTFLAG(LuauGenericFunctions) -LUAU_FASTFLAG(LuauRankNTypes) -LUAU_FASTFLAG(LuauNewRequireTrace) +LUAU_FASTFLAG(LuauNewRequireTrace2) /** FIXME: Many of these type definitions are not quite completely accurate. * @@ -185,25 +182,11 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypeId numberType = typeChecker.numberType; TypeId booleanType = typeChecker.booleanType; TypeId nilType = typeChecker.nilType; - TypeId stringType = typeChecker.stringType; - TypeId threadType = typeChecker.threadType; - TypeId anyType = typeChecker.anyType; TypeArena& arena = typeChecker.globalTypes; - TypeId optionalNumber = makeOption(typeChecker, arena, numberType); - TypeId optionalString = makeOption(typeChecker, arena, stringType); - TypeId optionalBoolean = makeOption(typeChecker, arena, booleanType); - - TypeId stringOrNumber = makeUnion(arena, {stringType, numberType}); - - TypePackId emptyPack = arena.addTypePack({}); TypePackId oneNumberPack = arena.addTypePack({numberType}); - TypePackId oneStringPack = arena.addTypePack({stringType}); TypePackId oneBooleanPack = arena.addTypePack({booleanType}); - TypePackId oneAnyPack = arena.addTypePack({anyType}); - - TypePackId anyTypePack = typeChecker.anyTypePack; TypePackId numberVariadicList = arena.addTypePack(TypePackVar{VariadicTypePack{numberType}}); TypePackId listOfAtLeastOneNumber = arena.addTypePack(TypePack{{numberType}, numberVariadicList}); @@ -215,8 +198,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypeId listOfAtLeastZeroNumbersToNumberType = arena.addType(FunctionTypeVar{numberVariadicList, oneNumberPack}); - TypeId stringToAnyMap = arena.addType(TableTypeVar{{}, TableIndexer(stringType, anyType), typeChecker.globalScope->level}); - LoadDefinitionFileResult loadResult = Luau::loadDefinitionFile(typeChecker, typeChecker.globalScope, getBuiltinDefinitionSource(), "@luau"); LUAU_ASSERT(loadResult.success); @@ -236,8 +217,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker) ttv->props["btest"] = makeProperty(arena.addType(FunctionTypeVar{listOfAtLeastOneNumber, oneBooleanPack}), "@luau/global/bit32.btest"); } - TypeId anyFunction = arena.addType(FunctionTypeVar{anyTypePack, anyTypePack}); - TypeId genericK = arena.addType(GenericTypeVar{"K"}); TypeId genericV = arena.addType(GenericTypeVar{"V"}); TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level}); @@ -252,222 +231,6 @@ void registerBuiltinTypes(TypeChecker& typeChecker) addGlobalBinding(typeChecker, "string", it->second.type, "@luau"); - if (!FFlag::LuauParseGenericFunctions || !FFlag::LuauGenericFunctions) - { - TableTypeVar::Props debugLib{ - {"info", {makeIntersection(arena, - { - arena.addType(FunctionTypeVar{arena.addTypePack({typeChecker.threadType, numberType, stringType}), anyTypePack}), - arena.addType(FunctionTypeVar{arena.addTypePack({numberType, stringType}), anyTypePack}), - arena.addType(FunctionTypeVar{arena.addTypePack({anyFunction, stringType}), anyTypePack}), - })}}, - {"traceback", {makeIntersection(arena, - { - makeFunction(arena, std::nullopt, {optionalString, optionalNumber}, {stringType}), - makeFunction(arena, std::nullopt, {typeChecker.threadType, optionalString, optionalNumber}, {stringType}), - })}}, - }; - - assignPropDocumentationSymbols(debugLib, "@luau/global/debug"); - addGlobalBinding(typeChecker, "debug", - arena.addType(TableTypeVar{debugLib, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}), "@luau"); - - TableTypeVar::Props utf8Lib = { - {"char", {arena.addType(FunctionTypeVar{listOfAtLeastOneNumber, oneStringPack})}}, // FIXME - {"charpattern", {stringType}}, - {"codes", {makeFunction(arena, std::nullopt, {stringType}, - {makeFunction(arena, std::nullopt, {stringType, numberType}, {numberType, numberType}), stringType, numberType})}}, - {"codepoint", - {arena.addType(FunctionTypeVar{arena.addTypePack({stringType, optionalNumber, optionalNumber}), listOfAtLeastOneNumber})}}, // FIXME - {"len", {makeFunction(arena, std::nullopt, {stringType, optionalNumber, optionalNumber}, {optionalNumber, numberType})}}, - {"offset", {makeFunction(arena, std::nullopt, {stringType, optionalNumber, optionalNumber}, {numberType})}}, - {"nfdnormalize", {makeFunction(arena, std::nullopt, {stringType}, {stringType})}}, - {"graphemes", {makeFunction(arena, std::nullopt, {stringType, optionalNumber, optionalNumber}, - {makeFunction(arena, std::nullopt, {}, {numberType, numberType})})}}, - {"nfcnormalize", {makeFunction(arena, std::nullopt, {stringType}, {stringType})}}, - }; - - assignPropDocumentationSymbols(utf8Lib, "@luau/global/utf8"); - addGlobalBinding( - typeChecker, "utf8", arena.addType(TableTypeVar{utf8Lib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau"); - - TypeId optionalV = makeOption(typeChecker, arena, genericV); - - TypeId arrayOfV = arena.addType(TableTypeVar{{}, TableIndexer(numberType, genericV), typeChecker.globalScope->level}); - - TypePackId unpackArgsPack = arena.addTypePack(TypePack{{arrayOfV, optionalNumber, optionalNumber}}); - TypePackId unpackReturnPack = arena.addTypePack(TypePack{{}, anyTypePack}); - TypeId unpackFunc = arena.addType(FunctionTypeVar{{genericV}, {}, unpackArgsPack, unpackReturnPack}); - - TypeId packResult = arena.addType(TableTypeVar{ - TableTypeVar::Props{{"n", {numberType}}}, TableIndexer{numberType, numberType}, typeChecker.globalScope->level, TableState::Sealed}); - TypePackId packArgsPack = arena.addTypePack(TypePack{{}, anyTypePack}); - TypePackId packReturnPack = arena.addTypePack(TypePack{{packResult}}); - - TypeId comparator = makeFunction(arena, std::nullopt, {genericV, genericV}, {booleanType}); - TypeId optionalComparator = makeOption(typeChecker, arena, comparator); - - TypeId packFn = arena.addType(FunctionTypeVar(packArgsPack, packReturnPack)); - - TableTypeVar::Props tableLib = { - {"concat", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, optionalString, optionalNumber, optionalNumber}, {stringType})}}, - {"insert", {makeIntersection(arena, {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, genericV}, {}), - makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, numberType, genericV}, {})})}}, - {"maxn", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV}, {numberType})}}, - {"remove", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, optionalNumber}, {optionalV})}}, - {"sort", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, optionalComparator}, {})}}, - {"create", {makeFunction(arena, std::nullopt, {genericV}, {}, {numberType, optionalV}, {arrayOfV})}}, - {"find", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, genericV, optionalNumber}, {optionalNumber})}}, - - {"unpack", {unpackFunc}}, // FIXME - {"pack", {packFn}}, - - // Lua 5.0 compat - {"getn", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV}, {numberType})}}, - {"foreach", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, - {mapOfKtoV, makeFunction(arena, std::nullopt, {genericK, genericV}, {})}, {})}}, - {"foreachi", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, makeFunction(arena, std::nullopt, {genericV}, {})}, {})}}, - - // backported from Lua 5.3 - {"move", {makeFunction(arena, std::nullopt, {genericV}, {}, {arrayOfV, numberType, numberType, numberType, arrayOfV}, {})}}, - - // added in Luau (borrowed from LuaJIT) - {"clear", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV}, {})}}, - - {"freeze", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV}, {mapOfKtoV})}}, - {"isfrozen", {makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV}, {booleanType})}}, - }; - - assignPropDocumentationSymbols(tableLib, "@luau/global/table"); - addGlobalBinding( - typeChecker, "table", arena.addType(TableTypeVar{tableLib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau"); - - TableTypeVar::Props coroutineLib = { - {"create", {makeFunction(arena, std::nullopt, {anyFunction}, {threadType})}}, - {"resume", {arena.addType(FunctionTypeVar{arena.addTypePack(TypePack{{threadType}, anyTypePack}), anyTypePack})}}, - {"running", {makeFunction(arena, std::nullopt, {}, {threadType})}}, - {"status", {makeFunction(arena, std::nullopt, {threadType}, {stringType})}}, - {"wrap", {makeFunction( - arena, std::nullopt, {anyFunction}, {anyType})}}, // FIXME this technically returns a function, but we can't represent this - // atm since it can be called with different arg types at different times - {"yield", {arena.addType(FunctionTypeVar{anyTypePack, anyTypePack})}}, - {"isyieldable", {makeFunction(arena, std::nullopt, {}, {booleanType})}}, - }; - - assignPropDocumentationSymbols(coroutineLib, "@luau/global/coroutine"); - addGlobalBinding(typeChecker, "coroutine", - arena.addType(TableTypeVar{coroutineLib, std::nullopt, typeChecker.globalScope->level, TableState::Sealed}), "@luau"); - - TypeId genericT = arena.addType(GenericTypeVar{"T"}); - TypeId genericR = arena.addType(GenericTypeVar{"R"}); - - // assert returns all arguments - TypePackId assertArgs = arena.addTypePack({genericT, optionalString}); - TypePackId assertRets = arena.addTypePack({genericT}); - addGlobalBinding(typeChecker, "assert", arena.addType(FunctionTypeVar{assertArgs, assertRets}), "@luau"); - - addGlobalBinding(typeChecker, "print", arena.addType(FunctionTypeVar{anyTypePack, emptyPack}), "@luau"); - - addGlobalBinding(typeChecker, "type", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT}, {stringType}), "@luau"); - addGlobalBinding(typeChecker, "typeof", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT}, {stringType}), "@luau"); - - addGlobalBinding(typeChecker, "error", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT, optionalNumber}, {}), "@luau"); - - addGlobalBinding(typeChecker, "tostring", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT}, {stringType}), "@luau"); - addGlobalBinding( - typeChecker, "tonumber", makeFunction(arena, std::nullopt, {genericT}, {}, {genericT, optionalNumber}, {numberType}), "@luau"); - - addGlobalBinding( - typeChecker, "rawequal", makeFunction(arena, std::nullopt, {genericT, genericR}, {}, {genericT, genericR}, {booleanType}), "@luau"); - addGlobalBinding( - typeChecker, "rawget", makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV, genericK}, {genericV}), "@luau"); - addGlobalBinding(typeChecker, "rawset", - makeFunction(arena, std::nullopt, {genericK, genericV}, {}, {mapOfKtoV, genericK, genericV}, {mapOfKtoV}), "@luau"); - - TypePackId genericTPack = arena.addTypePack({genericT}); - TypePackId genericRPack = arena.addTypePack({genericR}); - TypeId genericArgsToReturnFunction = arena.addType( - FunctionTypeVar{{genericT, genericR}, {}, arena.addTypePack(TypePack{{}, genericTPack}), arena.addTypePack(TypePack{{}, genericRPack})}); - - TypeId setfenvArgType = makeUnion(arena, {numberType, genericArgsToReturnFunction}); - TypeId setfenvReturnType = makeOption(typeChecker, arena, genericArgsToReturnFunction); - addGlobalBinding(typeChecker, "setfenv", makeFunction(arena, std::nullopt, {setfenvArgType, stringToAnyMap}, {setfenvReturnType}), "@luau"); - - TypePackId ipairsArgsTypePack = arena.addTypePack({arrayOfV}); - - TypeId ipairsNextFunctionType = arena.addType( - FunctionTypeVar{{genericK, genericV}, {}, arena.addTypePack({arrayOfV, numberType}), arena.addTypePack({numberType, genericV})}); - - // ipairs returns 'next, Array, 0' so we would need type-level primitives and change to - // again, we have a direct reference to 'next' because ipairs returns it - // ipairs(t: Array) -> ((Array) -> (number, V), Array, 0) - TypePackId ipairsReturnTypePack = arena.addTypePack(TypePack{{ipairsNextFunctionType, arrayOfV, numberType}}); - - // ipairs(t: Array) -> ((Array) -> (number, V), Array, number) - addGlobalBinding(typeChecker, "ipairs", arena.addType(FunctionTypeVar{{genericV}, {}, ipairsArgsTypePack, ipairsReturnTypePack}), "@luau"); - - TypePackId pcallArg0FnArgs = arena.addTypePack(TypePackVar{GenericTypeVar{"A"}}); - TypePackId pcallArg0FnRet = arena.addTypePack(TypePackVar{GenericTypeVar{"R"}}); - TypeId pcallArg0 = arena.addType(FunctionTypeVar{pcallArg0FnArgs, pcallArg0FnRet}); - TypePackId pcallArgsTypePack = arena.addTypePack(TypePack{{pcallArg0}, pcallArg0FnArgs}); - - TypePackId pcallReturnTypePack = arena.addTypePack(TypePack{{booleanType}, pcallArg0FnRet}); - - // pcall(f: (A...) -> R..., args: A...) -> boolean, R... - addGlobalBinding(typeChecker, "pcall", - arena.addType(FunctionTypeVar{{}, {pcallArg0FnArgs, pcallArg0FnRet}, pcallArgsTypePack, pcallReturnTypePack}), "@luau"); - - // errors thrown by the function 'f' are propagated onto the function 'err' that accepts it. - // and either 'f' or 'err' are valid results of this xpcall - // if 'err' did throw an error, then it returns: false, "error in error handling" - // TODO: the above is not represented (nor representable) in the type annotation below. - // - // The real type of xpcall is as such: (f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, - // R2...) - TypePackId genericAPack = arena.addTypePack(TypePackVar{GenericTypeVar{"A"}}); - TypePackId genericR1Pack = arena.addTypePack(TypePackVar{GenericTypeVar{"R1"}}); - TypePackId genericR2Pack = arena.addTypePack(TypePackVar{GenericTypeVar{"R2"}}); - - TypeId genericE = arena.addType(GenericTypeVar{"E"}); - - TypeId xpcallFArg = arena.addType(FunctionTypeVar{genericAPack, genericR1Pack}); - TypeId xpcallErrArg = arena.addType(FunctionTypeVar{arena.addTypePack({genericE}), genericR2Pack}); - - TypePackId xpcallArgsPack = arena.addTypePack({{xpcallFArg, xpcallErrArg}, genericAPack}); - TypePackId xpcallRetPack = arena.addTypePack({{booleanType}, genericR1Pack}); // FIXME - - addGlobalBinding(typeChecker, "xpcall", - arena.addType(FunctionTypeVar{{genericE}, {genericAPack, genericR1Pack, genericR2Pack}, xpcallArgsPack, xpcallRetPack}), "@luau"); - - addGlobalBinding(typeChecker, "unpack", unpackFunc, "@luau"); - - TypePackId selectArgsTypePack = arena.addTypePack(TypePack{ - {stringOrNumber}, - anyTypePack // FIXME? select() is tricky. - }); - - addGlobalBinding(typeChecker, "select", arena.addType(FunctionTypeVar{selectArgsTypePack, anyTypePack}), "@luau"); - - // TODO: not completely correct. loadstring's return type should be a function or (nil, string) - TypeId loadstringFunc = arena.addType(FunctionTypeVar{anyTypePack, oneAnyPack}); - - addGlobalBinding(typeChecker, "loadstring", - makeFunction(arena, std::nullopt, {stringType, optionalString}, - { - makeOption(typeChecker, arena, loadstringFunc), - makeOption(typeChecker, arena, stringType), - }), - "@luau"); - - // a userdata object is "roughly" the same as a sealed empty table - // except `type(newproxy(false))` evaluates to "userdata" so we may need another special type here too. - // another important thing to note: the value passed in conditionally creates an empty metatable, and you have to use getmetatable, NOT - // setmetatable. - // TODO: change this to something Luau can understand how to reject `setmetatable(newproxy(false or true), {})`. - TypeId sealedTable = arena.addType(TableTypeVar(TableState::Sealed, typeChecker.globalScope->level)); - addGlobalBinding(typeChecker, "newproxy", makeFunction(arena, std::nullopt, {optionalBoolean}, {sealedTable}), "@luau"); - } - // next(t: Table, i: K | nil) -> (K, V) TypePackId nextArgsTypePack = arena.addTypePack(TypePack{{mapOfKtoV, makeOption(typeChecker, arena, genericK)}}); addGlobalBinding(typeChecker, "next", @@ -475,8 +238,7 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypePackId pairsArgsTypePack = arena.addTypePack({mapOfKtoV}); - TypeId pairsNext = (FFlag::LuauRankNTypes ? arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}) - : getGlobalBinding(typeChecker, "next")); + TypeId pairsNext = arena.addType(FunctionTypeVar{nextArgsTypePack, arena.addTypePack(TypePack{{genericK, genericV}})}); TypePackId pairsReturnTypePack = arena.addTypePack(TypePack{{pairsNext, mapOfKtoV, nilType}}); // NOTE we are missing 'i: K | nil' argument in the first return types' argument. @@ -711,7 +473,7 @@ static std::optional> magicFunctionRequire( if (!checkRequirePath(typechecker, expr.args.data[0])) return std::nullopt; - const AstExpr* require = FFlag::LuauNewRequireTrace ? &expr : expr.args.data[0]; + const AstExpr* require = FFlag::LuauNewRequireTrace2 ? &expr : expr.args.data[0]; if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, *require)) return ExprResult{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 1e91561a6..96703ef16 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,9 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -LUAU_FASTFLAG(LuauParseGenericFunctions) -LUAU_FASTFLAG(LuauGenericFunctions) - namespace Luau { @@ -19,6 +16,8 @@ declare bit32: { bnot: (number) -> number, extract: (number, number, number?) -> number, replace: (number, number, number, number?) -> number, + countlz: (number) -> number, + countrz: (number) -> number, } declare math: { @@ -103,15 +102,6 @@ declare _VERSION: string declare function gcinfo(): number -)BUILTIN_SRC"; - -std::string getBuiltinDefinitionSource() -{ - std::string src = kBuiltinDefinitionLuaSrc; - - if (FFlag::LuauParseGenericFunctions && FFlag::LuauGenericFunctions) - { - src += R"( declare function print(...: T...) declare function type(value: T): string @@ -208,10 +198,12 @@ std::string getBuiltinDefinitionSource() -- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. declare function unpack(tab: {V}, i: number?, j: number?): ...V - )"; - } - return src; +)BUILTIN_SRC"; + +std::string getBuiltinDefinitionSource() +{ + return kBuiltinDefinitionLuaSrc; } } // namespace Luau diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 04d91444a..46ff2c72a 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -94,8 +94,23 @@ struct ErrorConverter { std::string operator()(const Luau::TypeMismatch& tm) const { - ToStringOptions opts; - return "Type '" + Luau::toString(tm.givenType, opts) + "' could not be converted into '" + Luau::toString(tm.wantedType, opts) + "'"; + std::string result = "Type '" + Luau::toString(tm.givenType) + "' could not be converted into '" + Luau::toString(tm.wantedType) + "'"; + + if (tm.error) + { + result += "\ncaused by:\n "; + + if (!tm.reason.empty()) + result += tm.reason + ". "; + + result += Luau::toString(*tm.error); + } + else if (!tm.reason.empty()) + { + result += "; " + tm.reason; + } + + return result; } std::string operator()(const Luau::UnknownSymbol& e) const @@ -478,9 +493,36 @@ struct InvalidNameChecker } }; +TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType) + : wantedType(wantedType) + , givenType(givenType) +{ +} + +TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason) + : wantedType(wantedType) + , givenType(givenType) + , reason(reason) +{ +} + +TypeMismatch::TypeMismatch(TypeId wantedType, TypeId givenType, std::string reason, TypeError error) + : wantedType(wantedType) + , givenType(givenType) + , reason(reason) + , error(std::make_shared(std::move(error))) +{ +} + bool TypeMismatch::operator==(const TypeMismatch& rhs) const { - return *wantedType == *rhs.wantedType && *givenType == *rhs.givenType; + if (!!error != !!rhs.error) + return false; + + if (error && !(*error == *rhs.error)) + return false; + + return *wantedType == *rhs.wantedType && *givenType == *rhs.givenType && reason == rhs.reason; } bool UnknownSymbol::operator==(const UnknownSymbol& rhs) const @@ -690,130 +732,141 @@ bool containsParseErrorName(const TypeError& error) return Luau::visit(InvalidNameChecker{}, error.data); } -void copyErrors(ErrorVec& errors, struct TypeArena& destArena) +template +void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks) { - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; - auto clone = [&](auto&& ty) { return ::Luau::clone(ty, destArena, seenTypes, seenTypePacks); }; auto visitErrorData = [&](auto&& e) { - using T = std::decay_t; + copyError(e, destArena, seenTypes, seenTypePacks); + }; - if constexpr (false) - { - } - else if constexpr (std::is_same_v) - { - e.wantedType = clone(e.wantedType); - e.givenType = clone(e.givenType); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.table = clone(e.table); - } - else if constexpr (std::is_same_v) - { - e.ty = clone(e.ty); - } - else if constexpr (std::is_same_v) - { - e.tableType = clone(e.tableType); - } - else if constexpr (std::is_same_v) - { - e.tableType = clone(e.tableType); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.typeFun = clone(e.typeFun); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.table = clone(e.table); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.ty = clone(e.ty); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.expectedReturnType = clone(e.expectedReturnType); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.superType = clone(e.superType); - e.subType = clone(e.subType); - } - else if constexpr (std::is_same_v) - { - } - else if constexpr (std::is_same_v) - { - e.optional = clone(e.optional); - } - else if constexpr (std::is_same_v) - { - e.type = clone(e.type); + if constexpr (false) + { + } + else if constexpr (std::is_same_v) + { + e.wantedType = clone(e.wantedType); + e.givenType = clone(e.givenType); - for (auto& ty : e.missing) - ty = clone(ty); - } - else - static_assert(always_false_v, "Non-exhaustive type switch"); + if (e.error) + visit(visitErrorData, e.error->data); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.table = clone(e.table); + } + else if constexpr (std::is_same_v) + { + e.ty = clone(e.ty); + } + else if constexpr (std::is_same_v) + { + e.tableType = clone(e.tableType); + } + else if constexpr (std::is_same_v) + { + e.tableType = clone(e.tableType); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.typeFun = clone(e.typeFun); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.table = clone(e.table); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.ty = clone(e.ty); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.expectedReturnType = clone(e.expectedReturnType); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.superType = clone(e.superType); + e.subType = clone(e.subType); + } + else if constexpr (std::is_same_v) + { + } + else if constexpr (std::is_same_v) + { + e.optional = clone(e.optional); + } + else if constexpr (std::is_same_v) + { + e.type = clone(e.type); + + for (auto& ty : e.missing) + ty = clone(ty); + } + else + static_assert(always_false_v, "Non-exhaustive type switch"); +} + +void copyErrors(ErrorVec& errors, TypeArena& destArena) +{ + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + + auto visitErrorData = [&](auto&& e) { + copyError(e, destArena, seenTypes, seenTypePacks); }; LUAU_ASSERT(!destArena.typeVars.isFrozen()); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 5e7af50c5..2f411274b 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -18,11 +18,10 @@ LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauTypeCheckTwice, false) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTFLAGVARIABLE(LuauSecondTypecheckKnowsTheDataModel, false) LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) -LUAU_FASTFLAG(LuauNewRequireTrace) +LUAU_FASTFLAG(LuauNewRequireTrace2) LUAU_FASTFLAGVARIABLE(LuauClearScopes, false) namespace Luau @@ -415,7 +414,7 @@ CheckResult Frontend::check(const ModuleName& name) // If we're typechecking twice, we do so. // The second typecheck is always in strict mode with DM awareness // to provide better typen information for IDE features. - if (options.typecheckTwice && FFlag::LuauSecondTypecheckKnowsTheDataModel) + if (options.typecheckTwice) { ModulePtr moduleForAutocomplete = typeCheckerForAutocomplete.check(sourceModule, Mode::Strict); moduleResolverForAutocomplete.modules[moduleName] = moduleForAutocomplete; @@ -897,7 +896,7 @@ std::optional FrontendModuleResolver::resolveModuleInfo(const Module const auto& exprs = it->second.exprs; const ModuleInfo* info = exprs.find(&pathExpr); - if (!info || (!FFlag::LuauNewRequireTrace && info->name.empty())) + if (!info || (!FFlag::LuauNewRequireTrace2 && info->name.empty())) return std::nullopt; return *info; @@ -914,7 +913,7 @@ const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) bool FrontendModuleResolver::moduleExists(const ModuleName& moduleName) const { - if (FFlag::LuauNewRequireTrace) + if (FFlag::LuauNewRequireTrace2) return frontend->sourceNodes.count(moduleName) != 0; else return frontend->fileResolver->moduleExists(moduleName); diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index bff947a56..1a5b24fe2 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -12,9 +12,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauLinterUnknownTypeVectorAware, false) -LUAU_FASTFLAGVARIABLE(LuauLinterTableMoveZero, false) - namespace Luau { @@ -1110,10 +1107,7 @@ class LintUnknownType : AstVisitor if (g && g->name == "type") { - if (FFlag::LuauLinterUnknownTypeVectorAware) - validateType(arg, {Kind_Primitive, Kind_Vector}, "primitive type"); - else - validateType(arg, {Kind_Primitive}, "primitive type"); + validateType(arg, {Kind_Primitive, Kind_Vector}, "primitive type"); } else if (g && g->name == "typeof") { @@ -2146,7 +2140,7 @@ class LintTableOperations : AstVisitor "wrap it in parentheses to silence"); } - if (FFlag::LuauLinterTableMoveZero && func->index == "move" && node->args.size >= 4) + if (func->index == "move" && node->args.size >= 4) { // table.move(t, 0, _, _) if (isConstant(args[1], 0.0)) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 2fd958965..880ffd2e5 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -12,7 +12,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) -LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans) LUAU_FASTFLAG(LuauTypeAliasPacks) LUAU_FASTFLAGVARIABLE(LuauCloneBoundTables, false) @@ -290,9 +289,7 @@ void TypeCloner::operator()(const FunctionTypeVar& t) for (TypePackId genericPack : t.genericPacks) ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, encounteredFreeType)); - if (FFlag::LuauSecondTypecheckKnowsTheDataModel) - ftv->tags = t.tags; - + ftv->tags = t.tags; ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, encounteredFreeType); ftv->argNames = t.argNames; ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, encounteredFreeType); @@ -319,12 +316,7 @@ void TypeCloner::operator()(const TableTypeVar& t) ttv->level = TypeLevel{0, 0}; for (const auto& [name, prop] : t.props) - { - if (FFlag::LuauSecondTypecheckKnowsTheDataModel) - ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; - else - ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location}; - } + ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; if (t.indexer) ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, encounteredFreeType), @@ -379,10 +371,7 @@ void TypeCloner::operator()(const ClassTypeVar& t) seenTypes[typeId] = result; for (const auto& [name, prop] : t.props) - if (FFlag::LuauSecondTypecheckKnowsTheDataModel) - ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; - else - ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location}; + ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; if (t.parent) ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, encounteredFreeType); diff --git a/Analysis/src/Predicate.cpp b/Analysis/src/Predicate.cpp index 25e63bffe..848627cf8 100644 --- a/Analysis/src/Predicate.cpp +++ b/Analysis/src/Predicate.cpp @@ -3,8 +3,6 @@ #include "Luau/Ast.h" -LUAU_FASTFLAG(LuauOrPredicate) - namespace Luau { @@ -60,8 +58,6 @@ std::string toString(const LValue& lvalue) void merge(RefinementMap& l, const RefinementMap& r, std::function f) { - LUAU_ASSERT(FFlag::LuauOrPredicate); - auto itL = l.begin(); auto itR = r.begin(); while (itL != l.end() && itR != r.end()) diff --git a/Analysis/src/RequireTracer.cpp b/Analysis/src/RequireTracer.cpp index 95910b562..b72f53f99 100644 --- a/Analysis/src/RequireTracer.cpp +++ b/Analysis/src/RequireTracer.cpp @@ -5,7 +5,7 @@ #include "Luau/Module.h" LUAU_FASTFLAGVARIABLE(LuauTraceRequireLookupChild, false) -LUAU_FASTFLAGVARIABLE(LuauNewRequireTrace, false) +LUAU_FASTFLAGVARIABLE(LuauNewRequireTrace2, false) namespace Luau { @@ -19,7 +19,7 @@ struct RequireTracerOld : AstVisitor : fileResolver(fileResolver) , currentModuleName(currentModuleName) { - LUAU_ASSERT(!FFlag::LuauNewRequireTrace); + LUAU_ASSERT(!FFlag::LuauNewRequireTrace2); } FileResolver* const fileResolver; @@ -188,7 +188,7 @@ struct RequireTracer : AstVisitor , currentModuleName(currentModuleName) , locals(nullptr) { - LUAU_ASSERT(FFlag::LuauNewRequireTrace); + LUAU_ASSERT(FFlag::LuauNewRequireTrace2); } bool visit(AstExprTypeAssertion* expr) override @@ -332,7 +332,7 @@ struct RequireTracer : AstVisitor RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName) { - if (FFlag::LuauNewRequireTrace) + if (FFlag::LuauNewRequireTrace2) { RequireTraceResult result; RequireTracer tracer{result, fileResolver, currentModuleName}; diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index d861eb3da..ca2b30f52 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -8,8 +8,6 @@ LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000) LUAU_FASTFLAGVARIABLE(LuauSubstitutionDontReplaceIgnoredTypes, false) -LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) -LUAU_FASTFLAG(LuauRankNTypes) LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau @@ -19,7 +17,7 @@ void Tarjan::visitChildren(TypeId ty, int index) { ty = follow(ty); - if (FFlag::LuauRankNTypes && ignoreChildren(ty)) + if (ignoreChildren(ty)) return; if (const FunctionTypeVar* ftv = get(ty)) @@ -68,7 +66,7 @@ void Tarjan::visitChildren(TypePackId tp, int index) { tp = follow(tp); - if (FFlag::LuauRankNTypes && ignoreChildren(tp)) + if (ignoreChildren(tp)) return; if (const TypePack* tpp = get(tp)) @@ -399,8 +397,7 @@ TypeId Substitution::clone(TypeId ty) if (FFlag::LuauTypeAliasPacks) clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; - if (FFlag::LuauSecondTypecheckKnowsTheDataModel) - clone.tags = ttv->tags; + clone.tags = ttv->tags; result = addType(std::move(clone)); } else if (const MetatableTypeVar* mtv = get(ty)) @@ -486,7 +483,7 @@ void Substitution::replaceChildren(TypeId ty) { ty = follow(ty); - if (FFlag::LuauRankNTypes && ignoreChildren(ty)) + if (ignoreChildren(ty)) return; if (FunctionTypeVar* ftv = getMutable(ty)) @@ -535,7 +532,7 @@ void Substitution::replaceChildren(TypePackId tp) { tp = follow(tp); - if (FFlag::LuauRankNTypes && ignoreChildren(tp)) + if (ignoreChildren(tp)) return; if (TypePack* tpp = getMutable(tp)) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index cd8180dba..885fd489b 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -11,7 +11,6 @@ #include LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) -LUAU_FASTFLAGVARIABLE(LuauInstantiatedTypeParamRecursion, false) LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau @@ -237,15 +236,6 @@ struct TypeVarStringifier return; } - if (!FFlag::LuauAddMissingFollow) - { - if (get(tv)) - { - state.emit(state.getName(tv)); - return; - } - } - Luau::visit( [this, tv](auto&& t) { return (*this)(tv, t); @@ -316,11 +306,7 @@ struct TypeVarStringifier void operator()(TypeId ty, const Unifiable::Free& ftv) { state.result.invalid = true; - - if (FFlag::LuauAddMissingFollow) - state.emit(state.getName(ty)); - else - state.emit(""); + state.emit(state.getName(ty)); } void operator()(TypeId, const BoundTypeVar& btv) @@ -724,16 +710,6 @@ struct TypePackStringifier return; } - if (!FFlag::LuauAddMissingFollow) - { - if (get(tp)) - { - state.emit(state.getName(tp)); - state.emit("..."); - return; - } - } - auto it = state.cycleTpNames.find(tp); if (it != state.cycleTpNames.end()) { @@ -821,16 +797,8 @@ struct TypePackStringifier void operator()(TypePackId tp, const FreeTypePack& pack) { state.result.invalid = true; - - if (FFlag::LuauAddMissingFollow) - { - state.emit(state.getName(tp)); - state.emit("..."); - } - else - { - state.emit(""); - } + state.emit(state.getName(tp)); + state.emit("..."); } void operator()(TypePackId, const BoundTypePack& btv) @@ -864,23 +832,15 @@ static void assignCycleNames(const std::unordered_set& cycles, const std std::string name; // TODO: use the stringified type list if there are no cycles - if (FFlag::LuauInstantiatedTypeParamRecursion) + if (auto ttv = get(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) { - if (auto ttv = get(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) - { - // If we have a cycle type in type parameters, assign a cycle name for this named table - if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), [&](auto&& el) { - return cycles.count(follow(el)); - }) != ttv->instantiatedTypeParams.end()) - cycleNames[cycleTy] = ttv->name ? *ttv->name : *ttv->syntheticName; + // If we have a cycle type in type parameters, assign a cycle name for this named table + if (std::find_if(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), [&](auto&& el) { + return cycles.count(follow(el)); + }) != ttv->instantiatedTypeParams.end()) + cycleNames[cycleTy] = ttv->name ? *ttv->name : *ttv->syntheticName; - continue; - } - } - else - { - if (auto ttv = get(follow(cycleTy)); !exhaustive && ttv && (ttv->syntheticName || ttv->name)) - continue; + continue; } name = "t" + std::to_string(nextIndex); @@ -912,58 +872,6 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) ToStringResult result; - if (!FFlag::LuauInstantiatedTypeParamRecursion && !opts.exhaustive) - { - if (auto ttv = get(ty); ttv && (ttv->name || ttv->syntheticName)) - { - if (ttv->syntheticName) - result.invalid = true; - - // If scope if provided, add module name and check visibility - if (ttv->name && opts.scope) - { - auto [success, moduleName] = canUseTypeNameInScope(opts.scope, *ttv->name); - - if (!success) - result.invalid = true; - - if (moduleName) - result.name = format("%s.", moduleName->c_str()); - } - - result.name += ttv->name ? *ttv->name : *ttv->syntheticName; - - if (ttv->instantiatedTypeParams.empty() && (!FFlag::LuauTypeAliasPacks || ttv->instantiatedTypePackParams.empty())) - return result; - - std::vector params; - for (TypeId tp : ttv->instantiatedTypeParams) - params.push_back(toString(tp)); - - if (FFlag::LuauTypeAliasPacks) - { - // Doesn't preserve grouping of multiple type packs - // But this is under a parent block of code that is being removed later - for (TypePackId tp : ttv->instantiatedTypePackParams) - { - std::string content = toString(tp); - - if (!content.empty()) - params.push_back(std::move(content)); - } - } - - result.name += "<" + join(params, ", ") + ">"; - return result; - } - else if (auto mtv = get(ty); mtv && mtv->syntheticName) - { - result.invalid = true; - result.name = *mtv->syntheticName; - return result; - } - } - StringifierState state{opts, result, opts.nameMap}; std::unordered_set cycles; @@ -975,7 +883,7 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) TypeVarStringifier tvs{state}; - if (FFlag::LuauInstantiatedTypeParamRecursion && !opts.exhaustive) + if (!opts.exhaustive) { if (auto ttv = get(ty); ttv && (ttv->name || ttv->syntheticName)) { diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 1b83ccdc2..7d880af49 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -10,7 +10,6 @@ #include #include -LUAU_FASTFLAG(LuauGenericFunctions) LUAU_FASTFLAG(LuauTypeAliasPacks) namespace @@ -97,9 +96,6 @@ struct Writer { virtual ~Writer() {} - virtual void begin() {} - virtual void end() {} - virtual void advance(const Position&) = 0; virtual void newline() = 0; virtual void space() = 0; @@ -131,6 +127,7 @@ struct StringWriter : Writer if (pos.column < newPos.column) write(std::string(newPos.column - pos.column, ' ')); } + void maybeSpace(const Position& newPos, int reserve) override { if (pos.column + reserve < newPos.column) @@ -279,11 +276,14 @@ struct Printer writer.identifier(func->index.value); } - void visualizeTypePackAnnotation(const AstTypePack& annotation) + void visualizeTypePackAnnotation(const AstTypePack& annotation, bool forVarArg) { + advance(annotation.location.begin); if (const AstTypePackVariadic* variadicTp = annotation.as()) { - writer.symbol("..."); + if (!forVarArg) + writer.symbol("..."); + visualizeTypeAnnotation(*variadicTp->variadicType); } else if (const AstTypePackGeneric* genericTp = annotation.as()) @@ -293,6 +293,7 @@ struct Printer } else if (const AstTypePackExplicit* explicitTp = annotation.as()) { + LUAU_ASSERT(!forVarArg); visualizeTypeList(explicitTp->typeList, true); } else @@ -317,7 +318,7 @@ struct Printer // Only variadic tail if (list.types.size == 0) { - visualizeTypePackAnnotation(*list.tailType); + visualizeTypePackAnnotation(*list.tailType, false); } else { @@ -345,7 +346,7 @@ struct Printer if (list.tailType) { writer.symbol(","); - visualizeTypePackAnnotation(*list.tailType); + visualizeTypePackAnnotation(*list.tailType, false); } writer.symbol(")"); @@ -542,6 +543,7 @@ struct Printer case AstExprBinary::CompareLt: case AstExprBinary::CompareGt: writer.maybeSpace(a->right->location.begin, 2); + writer.symbol(toString(a->op)); break; case AstExprBinary::Concat: case AstExprBinary::CompareNe: @@ -550,19 +552,35 @@ struct Printer case AstExprBinary::CompareGe: case AstExprBinary::Or: writer.maybeSpace(a->right->location.begin, 3); + writer.keyword(toString(a->op)); break; case AstExprBinary::And: writer.maybeSpace(a->right->location.begin, 4); + writer.keyword(toString(a->op)); break; } - writer.symbol(toString(a->op)); - visualize(*a->right); } else if (const auto& a = expr.as()) { visualize(*a->expr); + + if (writeTypes) + { + writer.maybeSpace(a->annotation->location.begin, 2); + writer.symbol("::"); + visualizeTypeAnnotation(*a->annotation); + } + } + else if (const auto& a = expr.as()) + { + writer.keyword("if"); + visualize(*a->condition); + writer.keyword("then"); + visualize(*a->trueExpr); + writer.keyword("else"); + visualize(*a->falseExpr); } else if (const auto& a = expr.as()) { @@ -769,24 +787,31 @@ struct Printer switch (a->op) { case AstExprBinary::Add: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("+="); break; case AstExprBinary::Sub: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("-="); break; case AstExprBinary::Mul: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("*="); break; case AstExprBinary::Div: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("/="); break; case AstExprBinary::Mod: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("%="); break; case AstExprBinary::Pow: + writer.maybeSpace(a->value->location.begin, 2); writer.symbol("^="); break; case AstExprBinary::Concat: + writer.maybeSpace(a->value->location.begin, 3); writer.symbol("..="); break; default: @@ -874,7 +899,7 @@ struct Printer void visualizeFunctionBody(AstExprFunction& func) { - if (FFlag::LuauGenericFunctions && (func.generics.size > 0 || func.genericPacks.size > 0)) + if (func.generics.size > 0 || func.genericPacks.size > 0) { CommaSeparatorInserter comma(writer); writer.symbol("<"); @@ -913,12 +938,13 @@ struct Printer if (func.vararg) { comma(); + advance(func.varargLocation.begin); writer.symbol("..."); if (func.varargAnnotation) { writer.symbol(":"); - visualizeTypePackAnnotation(*func.varargAnnotation); + visualizeTypePackAnnotation(*func.varargAnnotation, true); } } @@ -980,8 +1006,14 @@ struct Printer advance(typeAnnotation.location.begin); if (const auto& a = typeAnnotation.as()) { + if (a->hasPrefix) + { + writer.write(a->prefix.value); + writer.symbol("."); + } + writer.write(a->name.value); - if (a->parameters.size > 0) + if (a->parameters.size > 0 || a->hasParameterList) { CommaSeparatorInserter comma(writer); writer.symbol("<"); @@ -992,7 +1024,7 @@ struct Printer if (o.type) visualizeTypeAnnotation(*o.type); else - visualizeTypePackAnnotation(*o.typePack); + visualizeTypePackAnnotation(*o.typePack, false); } writer.symbol(">"); @@ -1000,7 +1032,7 @@ struct Printer } else if (const auto& a = typeAnnotation.as()) { - if (FFlag::LuauGenericFunctions && (a->generics.size > 0 || a->genericPacks.size > 0)) + if (a->generics.size > 0 || a->genericPacks.size > 0) { CommaSeparatorInserter comma(writer); writer.symbol("<"); @@ -1075,7 +1107,16 @@ struct Printer auto rta = r->as(); if (rta && rta->name == "nil") { + bool wrap = l->as() || l->as(); + + if (wrap) + writer.symbol("("); + visualizeTypeAnnotation(*l); + + if (wrap) + writer.symbol(")"); + writer.symbol("?"); return; } @@ -1089,7 +1130,15 @@ struct Printer writer.symbol("|"); } + bool wrap = a->types.data[i]->as() || a->types.data[i]->as(); + + if (wrap) + writer.symbol("("); + visualizeTypeAnnotation(*a->types.data[i]); + + if (wrap) + writer.symbol(")"); } } else if (const auto& a = typeAnnotation.as()) @@ -1102,7 +1151,15 @@ struct Printer writer.symbol("&"); } + bool wrap = a->types.data[i]->as() || a->types.data[i]->as(); + + if (wrap) + writer.symbol("("); + visualizeTypeAnnotation(*a->types.data[i]); + + if (wrap) + writer.symbol(")"); } } else if (typeAnnotation.is()) @@ -1116,31 +1173,27 @@ struct Printer } }; -void dump(AstNode* node) +std::string toString(AstNode* node) { StringWriter writer; + writer.pos = node->location.begin; + Printer printer(writer); printer.writeTypes = true; if (auto statNode = dynamic_cast(node)) - { printer.visualize(*statNode); - printf("%s\n", writer.str().c_str()); - } else if (auto exprNode = dynamic_cast(node)) - { printer.visualize(*exprNode); - printf("%s\n", writer.str().c_str()); - } else if (auto typeNode = dynamic_cast(node)) - { printer.visualizeTypeAnnotation(*typeNode); - printf("%s\n", writer.str().c_str()); - } - else - { - printf("Can't dump this node\n"); - } + + return writer.str(); +} + +void dump(AstNode* node) +{ + printf("%s\n", toString(node).c_str()); } std::string transpile(AstStatBlock& block) @@ -1149,6 +1202,7 @@ std::string transpile(AstStatBlock& block) Printer(writer).visualizeBlock(block); return writer.str(); } + std::string transpileWithTypes(AstStatBlock& block) { StringWriter writer; @@ -1158,7 +1212,7 @@ std::string transpileWithTypes(AstStatBlock& block) return writer.str(); } -TranspileResult transpile(std::string_view source, ParseOptions options) +TranspileResult transpile(std::string_view source, ParseOptions options, bool withTypes) { auto allocator = Allocator{}; auto names = AstNameTable{allocator}; @@ -1176,6 +1230,9 @@ TranspileResult transpile(std::string_view source, ParseOptions options) if (!parseResult.root) return TranspileResult{"", {}, "Internal error: Parser yielded empty parse tree"}; + if (withTypes) + return TranspileResult{transpileWithTypes(*parseResult.root)}; + return TranspileResult{transpile(*parseResult.root)}; } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 49f8e0cac..11aa7b394 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -13,7 +13,6 @@ #include -LUAU_FASTFLAG(LuauGenericFunctions) LUAU_FASTFLAG(LuauTypeAliasPacks) static char* allocateString(Luau::Allocator& allocator, std::string_view contents) @@ -203,39 +202,23 @@ class TypeRehydrationVisitor return allocator->alloc(Location(), std::nullopt, AstName("")); AstArray generics; - if (FFlag::LuauGenericFunctions) + generics.size = ftv.generics.size(); + generics.data = static_cast(allocator->allocate(sizeof(AstName) * generics.size)); + size_t numGenerics = 0; + for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it) { - generics.size = ftv.generics.size(); - generics.data = static_cast(allocator->allocate(sizeof(AstName) * generics.size)); - size_t i = 0; - for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it) - { - if (auto gtv = get(*it)) - generics.data[i++] = AstName(gtv->name.c_str()); - } - } - else - { - generics.size = 0; - generics.data = nullptr; + if (auto gtv = get(*it)) + generics.data[numGenerics++] = AstName(gtv->name.c_str()); } AstArray genericPacks; - if (FFlag::LuauGenericFunctions) - { - genericPacks.size = ftv.genericPacks.size(); - genericPacks.data = static_cast(allocator->allocate(sizeof(AstName) * genericPacks.size)); - size_t i = 0; - for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) - { - if (auto gtv = get(*it)) - genericPacks.data[i++] = AstName(gtv->name.c_str()); - } - } - else + genericPacks.size = ftv.genericPacks.size(); + genericPacks.data = static_cast(allocator->allocate(sizeof(AstName) * genericPacks.size)); + size_t numGenericPacks = 0; + for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) { - generics.size = 0; - generics.data = nullptr; + if (auto gtv = get(*it)) + genericPacks.data[numGenericPacks++] = AstName(gtv->name.c_str()); } AstArray argTypes; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 38e2e5270..8fad1af91 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -22,29 +22,19 @@ LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) -LUAU_FASTFLAGVARIABLE(LuauGenericFunctions, false) -LUAU_FASTFLAGVARIABLE(LuauGenericVariadicsUnification, false) LUAU_FASTFLAG(LuauKnowsTheDataModel3) -LUAU_FASTFLAG(LuauSecondTypecheckKnowsTheDataModel) LUAU_FASTFLAGVARIABLE(LuauClassPropertyAccessAsString, false) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) -LUAU_FASTFLAGVARIABLE(LuauRankNTypes, false) -LUAU_FASTFLAGVARIABLE(LuauOrPredicate, false) -LUAU_FASTFLAGVARIABLE(LuauInferReturnAssertAssign, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) -LUAU_FASTFLAGVARIABLE(LuauAddMissingFollow, false) -LUAU_FASTFLAGVARIABLE(LuauTypeGuardPeelsAwaySubclasses, false) -LUAU_FASTFLAGVARIABLE(LuauSlightlyMoreFlexibleBinaryPredicates, false) -LUAU_FASTFLAGVARIABLE(LuauFollowInTypeFunApply, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false) LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) -LUAU_FASTFLAG(LuauNewRequireTrace) +LUAU_FASTFLAG(LuauNewRequireTrace2) LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau @@ -222,9 +212,9 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan , threadType(singletonTypes.threadType) , anyType(singletonTypes.anyType) , errorType(singletonTypes.errorType) - , optionalNumberType(globalTypes.addType(UnionTypeVar{{numberType, nilType}})) - , anyTypePack(globalTypes.addTypePack(TypePackVar{VariadicTypePack{singletonTypes.anyType}, true})) - , errorTypePack(globalTypes.addTypePack(TypePackVar{Unifiable::Error{}})) + , optionalNumberType(singletonTypes.optionalNumberType) + , anyTypePack(singletonTypes.anyTypePack) + , errorTypePack(singletonTypes.errorTypePack) { globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); @@ -251,10 +241,8 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona if (module.cyclic) moduleScope->returnType = addTypePack(TypePack{{anyType}, std::nullopt}); - else if (FFlag::LuauRankNTypes) - moduleScope->returnType = freshTypePack(moduleScope); else - moduleScope->returnType = DEPRECATED_freshTypePack(moduleScope, true); + moduleScope->returnType = freshTypePack(moduleScope); moduleScope->varargPack = anyTypePack; @@ -268,7 +256,7 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona checkBlock(moduleScope, *module.root); - if (get(FFlag::LuauAddMissingFollow ? follow(moduleScope->returnType) : moduleScope->returnType)) + if (get(follow(moduleScope->returnType))) moduleScope->returnType = addTypePack(TypePack{{}, std::nullopt}); else moduleScope->returnType = anyify(moduleScope, moduleScope->returnType, Location{}); @@ -326,7 +314,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStat& program) check(scope, *typealias); else if (auto global = program.as()) { - TypeId globalType = (FFlag::LuauRankNTypes ? resolveType(scope, *global->type) : resolveType(scope, *global->type, true)); + TypeId globalType = resolveType(scope, *global->type); Name globalName(global->name.value); currentModule->declaredGlobals[globalName] = globalType; @@ -494,7 +482,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std Name name = typealias->name.value; TypeId type = bindings[name].type; - if (get(FFlag::LuauAddMissingFollow ? follow(type) : type)) + if (get(follow(type))) { *asMutable(type) = ErrorTypeVar{}; reportError(TypeError{typealias->location, OccursCheckFailed{}}); @@ -607,26 +595,22 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatRepeat& statement) void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) { std::vector> expectedTypes; + expectedTypes.reserve(return_.list.size); - if (FFlag::LuauInferReturnAssertAssign) - { - expectedTypes.reserve(return_.list.size); - - TypePackIterator expectedRetCurr = begin(scope->returnType); - TypePackIterator expectedRetEnd = end(scope->returnType); + TypePackIterator expectedRetCurr = begin(scope->returnType); + TypePackIterator expectedRetEnd = end(scope->returnType); - for (size_t i = 0; i < return_.list.size; ++i) + for (size_t i = 0; i < return_.list.size; ++i) + { + if (expectedRetCurr != expectedRetEnd) { - if (expectedRetCurr != expectedRetEnd) - { - expectedTypes.push_back(*expectedRetCurr); - ++expectedRetCurr; - } - else if (auto expectedArgsTail = expectedRetCurr.tail()) - { - if (const VariadicTypePack* vtp = get(follow(*expectedArgsTail))) - expectedTypes.push_back(vtp->ty); - } + expectedTypes.push_back(*expectedRetCurr); + ++expectedRetCurr; + } + else if (auto expectedArgsTail = expectedRetCurr.tail()) + { + if (const VariadicTypePack* vtp = get(follow(*expectedArgsTail))) + expectedTypes.push_back(vtp->ty); } } @@ -672,34 +656,30 @@ ErrorVec TypeChecker::tryUnify_(Id left, Id right, const Location& location) void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) { std::vector> expectedTypes; + expectedTypes.reserve(assign.vars.size); - if (FFlag::LuauInferReturnAssertAssign) - { - expectedTypes.reserve(assign.vars.size); + ScopePtr moduleScope = currentModule->getModuleScope(); - ScopePtr moduleScope = currentModule->getModuleScope(); + for (size_t i = 0; i < assign.vars.size; ++i) + { + AstExpr* dest = assign.vars.data[i]; - for (size_t i = 0; i < assign.vars.size; ++i) + if (auto a = dest->as()) { - AstExpr* dest = assign.vars.data[i]; - - if (auto a = dest->as()) - { - // AstExprLocal l-values will have to be checked again because their type might have been mutated during checkExprList later - expectedTypes.push_back(scope->lookup(a->local)); - } - else if (auto a = dest->as()) - { - // AstExprGlobal l-values lookup is inlined here to avoid creating a global binding before checkExprList - if (auto it = moduleScope->bindings.find(a->name); it != moduleScope->bindings.end()) - expectedTypes.push_back(it->second.typeId); - else - expectedTypes.push_back(std::nullopt); - } + // AstExprLocal l-values will have to be checked again because their type might have been mutated during checkExprList later + expectedTypes.push_back(scope->lookup(a->local)); + } + else if (auto a = dest->as()) + { + // AstExprGlobal l-values lookup is inlined here to avoid creating a global binding before checkExprList + if (auto it = moduleScope->bindings.find(a->name); it != moduleScope->bindings.end()) + expectedTypes.push_back(it->second.typeId); else - { - expectedTypes.push_back(checkLValue(scope, *dest)); - } + expectedTypes.push_back(std::nullopt); + } + else + { + expectedTypes.push_back(checkLValue(scope, *dest)); } } @@ -715,7 +695,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) AstExpr* dest = assign.vars.data[i]; TypeId left = nullptr; - if (!FFlag::LuauInferReturnAssertAssign || dest->is() || dest->is()) + if (dest->is() || dest->is()) left = checkLValue(scope, *dest); else left = *expectedTypes[i]; @@ -751,11 +731,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) if (right) { - if (FFlag::LuauGenericFunctions && !maybeGeneric(left) && isGeneric(right)) - right = instantiate(scope, right, loc); - - if (!FFlag::LuauGenericFunctions && get(FFlag::LuauAddMissingFollow ? follow(left) : left) && - get(FFlag::LuauAddMissingFollow ? follow(right) : right)) + if (!maybeGeneric(left) && isGeneric(right)) right = instantiate(scope, right, loc); // Setting a table entry to nil doesn't mean nil is the type of the indexer, it is just deleting the entry @@ -766,7 +742,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) if (!destTableTypeReceivingNil || !destTableTypeReceivingNil->indexer) { // In nonstrict mode, any assignments where the lhs is free and rhs isn't a function, we give it any typevar. - if (isNonstrictMode() && get(FFlag::LuauAddMissingFollow ? follow(left) : left) && !get(follow(right))) + if (isNonstrictMode() && get(follow(left)) && !get(follow(right))) unify(left, anyType, loc); else unify(left, right, loc); @@ -815,7 +791,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) if (annotation) { - ty = (FFlag::LuauRankNTypes ? resolveType(scope, *annotation) : resolveType(scope, *annotation, true)); + ty = resolveType(scope, *annotation); // If the annotation type has an error, treat it as if there was no annotation if (get(follow(ty))) @@ -823,23 +799,19 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) } if (!ty) - ty = rhsIsTable ? (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)) - : isNonstrictMode() ? anyType : (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + ty = rhsIsTable ? freshType(scope) : isNonstrictMode() ? anyType : freshType(scope); varBindings.emplace_back(vars[i], Binding{ty, vars[i]->location}); variableTypes.push_back(ty); expectedTypes.push_back(ty); - if (FFlag::LuauGenericFunctions) - instantiateGenerics.push_back(annotation != nullptr && !maybeGeneric(ty)); - else - instantiateGenerics.push_back(annotation != nullptr && get(FFlag::LuauAddMissingFollow ? follow(ty) : ty)); + instantiateGenerics.push_back(annotation != nullptr && !maybeGeneric(ty)); } if (local.values.size > 0) { - TypePackId variablePack = addTypePack(variableTypes, FFlag::LuauRankNTypes ? freshTypePack(scope) : DEPRECATED_freshTypePack(scope, true)); + TypePackId variablePack = addTypePack(variableTypes, freshTypePack(scope)); TypePackId valuePack = checkExprList(scope, local.location, local.values, /* substituteFreeForNil= */ true, instantiateGenerics, expectedTypes).type; @@ -979,8 +951,6 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) { AstExprCall* exprCall = firstValue->as(); callRetPack = checkExprPack(scope, *exprCall).type; - if (!FFlag::LuauRankNTypes) - callRetPack = DEPRECATED_instantiate(scope, callRetPack, exprCall->location); callRetPack = follow(callRetPack); if (get(callRetPack)) @@ -998,8 +968,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) else { iterTy = *first(callRetPack); - if (FFlag::LuauRankNTypes) - iterTy = instantiate(scope, iterTy, exprCall->location); + iterTy = instantiate(scope, iterTy, exprCall->location); } } else @@ -1158,10 +1127,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); - if (FFlag::LuauGenericFunctions) - scope->bindings[function.name] = {quantify(funScope, ty, function.name->location), function.name->location}; - else - scope->bindings[function.name] = {quantify(scope, ty, function.name->location), function.name->location}; + scope->bindings[function.name] = {quantify(funScope, ty, function.name->location), function.name->location}; } void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias, int subLevel, bool forwardDeclare) @@ -1199,7 +1165,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks); - TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); + TypeId ty = freshType(aliasScope); FreeTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); ftv->forwardedTypeAlias = true; @@ -1234,7 +1200,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias aliasScope->privateTypeBindings[n] = TypeFun{{}, g}; } - TypeId ty = (FFlag::LuauRankNTypes ? freshType(aliasScope) : DEPRECATED_freshType(scope, true)); + TypeId ty = freshType(aliasScope); FreeTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); ftv->forwardedTypeAlias = true; @@ -1266,7 +1232,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias } } - TypeId ty = (FFlag::LuauRankNTypes ? resolveType(aliasScope, *typealias.type) : resolveType(aliasScope, *typealias.type, true)); + TypeId ty = resolveType(aliasScope, *typealias.type); if (auto ttv = getMutable(follow(ty))) { // If the table is already named and we want to rename the type function, we have to bind new alias to a copy @@ -1325,25 +1291,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar LUAU_ASSERT(lookupType->typeParams.size() == 0 && (!FFlag::LuauTypeAliasPacks || lookupType->typePackParams.size() == 0)); superTy = lookupType->type; - if (FFlag::LuauAddMissingFollow) + if (!get(follow(*superTy))) { - if (!get(follow(*superTy))) - { - reportError(declaredClass.location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", - superName.c_str(), declaredClass.name.value)}); + reportError(declaredClass.location, + GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", superName.c_str(), declaredClass.name.value)}); - return; - } - } - else - { - if (const ClassTypeVar* superCtv = get(*superTy); !superCtv) - { - reportError(declaredClass.location, GenericError{format("Cannot use non-class type '%s' as a superclass of class '%s'", - superName.c_str(), declaredClass.name.value)}); - - return; - } + return; } } @@ -1558,8 +1511,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVa } else if (auto ftp = get(varargPack)) { - TypeId head = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, ftp->DEPRECATED_canBeGeneric)); - TypePackId tail = (FFlag::LuauRankNTypes ? freshTypePack(scope) : DEPRECATED_freshTypePack(scope, ftp->DEPRECATED_canBeGeneric)); + TypeId head = freshType(scope); + TypePackId tail = freshTypePack(scope); *asMutable(varargPack) = TypePack{{head}, tail}; return {head}; } @@ -1567,7 +1520,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVa return {errorType}; else if (auto vtp = get(varargPack)) return {vtp->ty}; - else if (FFlag::LuauGenericVariadicsUnification && get(varargPack)) + else if (get(varargPack)) { // TODO: Better error? reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"}); @@ -1588,7 +1541,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa } else if (auto ftp = get(retPack)) { - TypeId head = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, ftp->DEPRECATED_canBeGeneric)); + TypeId head = freshType(scope); TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope)}}); unify(retPack, pack, expr.location); return {head, std::move(result.predicates)}; @@ -1667,7 +1620,7 @@ std::optional TypeChecker::getIndexTypeFromType( } else if (tableType->state == TableState::Free) { - TypeId result = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + TypeId result = freshType(scope); tableType->props[name] = {result}; return result; } @@ -2129,7 +2082,7 @@ TypeId TypeChecker::checkRelationalOperation( if (!isNonstrictMode() && !isOrOp) return ty; - if (auto i = get(ty)) + if (get(ty)) { std::optional cleaned = tryStripUnionFromNil(ty); @@ -2158,16 +2111,9 @@ TypeId TypeChecker::checkRelationalOperation( { if (expr.op == AstExprBinary::Or && subexp->op == AstExprBinary::And) { - if (FFlag::LuauSlightlyMoreFlexibleBinaryPredicates) - { - ScopePtr subScope = childScope(scope, subexp->location); - reportErrors(resolve(predicates, subScope, true)); - return unionOfTypes(rhsType, stripNil(checkExpr(subScope, *subexp->right).type, true), expr.location); - } - else - { - return unionOfTypes(rhsType, checkExpr(scope, *subexp->right).type, expr.location); - } + ScopePtr subScope = childScope(scope, subexp->location); + reportErrors(resolve(predicates, subScope, true)); + return unionOfTypes(rhsType, stripNil(checkExpr(subScope, *subexp->right).type, true), expr.location); } } @@ -2217,10 +2163,8 @@ TypeId TypeChecker::checkRelationalOperation( std::string metamethodName = opToMetaTableEntry(expr.op); - std::optional leftMetatable = - isString(lhsType) ? std::nullopt : getMetatable(FFlag::LuauAddMissingFollow ? follow(lhsType) : lhsType); - std::optional rightMetatable = - isString(rhsType) ? std::nullopt : getMetatable(FFlag::LuauAddMissingFollow ? follow(rhsType) : rhsType); + std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType)); + std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType)); if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable) { @@ -2266,7 +2210,7 @@ TypeId TypeChecker::checkRelationalOperation( } } - if (get(FFlag::LuauAddMissingFollow ? follow(lhsType) : lhsType) && !isEquality) + if (get(follow(lhsType)) && !isEquality) { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Comparison}); @@ -2417,12 +2361,10 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi resolve(lhs.predicates, innerScope, true); ExprResult rhs = checkExpr(innerScope, *expr.right); - if (!FFlag::LuauSlightlyMoreFlexibleBinaryPredicates) - resolve(rhs.predicates, innerScope, true); return {checkBinaryOperation(innerScope, expr, lhs.type, rhs.type), {AndPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; } - else if (FFlag::LuauOrPredicate && expr.op == AstExprBinary::Or) + else if (expr.op == AstExprBinary::Or) { ExprResult lhs = checkExpr(scope, *expr.left); @@ -2468,19 +2410,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr) { - ExprResult result; - TypeId annotationType; - - if (FFlag::LuauInferReturnAssertAssign) - { - annotationType = (FFlag::LuauRankNTypes ? resolveType(scope, *expr.annotation) : resolveType(scope, *expr.annotation, true)); - result = checkExpr(scope, *expr.expr, annotationType); - } - else - { - result = checkExpr(scope, *expr.expr); - annotationType = (FFlag::LuauRankNTypes ? resolveType(scope, *expr.annotation) : resolveType(scope, *expr.annotation, true)); - } + TypeId annotationType = resolveType(scope, *expr.annotation); + ExprResult result = checkExpr(scope, *expr.expr, annotationType); ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); if (!errorVec.empty()) @@ -2570,23 +2501,16 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (it != moduleScope->bindings.end()) return std::pair(it->second.typeId, &it->second.typeId); - if (isNonstrictMode() || FFlag::LuauSecondTypecheckKnowsTheDataModel) - { - TypeId result = (FFlag::LuauGenericFunctions && FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(moduleScope, true)); - - Binding& binding = moduleScope->bindings[expr.name]; - binding = {result, expr.location}; + TypeId result = freshType(scope); + Binding& binding = moduleScope->bindings[expr.name]; + binding = {result, expr.location}; - // If we're in strict mode, we want to report defining a global as an error, - // but still add it to the bindings, so that autocomplete includes it in completions. - if (!isNonstrictMode()) - reportError(TypeError{expr.location, UnknownSymbol{name, UnknownSymbol::Binding}}); + // If we're in strict mode, we want to report defining a global as an error, + // but still add it to the bindings, so that autocomplete includes it in completions. + if (!isNonstrictMode()) + reportError(TypeError{expr.location, UnknownSymbol{name, UnknownSymbol::Binding}}); - return std::pair(result, &binding.typeId); - } - - reportError(TypeError{expr.location, UnknownSymbol{name, UnknownSymbol::Binding}}); - return std::pair(errorType, nullptr); + return std::pair(result, &binding.typeId); } std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr) @@ -2611,7 +2535,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope } else if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) { - TypeId theType = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + TypeId theType = freshType(scope); Property& property = lhsTable->props[name]; property.type = theType; property.location = expr.indexLocation; @@ -2683,7 +2607,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope AstExprConstantString* value = expr.index->as(); - if (value && FFlag::LuauClassPropertyAccessAsString) + if (value) { if (const ClassTypeVar* exprClass = get(exprType)) { @@ -2714,7 +2638,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) { - TypeId resultType = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + TypeId resultType = freshType(scope); Property& property = exprTable->props[value->value.data]; property.type = resultType; property.location = expr.index->location; @@ -2730,7 +2654,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) { - TypeId resultType = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + TypeId resultType = freshType(scope); exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; return std::pair(resultType, nullptr); } @@ -2758,7 +2682,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) } else { - TypeId ty = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + TypeId ty = freshType(scope); globalScope->bindings[name] = {ty, funName.location}; return ty; } @@ -2768,7 +2692,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) Symbol name = localName->local; Binding& binding = scope->bindings[name]; if (binding.typeId == nullptr) - binding = {(FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)), funName.location}; + binding = {freshType(scope), funName.location}; return binding.typeId; } @@ -2798,7 +2722,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) Property& property = ttv->props[name]; - property.type = (FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + property.type = freshType(scope); property.location = indexName->indexLocation; ttv->methodDefinitionLocations[name] = funName.location; return property.type; @@ -2865,22 +2789,11 @@ std::pair TypeChecker::checkFunctionSignature( expectedFunctionType = nullptr; } - std::vector generics; - std::vector genericPacks; - - if (FFlag::LuauGenericFunctions) - { - std::tie(generics, genericPacks) = createGenericTypes(funScope, std::nullopt, expr, expr.generics, expr.genericPacks); - } + auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, expr, expr.generics, expr.genericPacks); TypePackId retPack; if (expr.hasReturnAnnotation) - { - if (FFlag::LuauGenericFunctions) - retPack = resolveTypePack(funScope, expr.returnAnnotation); - else - retPack = resolveTypePack(scope, expr.returnAnnotation); - } + retPack = resolveTypePack(funScope, expr.returnAnnotation); else if (isNonstrictMode()) retPack = anyTypePack; else if (expectedFunctionType) @@ -2889,24 +2802,17 @@ std::pair TypeChecker::checkFunctionSignature( // Do not infer 'nil' as function return type if (!tail && head.size() == 1 && isNil(head[0])) - retPack = FFlag::LuauGenericFunctions ? freshTypePack(funScope) : freshTypePack(scope); + retPack = freshTypePack(funScope); else retPack = addTypePack(head, tail); } - else if (FFlag::LuauGenericFunctions) - retPack = freshTypePack(funScope); else - retPack = freshTypePack(scope); + retPack = freshTypePack(funScope); if (expr.vararg) { if (expr.varargAnnotation) - { - if (FFlag::LuauGenericFunctions) - funScope->varargPack = resolveTypePack(funScope, *expr.varargAnnotation); - else - funScope->varargPack = resolveTypePack(scope, *expr.varargAnnotation); - } + funScope->varargPack = resolveTypePack(funScope, *expr.varargAnnotation); else { if (expectedFunctionType && !isNonstrictMode()) @@ -2963,7 +2869,7 @@ std::pair TypeChecker::checkFunctionSignature( if (local->annotation) { - argType = resolveType((FFlag::LuauGenericFunctions ? funScope : scope), *local->annotation); + argType = resolveType(funScope, *local->annotation); // If the annotation type has an error, treat it as if there was no annotation if (get(follow(argType))) @@ -3022,7 +2928,7 @@ static bool allowsNoReturnValues(const TypePackId tp) { for (TypeId ty : tp) { - if (!get(FFlag::LuauAddMissingFollow ? follow(ty) : ty)) + if (!get(follow(ty))) { return false; } @@ -3058,7 +2964,7 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE check(scope, *function.body); // We explicitly don't follow here to check if we have a 'true' free type instead of bound one - if (FFlag::LuauAddMissingFollow ? get_if(&funTy->retType->ty) : get(funTy->retType)) + if (get_if(&funTy->retType->ty)) *asMutable(funTy->retType) = TypePack{{}, std::nullopt}; bool reachesImplicitReturn = getFallthrough(function.body) != nullptr; @@ -3287,7 +3193,7 @@ void TypeChecker::checkArgumentList( return; } - else if (FFlag::LuauGenericVariadicsUnification && get(tail)) + else if (get(tail)) { // Create a type pack out of the remaining argument types // and unify it with the tail. @@ -3310,7 +3216,7 @@ void TypeChecker::checkArgumentList( return; } - else if (FFlag::LuauRankNTypes && get(tail)) + else if (get(tail)) { // For this case, we want the error span to cover every errant extra parameter Location location = state.location; @@ -3323,10 +3229,7 @@ void TypeChecker::checkArgumentList( } else { - if (FFlag::LuauRankNTypes) - unifyWithInstantiationIfNeeded(scope, *paramIter, *argIter, state); - else - state.tryUnify(*paramIter, *argIter, /*isFunctionCall*/ false); + unifyWithInstantiationIfNeeded(scope, *paramIter, *argIter, state); ++argIter; ++paramIter; } @@ -3356,9 +3259,6 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A ice("method call expression has no 'self'"); selfType = checkExpr(scope, *indexExpr->expr).type; - if (!FFlag::LuauRankNTypes) - instantiate(scope, selfType, expr.func->location); - selfType = stripFromNilAndReport(selfType, expr.func->location); if (std::optional propTy = getIndexTypeFromType(scope, selfType, indexExpr->index.value, expr.location, true)) @@ -3393,8 +3293,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A std::vector> expectedTypes = getExpectedTypesForCall(overloads, expr.args.size, expr.self); ExprResult argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); - TypePackId argList = argListResult.type; - TypePackId argPack = (FFlag::LuauRankNTypes ? argList : DEPRECATED_instantiate(scope, argList, expr.location)); + TypePackId argPack = argListResult.type; if (get(argPack)) return ExprResult{errorTypePack}; @@ -3526,8 +3425,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); TypeId fn = *ty; - if (FFlag::LuauRankNTypes) - fn = instantiate(scope, fn, expr.func->location); + fn = instantiate(scope, fn, expr.func->location); return checkCallOverload( scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult, overloadsThatMatchArgCount, errors); @@ -3800,7 +3698,7 @@ ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const L TypeId actualType = substituteFreeForNil && expr->is() ? freshType(scope) : type; - if (instantiateGenerics.size() > i && instantiateGenerics[i] && (FFlag::LuauGenericFunctions || get(actualType))) + if (instantiateGenerics.size() > i && instantiateGenerics[i]) actualType = instantiate(scope, actualType, expr->location); if (expectedType) @@ -3837,7 +3735,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module LUAU_TIMETRACE_SCOPE("TypeChecker::checkRequire", "TypeChecker"); LUAU_TIMETRACE_ARGUMENT("moduleInfo", moduleInfo.name.c_str()); - if (FFlag::LuauNewRequireTrace && moduleInfo.name.empty()) + if (FFlag::LuauNewRequireTrace2 && moduleInfo.name.empty()) { if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict) { @@ -3922,7 +3820,6 @@ bool TypeChecker::unify(TypePackId left, TypePackId right, const Location& locat bool TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, const Location& location) { - LUAU_ASSERT(FFlag::LuauRankNTypes); Unifier state = mkUnifier(location); unifyWithInstantiationIfNeeded(scope, left, right, state); @@ -3933,7 +3830,6 @@ bool TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId l void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, Unifier& state) { - LUAU_ASSERT(FFlag::LuauRankNTypes); if (!maybeGeneric(right)) // Quick check to see if we definitely can't instantiate state.tryUnify(left, right, /*isFunctionCall*/ false); @@ -3973,19 +3869,7 @@ void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId l bool Instantiation::isDirty(TypeId ty) { - if (FFlag::LuauRankNTypes) - { - if (get(ty)) - return true; - else - return false; - } - - if (const FunctionTypeVar* ftv = get(ty)) - return !ftv->generics.empty() || !ftv->genericPacks.empty(); - else if (const TableTypeVar* ttv = get(ty)) - return ttv->state == TableState::Generic; - else if (get(ty)) + if (get(ty)) return true; else return false; @@ -3993,18 +3877,11 @@ bool Instantiation::isDirty(TypeId ty) bool Instantiation::isDirty(TypePackId tp) { - if (FFlag::LuauRankNTypes) - return false; - - if (get(tp)) - return true; - else - return false; + return false; } bool Instantiation::ignoreChildren(TypeId ty) { - LUAU_ASSERT(FFlag::LuauRankNTypes); if (get(ty)) return true; else @@ -4013,63 +3890,38 @@ bool Instantiation::ignoreChildren(TypeId ty) TypeId Instantiation::clean(TypeId ty) { - LUAU_ASSERT(isDirty(ty)); - - if (const FunctionTypeVar* ftv = get(ty)) - { - FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; - clone.magicFunction = ftv->magicFunction; - clone.tags = ftv->tags; - clone.argNames = ftv->argNames; - TypeId result = addType(std::move(clone)); - - if (FFlag::LuauRankNTypes) - { - // Annoyingly, we have to do this even if there are no generics, - // to replace any generic tables. - replaceGenerics.level = level; - replaceGenerics.currentModule = currentModule; - replaceGenerics.generics.assign(ftv->generics.begin(), ftv->generics.end()); - replaceGenerics.genericPacks.assign(ftv->genericPacks.begin(), ftv->genericPacks.end()); - - // TODO: What to do if this returns nullopt? - // We don't have access to the error-reporting machinery - result = replaceGenerics.substitute(result).value_or(result); - } - - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; - } - else if (const TableTypeVar* ttv = get(ty)) - { - LUAU_ASSERT(!FFlag::LuauRankNTypes); - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; - clone.definitionModuleName = ttv->definitionModuleName; - TypeId result = addType(std::move(clone)); - - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; - } - else - { - LUAU_ASSERT(!FFlag::LuauRankNTypes); - TypeId result = addType(FreeTypeVar{level}); - - asMutable(result)->documentationSymbol = ty->documentationSymbol; - return result; - } + const FunctionTypeVar* ftv = get(ty); + LUAU_ASSERT(ftv); + + FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; + clone.magicFunction = ftv->magicFunction; + clone.tags = ftv->tags; + clone.argNames = ftv->argNames; + TypeId result = addType(std::move(clone)); + + // Annoyingly, we have to do this even if there are no generics, + // to replace any generic tables. + replaceGenerics.level = level; + replaceGenerics.currentModule = currentModule; + replaceGenerics.generics.assign(ftv->generics.begin(), ftv->generics.end()); + replaceGenerics.genericPacks.assign(ftv->genericPacks.begin(), ftv->genericPacks.end()); + + // TODO: What to do if this returns nullopt? + // We don't have access to the error-reporting machinery + result = replaceGenerics.substitute(result).value_or(result); + + asMutable(result)->documentationSymbol = ty->documentationSymbol; + return result; } TypePackId Instantiation::clean(TypePackId tp) { - LUAU_ASSERT(!FFlag::LuauRankNTypes); - return addTypePack(TypePackVar(FreeTypePack{level})); + LUAU_ASSERT(false); + return tp; } bool ReplaceGenerics::ignoreChildren(TypeId ty) { - LUAU_ASSERT(FFlag::LuauRankNTypes); if (const FunctionTypeVar* ftv = get(ty)) // We aren't recursing in the case of a generic function which // binds the same generics. This can happen if, for example, there's recursive types. @@ -4083,7 +3935,6 @@ bool ReplaceGenerics::ignoreChildren(TypeId ty) bool ReplaceGenerics::isDirty(TypeId ty) { - LUAU_ASSERT(FFlag::LuauRankNTypes); if (const TableTypeVar* ttv = get(ty)) return ttv->state == TableState::Generic; else if (get(ty)) @@ -4094,7 +3945,6 @@ bool ReplaceGenerics::isDirty(TypeId ty) bool ReplaceGenerics::isDirty(TypePackId tp) { - LUAU_ASSERT(FFlag::LuauRankNTypes); if (get(tp)) return std::find(genericPacks.begin(), genericPacks.end(), tp) != genericPacks.end(); else @@ -4255,21 +4105,6 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat } } -TypePackId TypeChecker::DEPRECATED_instantiate(const ScopePtr& scope, TypePackId ty, Location location) -{ - LUAU_ASSERT(!FFlag::LuauRankNTypes); - instantiation.level = scope->level; - instantiation.currentModule = currentModule; - std::optional instantiated = instantiation.substitute(ty); - if (instantiated.has_value()) - return *instantiated; - else - { - reportError(location, UnificationTooComplex{}); - return errorTypePack; - } -} - TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) { anyification.anyType = anyType; @@ -4444,16 +4279,6 @@ TypeId TypeChecker::freshType(TypeLevel level) return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level))); } -TypeId TypeChecker::DEPRECATED_freshType(const ScopePtr& scope, bool canBeGeneric) -{ - return DEPRECATED_freshType(scope->level, canBeGeneric); -} - -TypeId TypeChecker::DEPRECATED_freshType(TypeLevel level, bool canBeGeneric) -{ - return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level, canBeGeneric))); -} - std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) { std::vector types = Luau::filterMap(type, predicate); @@ -4509,21 +4334,8 @@ TypePackId TypeChecker::freshTypePack(TypeLevel level) return addTypePack(TypePackVar(FreeTypePack(level))); } -TypePackId TypeChecker::DEPRECATED_freshTypePack(const ScopePtr& scope, bool canBeGeneric) -{ - return DEPRECATED_freshTypePack(scope->level, canBeGeneric); -} - -TypePackId TypeChecker::DEPRECATED_freshTypePack(TypeLevel level, bool canBeGeneric) -{ - return addTypePack(TypePackVar(FreeTypePack(level, canBeGeneric))); -} - -TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation, bool DEPRECATED_canBeGeneric) +TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation) { - if (DEPRECATED_canBeGeneric) - LUAU_ASSERT(!FFlag::LuauRankNTypes); - if (const auto& lit = annotation.as()) { std::optional tf; @@ -4668,11 +4480,11 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation std::optional tableIndexer; for (const auto& prop : table->props) - props[prop.name.value] = {resolveType(scope, *prop.type, DEPRECATED_canBeGeneric)}; + props[prop.name.value] = {resolveType(scope, *prop.type)}; if (const auto& indexer = table->indexer) tableIndexer = TableIndexer( - resolveType(scope, *indexer->indexType, DEPRECATED_canBeGeneric), resolveType(scope, *indexer->resultType, DEPRECATED_canBeGeneric)); + resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); return addType(TableTypeVar{ props, tableIndexer, scope->level, @@ -4683,17 +4495,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation { ScopePtr funcScope = childScope(scope, func->location); - std::vector generics; - std::vector genericPacks; - - if (FFlag::LuauGenericFunctions) - { - std::tie(generics, genericPacks) = createGenericTypes(funcScope, std::nullopt, annotation, func->generics, func->genericPacks); - } - - // TODO: better error message CLI-39912 - if (FFlag::LuauGenericFunctions && !FFlag::LuauRankNTypes && !DEPRECATED_canBeGeneric && (generics.size() > 0 || genericPacks.size() > 0)) - reportError(TypeError{annotation.location, GenericError{"generic function where only monotypes are allowed"}}); + auto [generics, genericPacks] = createGenericTypes(funcScope, std::nullopt, annotation, func->generics, func->genericPacks); TypePackId argTypes = resolveTypePack(funcScope, func->argTypes); TypePackId retTypes = resolveTypePack(funcScope, func->returnTypes); @@ -4716,16 +4518,13 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else if (auto typeOf = annotation.as()) { TypeId ty = checkExpr(scope, *typeOf->expr).type; - // TODO: better error message CLI-39912 - if (FFlag::LuauGenericFunctions && !FFlag::LuauRankNTypes && !DEPRECATED_canBeGeneric && isGeneric(ty)) - reportError(TypeError{annotation.location, GenericError{"typeof produced a polytype where only monotypes are allowed"}}); return ty; } else if (const auto& un = annotation.as()) { std::vector types; for (AstType* ann : un->types) - types.push_back(resolveType(scope, *ann, DEPRECATED_canBeGeneric)); + types.push_back(resolveType(scope, *ann)); return addType(UnionTypeVar{types}); } @@ -4733,7 +4532,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation { std::vector types; for (AstType* ann : un->types) - types.push_back(resolveType(scope, *ann, DEPRECATED_canBeGeneric)); + types.push_back(resolveType(scope, *ann)); return addType(IntersectionTypeVar{types}); } @@ -4919,9 +4718,8 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, if (FFlag::LuauCloneCorrectlyBeforeMutatingTableType) { - // TODO: CLI-46926 it's a bad idea to rename the type whether we follow through the BoundTypeVar or not - TypeId target = FFlag::LuauFollowInTypeFunApply ? follow(instantiated) : instantiated; - + // TODO: CLI-46926 it's not a good idea to rename the type here + TypeId target = follow(instantiated); bool needsClone = follow(tf.type) == target; TableTypeVar* ttv = getMutableTableType(target); @@ -5152,31 +4950,18 @@ void TypeChecker::resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, Refi void TypeChecker::resolve(const AndPredicate& andP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) { - if (FFlag::LuauOrPredicate) + if (!sense) { - if (!sense) - { - OrPredicate orP{ - {NotPredicate{std::move(andP.lhs)}}, - {NotPredicate{std::move(andP.rhs)}}, - }; - - return resolve(orP, errVec, refis, scope, !sense); - } + OrPredicate orP{ + {NotPredicate{std::move(andP.lhs)}}, + {NotPredicate{std::move(andP.rhs)}}, + }; - resolve(andP.lhs, errVec, refis, scope, sense); - resolve(andP.rhs, errVec, refis, scope, sense); + return resolve(orP, errVec, refis, scope, !sense); } - else - { - // And predicate is currently not resolvable when sense is false. 'not (a and b)' is synonymous with '(not a) or (not b)'. - // TODO: implement environment merging to permit this case. - if (!sense) - return; - resolve(andP.lhs, errVec, refis, scope, sense); - resolve(andP.rhs, errVec, refis, scope, sense); - } + resolve(andP.lhs, errVec, refis, scope, sense); + resolve(andP.rhs, errVec, refis, scope, sense); } void TypeChecker::resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) @@ -5207,58 +4992,41 @@ void TypeChecker::resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMa void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) { auto predicate = [&](TypeId option) -> std::optional { - if (FFlag::LuauTypeGuardPeelsAwaySubclasses) - { - // This by itself is not truly enough to determine that A is stronger than B or vice versa. - // The best unambiguous way about this would be to have a function that returns the relationship ordering of a pair. - // i.e. TypeRelationship relationshipOf(TypeId superTy, TypeId subTy) - bool optionIsSubtype = canUnify(isaP.ty, option, isaP.location).empty(); - bool targetIsSubtype = canUnify(option, isaP.ty, isaP.location).empty(); - - // If A is a superset of B, then if sense is true, we promote A to B, otherwise we keep A. - if (!optionIsSubtype && targetIsSubtype) + // This by itself is not truly enough to determine that A is stronger than B or vice versa. + // The best unambiguous way about this would be to have a function that returns the relationship ordering of a pair. + // i.e. TypeRelationship relationshipOf(TypeId superTy, TypeId subTy) + bool optionIsSubtype = canUnify(isaP.ty, option, isaP.location).empty(); + bool targetIsSubtype = canUnify(option, isaP.ty, isaP.location).empty(); + + // If A is a superset of B, then if sense is true, we promote A to B, otherwise we keep A. + if (!optionIsSubtype && targetIsSubtype) + return sense ? isaP.ty : option; + + // If A is a subset of B, then if sense is true we pick A, otherwise we eliminate A. + if (optionIsSubtype && !targetIsSubtype) + return sense ? std::optional(option) : std::nullopt; + + // If neither has any relationship, we only return A if sense is false. + if (!optionIsSubtype && !targetIsSubtype) + return sense ? std::nullopt : std::optional(option); + + // If both are subtypes, then we're in one of the two situations: + // 1. Instance₁ <: Instance₂ ∧ Instance₂ <: Instance₁ + // 2. any <: Instance ∧ Instance <: any + // Right now, we have to look at the types to see if they were undecidables. + // By this point, we also know free tables are also subtypes and supertypes. + if (optionIsSubtype && targetIsSubtype) + { + // We can only have (any, Instance) because the rhs is never undecidable right now. + // So we can just return the right hand side immediately. + + // typeof(x) == "Instance" where x : any + auto ttv = get(option); + if (isUndecidable(option) || (ttv && ttv->state == TableState::Free)) return sense ? isaP.ty : option; - // If A is a subset of B, then if sense is true we pick A, otherwise we eliminate A. - if (optionIsSubtype && !targetIsSubtype) - return sense ? std::optional(option) : std::nullopt; - - // If neither has any relationship, we only return A if sense is false. - if (!optionIsSubtype && !targetIsSubtype) - return sense ? std::nullopt : std::optional(option); - - // If both are subtypes, then we're in one of the two situations: - // 1. Instance₁ <: Instance₂ ∧ Instance₂ <: Instance₁ - // 2. any <: Instance ∧ Instance <: any - // Right now, we have to look at the types to see if they were undecidables. - // By this point, we also know free tables are also subtypes and supertypes. - if (optionIsSubtype && targetIsSubtype) - { - // We can only have (any, Instance) because the rhs is never undecidable right now. - // So we can just return the right hand side immediately. - - // typeof(x) == "Instance" where x : any - auto ttv = get(option); - if (isUndecidable(option) || (ttv && ttv->state == TableState::Free)) - return sense ? isaP.ty : option; - - // typeof(x) == "Instance" where x : Instance - if (sense) - return isaP.ty; - } - } - else - { - auto lctv = get(option); - auto rctv = get(isaP.ty); - - if (isSubclass(lctv, rctv) == sense) - return option; - - if (isSubclass(rctv, lctv) == sense) - return isaP.ty; - - if (canUnify(option, isaP.ty, isaP.location).empty() == sense) + // typeof(x) == "Instance" where x : Instance + if (sense) return isaP.ty; } @@ -5457,7 +5225,7 @@ std::vector TypeChecker::unTypePack(const ScopePtr& scope, TypePackId tp TypePack* expectedPack = getMutable(expectedTypePack); LUAU_ASSERT(expectedPack); for (size_t i = 0; i < expectedLength; ++i) - expectedPack->head.push_back(FFlag::LuauRankNTypes ? freshType(scope) : DEPRECATED_freshType(scope, true)); + expectedPack->head.push_back(freshType(scope)); unify(expectedTypePack, tp, location); diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 68a16ef04..228b19267 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -97,7 +97,7 @@ TypePackIterator begin(TypePackId tp) TypePackIterator end(TypePackId tp) { - return FFlag::LuauAddMissingFollow ? TypePackIterator{} : TypePackIterator{nullptr}; + return TypePackIterator{}; } bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs) @@ -203,7 +203,7 @@ TypePackId follow(TypePackId tp) size_t size(TypePackId tp) { - if (auto pack = get(FFlag::LuauAddMissingFollow ? follow(tp) : tp)) + if (auto pack = get(follow(tp))) return size(*pack); else return 0; @@ -216,7 +216,7 @@ bool finite(TypePackId tp) if (auto pack = get(tp)) return pack->tail ? finite(*pack->tail) : true; - if (auto pack = get(tp)) + if (get(tp)) return false; return true; @@ -227,7 +227,7 @@ size_t size(const TypePack& tp) size_t result = tp.head.size(); if (tp.tail) { - const TypePack* tail = get(FFlag::LuauAddMissingFollow ? follow(*tp.tail) : *tp.tail); + const TypePack* tail = get(follow(*tp.tail)); if (tail) result += size(*tail); } diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index e82f7519d..cd447ca23 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -19,8 +19,6 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) -LUAU_FASTFLAG(LuauRankNTypes) -LUAU_FASTFLAG(LuauTypeGuardPeelsAwaySubclasses) LUAU_FASTFLAG(LuauTypeAliasPacks) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) @@ -42,7 +40,7 @@ TypeId follow(TypeId t) }; auto force = [](TypeId ty) { - if (auto ltv = FFlag::LuauAddMissingFollow ? get_if(&ty->ty) : get(ty)) + if (auto ltv = get_if(&ty->ty)) { TypeId res = ltv->thunk(); if (get(res)) @@ -296,7 +294,7 @@ bool maybeGeneric(TypeId ty) { ty = follow(ty); if (auto ftv = get(ty)) - return FFlag::LuauRankNTypes || ftv->DEPRECATED_canBeGeneric; + return true; else if (auto ttv = get(ty)) { // TODO: recurse on table types CLI-39914 @@ -545,15 +543,30 @@ TypeId makeFunction(TypeArena& arena, std::optional selfType, std::initi std::initializer_list genericPacks, std::initializer_list paramTypes, std::initializer_list paramNames, std::initializer_list retTypes); +static TypeVar nilType_{PrimitiveTypeVar{PrimitiveTypeVar::NilType}, /*persistent*/ true}; +static TypeVar numberType_{PrimitiveTypeVar{PrimitiveTypeVar::Number}, /*persistent*/ true}; +static TypeVar stringType_{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true}; +static TypeVar booleanType_{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true}; +static TypeVar threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true}; +static TypeVar anyType_{AnyTypeVar{}}; +static TypeVar errorType_{ErrorTypeVar{}}; +static TypeVar optionalNumberType_{UnionTypeVar{{&numberType_, &nilType_}}}; + +static TypePackVar anyTypePack_{VariadicTypePack{&anyType_}, true}; +static TypePackVar errorTypePack_{Unifiable::Error{}}; + SingletonTypes::SingletonTypes() - : arena(new TypeArena) - , nilType_{PrimitiveTypeVar{PrimitiveTypeVar::NilType}, /*persistent*/ true} - , numberType_{PrimitiveTypeVar{PrimitiveTypeVar::Number}, /*persistent*/ true} - , stringType_{PrimitiveTypeVar{PrimitiveTypeVar::String}, /*persistent*/ true} - , booleanType_{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}, /*persistent*/ true} - , threadType_{PrimitiveTypeVar{PrimitiveTypeVar::Thread}, /*persistent*/ true} - , anyType_{AnyTypeVar{}} - , errorType_{ErrorTypeVar{}} + : nilType(&nilType_) + , numberType(&numberType_) + , stringType(&stringType_) + , booleanType(&booleanType_) + , threadType(&threadType_) + , anyType(&anyType_) + , errorType(&errorType_) + , optionalNumberType(&optionalNumberType_) + , anyTypePack(&anyTypePack_) + , errorTypePack(&errorTypePack_) + , arena(new TypeArena) { TypeId stringMetatable = makeStringMetatable(); stringType_.ty = PrimitiveTypeVar{PrimitiveTypeVar::String, makeStringMetatable()}; @@ -749,9 +762,9 @@ void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) if (opts.duplicatePrimitives && canDuplicatePrimitive(ty)) { - if (const PrimitiveTypeVar* ptv = get(ty)) + if (get(ty)) formatAppend(result, "n%d [label=\"%s\"];\n", index, toStringDetailed(ty, {}).name.c_str()); - else if (const AnyTypeVar* atv = get(ty)) + else if (get(ty)) formatAppend(result, "n%d [label=\"any\"];\n", index); } else @@ -902,19 +915,19 @@ void StateDot::visitChildren(TypeId ty, int index) finishNodeLabel(ty); finishNode(); } - else if (const AnyTypeVar* atv = get(ty)) + else if (get(ty)) { formatAppend(result, "AnyTypeVar %d", index); finishNodeLabel(ty); finishNode(); } - else if (const PrimitiveTypeVar* ptv = get(ty)) + else if (get(ty)) { formatAppend(result, "PrimitiveTypeVar %s", toStringDetailed(ty, {}).name.c_str()); finishNodeLabel(ty); finishNode(); } - else if (const ErrorTypeVar* etv = get(ty)) + else if (get(ty)) { formatAppend(result, "ErrorTypeVar %d", index); finishNodeLabel(ty); @@ -994,7 +1007,7 @@ void StateDot::visitChildren(TypePackId tp, int index) finishNodeLabel(tp); finishNode(); } - else if (const Unifiable::Error* etp = get(tp)) + else if (get(tp)) { formatAppend(result, "ErrorTypePack %d", index); finishNodeLabel(tp); @@ -1372,24 +1385,6 @@ UnionTypeVarIterator end(const UnionTypeVar* utv) return UnionTypeVarIterator{}; } -static std::vector DEPRECATED_filterMap(TypeId type, TypeIdPredicate predicate) -{ - std::vector result; - - if (auto utv = get(follow(type))) - { - for (TypeId option : utv) - { - if (auto out = predicate(follow(option))) - result.push_back(*out); - } - } - else if (auto out = predicate(follow(type))) - return {*out}; - - return result; -} - static std::vector parseFormatString(TypeChecker& typechecker, const char* data, size_t size) { const char* options = "cdiouxXeEfgGqs"; @@ -1470,9 +1465,6 @@ std::optional> magicFunctionFormat( std::vector filterMap(TypeId type, TypeIdPredicate predicate) { - if (!FFlag::LuauTypeGuardPeelsAwaySubclasses) - return DEPRECATED_filterMap(type, predicate); - type = follow(type); if (auto utv = get(type)) diff --git a/Analysis/src/Unifiable.cpp b/Analysis/src/Unifiable.cpp index cef07833a..dc5546640 100644 --- a/Analysis/src/Unifiable.cpp +++ b/Analysis/src/Unifiable.cpp @@ -1,8 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Unifiable.h" -LUAU_FASTFLAG(LuauRankNTypes) - namespace Luau { namespace Unifiable @@ -14,14 +12,6 @@ Free::Free(TypeLevel level) { } -Free::Free(TypeLevel level, bool DEPRECATED_canBeGeneric) - : index(++nextIndex) - , level(level) - , DEPRECATED_canBeGeneric(DEPRECATED_canBeGeneric) -{ - LUAU_ASSERT(!FFlag::LuauRankNTypes); -} - int Free::nextIndex = 0; Generic::Generic() diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 2539650a4..82f621b66 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -14,17 +14,15 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); -LUAU_FASTFLAG(LuauGenericFunctions) LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance, false); -LUAU_FASTFLAGVARIABLE(LuauDontMutatePersistentFunctions, false) -LUAU_FASTFLAG(LuauRankNTypes) LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) -LUAU_FASTFLAGVARIABLE(LuauSealedTableUnifyOptionalFix, false) LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false) LUAU_FASTFLAG(LuauShareTxnSeen); LUAU_FASTFLAGVARIABLE(LuauCacheUnifyTableResults, false) +LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false) +LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false) namespace Luau { @@ -219,17 +217,12 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool *asMutable(subTy) = BoundTypeVar(superTy); } - if (!FFlag::LuauRankNTypes) - l->DEPRECATED_canBeGeneric &= r->DEPRECATED_canBeGeneric; - return; } - else if (l && r && FFlag::LuauGenericFunctions) + else if (l && r) { log(superTy); occursCheck(superTy, subTy); - if (!FFlag::LuauRankNTypes) - r->DEPRECATED_canBeGeneric &= l->DEPRECATED_canBeGeneric; r->level = min(r->level, l->level); *asMutable(superTy) = BoundTypeVar(subTy); return; @@ -240,7 +233,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool // Unification can't change the level of a generic. auto rightGeneric = get(subTy); - if (FFlag::LuauRankNTypes && rightGeneric && !rightGeneric->level.subsumes(l->level)) + if (rightGeneric && !rightGeneric->level.subsumes(l->level)) { // TODO: a more informative error message? CLI-39912 errors.push_back(TypeError{location, GenericError{"Generic subtype escaping scope"}}); @@ -266,31 +259,13 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool // Unification can't change the level of a generic. auto leftGeneric = get(superTy); - if (FFlag::LuauRankNTypes && leftGeneric && !leftGeneric->level.subsumes(r->level)) - { - // TODO: a more informative error message? CLI-39912 - errors.push_back(TypeError{location, GenericError{"Generic supertype escaping scope"}}); - return; - } - - // This is the old code which is just wrong - auto wrongGeneric = get(subTy); // Guaranteed to be null - if (!FFlag::LuauRankNTypes && FFlag::LuauGenericFunctions && wrongGeneric && r->level.subsumes(wrongGeneric->level)) + if (leftGeneric && !leftGeneric->level.subsumes(r->level)) { - // This code is unreachable! Should we just remove it? // TODO: a more informative error message? CLI-39912 errors.push_back(TypeError{location, GenericError{"Generic supertype escaping scope"}}); return; } - // Check if we're unifying a monotype with a polytype - if (FFlag::LuauGenericFunctions && !FFlag::LuauRankNTypes && !r->DEPRECATED_canBeGeneric && isGeneric(superTy)) - { - // TODO: a more informative error message? CLI-39912 - errors.push_back(TypeError{location, GenericError{"Failed to unify a polytype with a monotype"}}); - return; - } - if (!get(subTy)) { if (auto leftLevel = getMutableLevel(superTy)) @@ -333,6 +308,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool // A | B <: T if A <: T and B <: T bool failed = false; std::optional unificationTooComplex; + std::optional firstFailedOption; size_t count = uv->options.size(); size_t i = 0; @@ -345,7 +321,13 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (auto e = hasUnificationTooComplex(innerState.errors)) unificationTooComplex = e; else if (!innerState.errors.empty()) + { + // 'nil' option is skipped from extended report because we present the type in a special way - 'T?' + if (FFlag::LuauExtendedTypeMismatchError && !firstFailedOption && !isNil(type)) + firstFailedOption = {innerState.errors.front()}; + failed = true; + } if (i != count - 1) innerState.log.rollback(); @@ -358,7 +340,12 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (unificationTooComplex) errors.push_back(*unificationTooComplex); else if (failed) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + { + if (FFlag::LuauExtendedTypeMismatchError && firstFailedOption) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible", *firstFailedOption}}); + else + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + } } else if (const UnionTypeVar* uv = get(superTy)) { @@ -425,14 +412,49 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (unificationTooComplex) errors.push_back(*unificationTooComplex); else if (!found) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + { + if (FFlag::LuauExtendedTypeMismatchError) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); + else + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + } } else if (const IntersectionTypeVar* uv = get(superTy)) { - // T <: A & B if A <: T and B <: T - for (TypeId type : uv->parts) + if (FFlag::LuauExtendedTypeMismatchError) + { + std::optional unificationTooComplex; + std::optional firstFailedOption; + + // T <: A & B if A <: T and B <: T + for (TypeId type : uv->parts) + { + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); + + if (auto e = hasUnificationTooComplex(innerState.errors)) + unificationTooComplex = e; + else if (!innerState.errors.empty()) + { + if (!firstFailedOption) + firstFailedOption = {innerState.errors.front()}; + } + + log.concat(std::move(innerState.log)); + } + + if (unificationTooComplex) + errors.push_back(*unificationTooComplex); + else if (firstFailedOption) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible", *firstFailedOption}}); + } + else { - tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); + // T <: A & B if A <: T and B <: T + for (TypeId type : uv->parts) + { + tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); + } } } else if (const IntersectionTypeVar* uv = get(subTy)) @@ -480,7 +502,12 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (unificationTooComplex) errors.push_back(*unificationTooComplex); else if (!found) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + { + if (FFlag::LuauExtendedTypeMismatchError) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); + else + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + } } else if (get(superTy) && get(subTy)) tryUnifyPrimitives(superTy, subTy); @@ -773,8 +800,8 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { - const bool lFreeTail = l->tail && get(FFlag::LuauAddMissingFollow ? follow(*l->tail) : *l->tail) != nullptr; - const bool rFreeTail = r->tail && get(FFlag::LuauAddMissingFollow ? follow(*r->tail) : *r->tail) != nullptr; + const bool lFreeTail = l->tail && get(follow(*l->tail)) != nullptr; + const bool rFreeTail = r->tail && get(follow(*r->tail)) != nullptr; if (lFreeTail && rFreeTail) tryUnify_(*l->tail, *r->tail); else if (lFreeTail) @@ -812,7 +839,7 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal } // In nonstrict mode, any also marks an optional argument. - else if (superIter.good() && isNonstrictMode() && get(FFlag::LuauAddMissingFollow ? follow(*superIter) : *superIter)) + else if (superIter.good() && isNonstrictMode() && get(follow(*superIter))) { superIter.advance(); continue; @@ -887,24 +914,21 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal ice("passed non-function types to unifyFunction"); size_t numGenerics = lf->generics.size(); - if (FFlag::LuauGenericFunctions && numGenerics != rf->generics.size()) + if (numGenerics != rf->generics.size()) { numGenerics = std::min(lf->generics.size(), rf->generics.size()); errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } size_t numGenericPacks = lf->genericPacks.size(); - if (FFlag::LuauGenericFunctions && numGenericPacks != rf->genericPacks.size()) + if (numGenericPacks != rf->genericPacks.size()) { numGenericPacks = std::min(lf->genericPacks.size(), rf->genericPacks.size()); errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } - if (FFlag::LuauGenericFunctions) - { - for (size_t i = 0; i < numGenerics; i++) - log.pushSeen(lf->generics[i], rf->generics[i]); - } + for (size_t i = 0; i < numGenerics; i++) + log.pushSeen(lf->generics[i], rf->generics[i]); CountMismatch::Context context = ctx; @@ -931,22 +955,19 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal tryUnify_(lf->retType, rf->retType); } - if (lf->definition && !rf->definition && (!FFlag::LuauDontMutatePersistentFunctions || !subTy->persistent)) + if (lf->definition && !rf->definition && !subTy->persistent) { rf->definition = lf->definition; } - else if (!lf->definition && rf->definition && (!FFlag::LuauDontMutatePersistentFunctions || !superTy->persistent)) + else if (!lf->definition && rf->definition && !superTy->persistent) { lf->definition = rf->definition; } ctx = context; - if (FFlag::LuauGenericFunctions) - { - for (int i = int(numGenerics) - 1; 0 <= i; i--) - log.popSeen(lf->generics[i], rf->generics[i]); - } + for (int i = int(numGenerics) - 1; 0 <= i; i--) + log.popSeen(lf->generics[i], rf->generics[i]); } namespace @@ -1032,7 +1053,12 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) Unifier innerState = makeChildUnifier(); innerState.tryUnify_(prop.type, r->second.type); - checkChildUnifierTypeMismatch(innerState.errors, left, right); + + if (FFlag::LuauExtendedTypeMismatchError) + checkChildUnifierTypeMismatch(innerState.errors, name, left, right); + else + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) log.concat(std::move(innerState.log)); else @@ -1047,7 +1073,12 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) Unifier innerState = makeChildUnifier(); innerState.tryUnify_(prop.type, rt->indexer->indexResultType); - checkChildUnifierTypeMismatch(innerState.errors, left, right); + + if (FFlag::LuauExtendedTypeMismatchError) + checkChildUnifierTypeMismatch(innerState.errors, name, left, right); + else + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) log.concat(std::move(innerState.log)); else @@ -1083,7 +1114,12 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) Unifier innerState = makeChildUnifier(); innerState.tryUnify_(prop.type, lt->indexer->indexResultType); - checkChildUnifierTypeMismatch(innerState.errors, left, right); + + if (FFlag::LuauExtendedTypeMismatchError) + checkChildUnifierTypeMismatch(innerState.errors, name, left, right); + else + checkChildUnifierTypeMismatch(innerState.errors, left, right); + if (innerState.errors.empty()) log.concat(std::move(innerState.log)); else @@ -1384,21 +1420,8 @@ void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersectio const auto& r = rt->props.find(it.first); if (r == rt->props.end()) { - if (FFlag::LuauSealedTableUnifyOptionalFix) - { - if (isOptional(it.second.type)) - continue; - } - else - { - if (get(it.second.type)) - { - const UnionTypeVar* possiblyOptional = get(it.second.type); - const std::vector& options = possiblyOptional->options; - if (options.end() != std::find_if(options.begin(), options.end(), isNil)) - continue; - } - } + if (isOptional(it.second.type)) + continue; missingPropertiesInSuper.push_back(it.first); @@ -1482,21 +1505,8 @@ void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersectio const auto& r = lt->props.find(it.first); if (r == lt->props.end()) { - if (FFlag::LuauSealedTableUnifyOptionalFix) - { - if (isOptional(it.second.type)) - continue; - } - else - { - if (get(it.second.type)) - { - const UnionTypeVar* possiblyOptional = get(it.second.type); - const std::vector& options = possiblyOptional->options; - if (options.end() != std::find_if(options.begin(), options.end(), isNil)) - continue; - } - } + if (isOptional(it.second.type)) + continue; extraPropertiesInSub.push_back(it.first); } @@ -1526,7 +1536,18 @@ void Unifier::tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reverse innerState.tryUnify_(lhs->table, rhs->table); innerState.tryUnify_(lhs->metatable, rhs->metatable); - checkChildUnifierTypeMismatch(innerState.errors, reversed ? other : metatable, reversed ? metatable : other); + if (FFlag::LuauExtendedTypeMismatchError) + { + if (auto e = hasUnificationTooComplex(innerState.errors)) + errors.push_back(*e); + else if (!innerState.errors.empty()) + errors.push_back( + TypeError{location, TypeMismatch{reversed ? other : metatable, reversed ? metatable : other, "", innerState.errors.front()}}); + } + else + { + checkChildUnifierTypeMismatch(innerState.errors, reversed ? other : metatable, reversed ? metatable : other); + } log.concat(std::move(innerState.log)); } @@ -1613,10 +1634,34 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) { ok = false; errors.push_back(TypeError{location, UnknownProperty{superTy, propName}}); - tryUnify_(prop.type, singletonTypes.errorType); + + if (!FFlag::LuauExtendedClassMismatchError) + tryUnify_(prop.type, singletonTypes.errorType); } else - tryUnify_(prop.type, classProp->type); + { + if (FFlag::LuauExtendedClassMismatchError) + { + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(prop.type, classProp->type); + + checkChildUnifierTypeMismatch(innerState.errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); + + if (innerState.errors.empty()) + { + log.concat(std::move(innerState.log)); + } + else + { + ok = false; + innerState.log.rollback(); + } + } + else + { + tryUnify_(prop.type, classProp->type); + } + } } if (table->indexer) @@ -1649,45 +1694,24 @@ static void queueTypePack_DEPRECATED( while (true) { - if (FFlag::LuauAddMissingFollow) - a = follow(a); + a = follow(a); if (seenTypePacks.count(a)) break; seenTypePacks.insert(a); - if (FFlag::LuauAddMissingFollow) + if (get(a)) { - if (get(a)) - { - state.log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; - } - else if (auto tp = get(a)) - { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; - } + state.log(a); + *asMutable(a) = Unifiable::Bound{anyTypePack}; } - else + else if (auto tp = get(a)) { - if (get(a)) - { - state.log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; - } - - if (auto tp = get(a)) - { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; - } + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; } } } @@ -1698,45 +1722,24 @@ static void queueTypePack(std::vector& queue, DenseHashSet& while (true) { - if (FFlag::LuauAddMissingFollow) - a = follow(a); + a = follow(a); if (seenTypePacks.find(a)) break; seenTypePacks.insert(a); - if (FFlag::LuauAddMissingFollow) + if (get(a)) { - if (get(a)) - { - state.log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; - } - else if (auto tp = get(a)) - { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; - } + state.log(a); + *asMutable(a) = Unifiable::Bound{anyTypePack}; } - else + else if (auto tp = get(a)) { - if (get(a)) - { - state.log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; - } - - if (auto tp = get(a)) - { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; - } + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; } } } @@ -1990,33 +1993,6 @@ std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N return Luau::findTablePropertyRespectingMeta(errors, globalScope, lhsType, name, location); } -std::optional Unifier::findMetatableEntry(TypeId type, std::string entry) -{ - type = follow(type); - - std::optional metatable = getMetatable(type); - if (!metatable) - return std::nullopt; - - TypeId unwrapped = follow(*metatable); - - if (get(unwrapped)) - return singletonTypes.anyType; - - const TableTypeVar* mtt = getTableType(unwrapped); - if (!mtt) - { - errors.push_back(TypeError{location, GenericError{"Metatable was not a table."}}); - return std::nullopt; - } - - auto it = mtt->props.find(entry); - if (it != mtt->props.end()) - return it->second.type; - else - return std::nullopt; -} - void Unifier::occursCheck(TypeId needle, TypeId haystack) { std::unordered_set seen_DEPRECATED; @@ -2168,7 +2144,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, Dense { for (const auto& ty : a->head) { - if (auto f = get(FFlag::LuauAddMissingFollow ? follow(ty) : ty)) + if (auto f = get(follow(ty))) { occursCheck(seen_DEPRECATED, seen, needle, f->argTypes); occursCheck(seen_DEPRECATED, seen, needle, f->retType); @@ -2207,6 +2183,17 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId errors.push_back(TypeError{location, TypeMismatch{wantedType, givenType}}); } +void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType) +{ + LUAU_ASSERT(FFlag::LuauExtendedTypeMismatchError || FFlag::LuauExtendedClassMismatchError); + + if (auto e = hasUnificationTooComplex(innerErrors)) + errors.push_back(*e); + else if (!innerErrors.empty()) + errors.push_back( + TypeError{location, TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible", prop.c_str()), innerErrors.front()}}); +} + void Unifier::ice(const std::string& message, const Location& location) { sharedState.iceHandler->ice(message, location); diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 42c64dc92..39c7d9251 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -282,7 +282,6 @@ class Parser // `<' namelist `>' std::pair, AstArray> parseGenericTypeList(); - std::pair, AstArray> parseGenericTypeListIfFFlagParseGenericFunctions(); // `<' typeAnnotation[, ...] `>' AstArray parseTypeParams(); diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 846bc0ba9..a1bad65ef 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -10,13 +10,13 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsParserFix, false) -LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctions, false) LUAU_FASTFLAGVARIABLE(LuauCaptureBrokenCommentSpans, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) LUAU_FASTFLAGVARIABLE(LuauTypeAliasPacks, false) LUAU_FASTFLAGVARIABLE(LuauParseTypePackTypeParameters, false) +LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) +LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctionTypeBegin, false) namespace Luau { @@ -957,7 +957,7 @@ AstStat* Parser::parseAssignment(AstExpr* initial) { nextLexeme(); - AstExpr* expr = parsePrimaryExpr(/* asStatement= */ false); + AstExpr* expr = parsePrimaryExpr(/* asStatement= */ FFlag::LuauFixAmbiguousErrorRecoveryInAssign); if (!isExprLValue(expr)) expr = reportExprError(expr->location, copy({expr}), "Assigned expression must be a variable or a field"); @@ -995,7 +995,7 @@ std::pair Parser::parseFunctionBody( { Location start = matchFunction.location; - auto [generics, genericPacks] = parseGenericTypeListIfFFlagParseGenericFunctions(); + auto [generics, genericPacks] = parseGenericTypeList(); Lexeme matchParen = lexer.current(); expectAndConsume('(', "function"); @@ -1343,19 +1343,18 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) { incrementRecursionCounter("type annotation"); - bool monomorphic = !(FFlag::LuauParseGenericFunctions && lexer.current().type == '<'); - - auto [generics, genericPacks] = parseGenericTypeListIfFFlagParseGenericFunctions(); + bool monomorphic = lexer.current().type != '<'; Lexeme begin = lexer.current(); - if (FFlag::LuauGenericFunctionsParserFix) - expectAndConsume('(', "function parameters"); - else - { - LUAU_ASSERT(begin.type == '('); - nextLexeme(); // ( - } + auto [generics, genericPacks] = parseGenericTypeList(); + + Lexeme parameterStart = lexer.current(); + + if (!FFlag::LuauParseGenericFunctionTypeBegin) + begin = parameterStart; + + expectAndConsume('(', "function parameters"); matchRecoveryStopOnToken[Lexeme::SkinnyArrow]++; @@ -1366,7 +1365,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) if (lexer.current().type != ')') varargAnnotation = parseTypeList(params, names); - expectMatchAndConsume(')', begin, true); + expectMatchAndConsume(')', parameterStart, true); matchRecoveryStopOnToken[Lexeme::SkinnyArrow]--; @@ -1585,7 +1584,7 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) { return {parseTableTypeAnnotation(), {}}; } - else if (lexer.current().type == '(' || (FFlag::LuauParseGenericFunctions && lexer.current().type == '<')) + else if (lexer.current().type == '(' || lexer.current().type == '<') { return parseFunctionTypeAnnotation(allowPack); } @@ -2315,19 +2314,6 @@ Parser::Name Parser::parseIndexName(const char* context, const Position& previou return Name(nameError, location); } -std::pair, AstArray> Parser::parseGenericTypeListIfFFlagParseGenericFunctions() -{ - if (FFlag::LuauParseGenericFunctions) - return Parser::parseGenericTypeList(); - AstArray generics; - AstArray genericPacks; - generics.size = 0; - generics.data = nullptr; - genericPacks.size = 0; - genericPacks.data = nullptr; - return std::pair(generics, genericPacks); -} - std::pair, AstArray> Parser::parseGenericTypeList() { TempVector names{scratchName}; @@ -2342,7 +2328,7 @@ std::pair, AstArray> Parser::parseGenericTypeList() while (true) { AstName name = parseName().name; - if (FFlag::LuauParseGenericFunctions && lexer.current().type == Lexeme::Dot3) + if (lexer.current().type == Lexeme::Dot3) { seenPack = true; nextLexeme(); @@ -2379,15 +2365,12 @@ AstArray Parser::parseTypeParams() Lexeme begin = lexer.current(); nextLexeme(); - bool seenPack = false; while (true) { if (FFlag::LuauParseTypePackTypeParameters) { if (shouldParseTypePackAnnotation(lexer)) { - seenPack = true; - auto typePack = parseTypePackAnnotation(); if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them @@ -2399,8 +2382,6 @@ AstArray Parser::parseTypeParams() if (typePack) { - seenPack = true; - if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them parameters.push_back({{}, typePack}); } diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index ed0552d74..9ab10aaf0 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -34,8 +34,10 @@ static void report(ReportFormat format, const char* name, const Luau::Location& } } -static void reportError(ReportFormat format, const char* name, const Luau::TypeError& error) +static void reportError(ReportFormat format, const Luau::TypeError& error) { + const char* name = error.moduleName.c_str(); + if (const Luau::SyntaxError* syntaxError = Luau::get_if(&error.data)) report(format, name, error.location, "SyntaxError", syntaxError->message.c_str()); else @@ -49,7 +51,10 @@ static void reportWarning(ReportFormat format, const char* name, const Luau::Lin static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat format, bool annotate) { - Luau::CheckResult cr = frontend.check(name); + Luau::CheckResult cr; + + if (frontend.isDirty(name)) + cr = frontend.check(name); if (!frontend.getSourceModule(name)) { @@ -58,7 +63,7 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat } for (auto& error : cr.errors) - reportError(format, name, error); + reportError(format, error); Luau::LintResult lr = frontend.lint(name); @@ -115,7 +120,12 @@ struct CliFileResolver : Luau::FileResolver { if (Luau::AstExprConstantString* expr = node->as()) { - Luau::ModuleName name = std::string(expr->value.data, expr->value.size) + ".lua"; + Luau::ModuleName name = std::string(expr->value.data, expr->value.size) + ".luau"; + if (!moduleExists(name)) + { + // fall back to .lua if a module with .luau doesn't exist + name = std::string(expr->value.data, expr->value.size) + ".lua"; + } return {{name}}; } @@ -236,8 +246,15 @@ int main(int argc, char** argv) if (isDirectory(argv[i])) { traverseDirectory(argv[i], [&](const std::string& name) { - if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) + // Look for .luau first and if absent, fall back to .lua + if (name.length() > 5 && name.rfind(".luau") == name.length() - 5) + { + failed += !analyzeFile(frontend, name.c_str(), format, annotate); + } + else if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) + { failed += !analyzeFile(frontend, name.c_str(), format, annotate); + } }); } else diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 4968d0800..5c904cca4 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -13,6 +13,17 @@ #include +#ifdef _WIN32 +#include +#include +#endif + +enum class CompileFormat +{ + Default, + Binary +}; + static int lua_loadstring(lua_State* L) { size_t l = 0; @@ -51,9 +62,13 @@ static int lua_require(lua_State* L) return finishrequire(L); lua_pop(L, 1); - std::optional source = readFile(name + ".lua"); + std::optional source = readFile(name + ".luau"); if (!source) - luaL_argerrorL(L, 1, ("error loading " + name).c_str()); + { + source = readFile(name + ".lua"); // try .lua if .luau doesn't exist + if (!source) + luaL_argerrorL(L, 1, ("error loading " + name).c_str()); // if neither .luau nor .lua exist, we have an error + } // module needs to run in a new thread, isolated from the rest lua_State* GL = lua_mainthread(L); @@ -183,6 +198,11 @@ static std::string runCode(lua_State* L, const std::string& source) error += "\nstack backtrace:\n"; error += lua_debugtrace(T); +#ifdef __EMSCRIPTEN__ + // nicer formatting for errors in web repl + error = "Error:" + error; +#endif + fprintf(stdout, "%s", error.c_str()); } @@ -190,6 +210,39 @@ static std::string runCode(lua_State* L, const std::string& source) return std::string(); } +#ifdef __EMSCRIPTEN__ +extern "C" +{ + const char* executeScript(const char* source) + { + // setup flags + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + if (strncmp(flag->name, "Luau", 4) == 0) + flag->value = true; + + // create new state + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + // setup state + setupState(L); + + // sandbox thread + luaL_sandboxthread(L); + + // static string for caching result (prevents dangling ptr on function exit) + static std::string result; + + // run code + collect error + result = runCode(L, source); + + return result.empty() ? NULL : result.c_str(); + } +} +#endif + +// Excluded from emscripten compilation to avoid -Wunused-function errors. +#ifndef __EMSCRIPTEN__ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, std::vector& completions) { std::string_view lookup = editBuffer + start; @@ -366,7 +419,7 @@ static void reportError(const char* name, const Luau::CompileError& error) report(name, error.getLocation(), "CompileError", error.what()); } -static bool compileFile(const char* name) +static bool compileFile(const char* name, CompileFormat format) { std::optional source = readFile(name); if (!source) @@ -383,7 +436,15 @@ static bool compileFile(const char* name) Luau::compileOrThrow(bcb, *source); - printf("%s", bcb.dumpEverything().c_str()); + switch (format) + { + case CompileFormat::Default: + printf("%s", bcb.dumpEverything().c_str()); + break; + case CompileFormat::Binary: + fwrite(bcb.getBytecode().data(), 1, bcb.getBytecode().size(), stdout); + break; + } return true; } @@ -408,7 +469,7 @@ static void displayHelp(const char* argv0) printf("\n"); printf("Available modes:\n"); printf(" omitted: compile and run input files one by one\n"); - printf(" --compile: compile input files and output resulting bytecode\n"); + printf(" --compile[=format]: compile input files and output resulting formatted bytecode (binary or text)\n"); printf("\n"); printf("Available options:\n"); printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); @@ -440,8 +501,19 @@ int main(int argc, char** argv) return 0; } - if (argc >= 2 && strcmp(argv[1], "--compile") == 0) + + if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0) { + CompileFormat format = CompileFormat::Default; + + if (strcmp(argv[1], "--compile=binary") == 0) + format = CompileFormat::Binary; + +#ifdef _WIN32 + if (format == CompileFormat::Binary) + _setmode(_fileno(stdout), _O_BINARY); +#endif + int failed = 0; for (int i = 2; i < argc; ++i) @@ -452,13 +524,15 @@ int main(int argc, char** argv) if (isDirectory(argv[i])) { traverseDirectory(argv[i], [&](const std::string& name) { - if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) - failed += !compileFile(name.c_str()); + if (name.length() > 5 && name.rfind(".luau") == name.length() - 5) + failed += !compileFile(name.c_str(), format); + else if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) + failed += !compileFile(name.c_str(), format); }); } else { - failed += !compileFile(argv[i]); + failed += !compileFile(argv[i], format); } } @@ -511,5 +585,6 @@ int main(int argc, char** argv) return failed; } } +#endif diff --git a/CMakeLists.txt b/CMakeLists.txt index d6598f2a9..36014a983 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,17 +17,26 @@ add_library(Luau.VM STATIC) if(LUAU_BUILD_CLI) add_executable(Luau.Repl.CLI) - add_executable(Luau.Analyze.CLI) + if(NOT EMSCRIPTEN) + add_executable(Luau.Analyze.CLI) + else() + # add -fexceptions for emscripten to allow exceptions to be caught in C++ + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions") + endif() # This also adds target `name` on Linux/macOS and `name.exe` on Windows set_target_properties(Luau.Repl.CLI PROPERTIES OUTPUT_NAME luau) - set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze) + + if(NOT EMSCRIPTEN) + set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze) + endif() endif() -if(LUAU_BUILD_TESTS) +if(LUAU_BUILD_TESTS AND NOT EMSCRIPTEN) add_executable(Luau.UnitTest) add_executable(Luau.Conformance) endif() + include(Sources.cmake) target_compile_features(Luau.Ast PUBLIC cxx_std_17) @@ -53,10 +62,6 @@ if(MSVC) else() list(APPEND LUAU_OPTIONS -Wall) # All warnings list(APPEND LUAU_OPTIONS -Werror) # Warnings are errors - - if(CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - list(APPEND LUAU_OPTIONS -Wno-unused) # GCC considers variables declared/checked in if() as unused - endif() endif() target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS}) @@ -65,7 +70,10 @@ target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS}) if(LUAU_BUILD_CLI) target_compile_options(Luau.Repl.CLI PRIVATE ${LUAU_OPTIONS}) - target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) + + if(NOT EMSCRIPTEN) + target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) + endif() target_include_directories(Luau.Repl.CLI PRIVATE extern) target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.VM) @@ -74,10 +82,20 @@ if(LUAU_BUILD_CLI) target_link_libraries(Luau.Repl.CLI PRIVATE pthread) endif() - target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis) + if(NOT EMSCRIPTEN) + target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis) + endif() + + if(EMSCRIPTEN) + # declare exported functions to emscripten + target_link_options(Luau.Repl.CLI PRIVATE -sEXPORTED_FUNCTIONS=['_executeScript'] -sEXPORTED_RUNTIME_METHODS=['ccall','cwrap'] -fexceptions) + + # custom output directory for wasm + js file + set_target_properties(Luau.Repl.CLI PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/docs/assets/luau) + endif() endif() -if(LUAU_BUILD_TESTS) +if(LUAU_BUILD_TESTS AND NOT EMSCRIPTEN) target_compile_options(Luau.UnitTest PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.UnitTest PRIVATE extern) target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler) diff --git a/Compiler/include/Luau/Bytecode.h b/Compiler/include/Luau/Bytecode.h index 4b03ed1c7..71631d101 100644 --- a/Compiler/include/Luau/Bytecode.h +++ b/Compiler/include/Luau/Bytecode.h @@ -467,6 +467,10 @@ enum LuauBuiltinFunction // vector ctor LBF_VECTOR, + + // bit32.count + LBF_BIT32_COUNTLZ, + LBF_BIT32_COUNTRZ, }; // Capture type, used in LOP_CAPTURE diff --git a/Compiler/include/Luau/Compiler.h b/Compiler/include/Luau/Compiler.h index f8d671588..4f88e602e 100644 --- a/Compiler/include/Luau/Compiler.h +++ b/Compiler/include/Luau/Compiler.h @@ -36,6 +36,9 @@ struct CompileOptions // global builtin to construct vectors; disabled by default const char* vectorLib = nullptr; const char* vectorCtor = nullptr; + + // null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these + const char** mutableGlobals = nullptr; }; class CompileError : public std::exception diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 7750a1d9f..9712f02f4 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -13,7 +13,9 @@ LUAU_FASTFLAGVARIABLE(LuauPreloadClosures, false) LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresFenv, false) LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresUpval, false) +LUAU_FASTFLAGVARIABLE(LuauGenericSpecialGlobals, false) LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport) +LUAU_FASTFLAGVARIABLE(LuauBit32CountBuiltin, false) namespace Luau { @@ -22,6 +24,7 @@ static const uint32_t kMaxRegisterCount = 255; static const uint32_t kMaxUpvalueCount = 200; static const uint32_t kMaxLocalCount = 200; +// TODO: Remove with LuauGenericSpecialGlobals static const char* kSpecialGlobals[] = {"Game", "Workspace", "_G", "game", "plugin", "script", "shared", "workspace"}; CompileError::CompileError(const Location& location, const std::string& message) @@ -1277,7 +1280,7 @@ struct Compiler { const Global* global = globals.find(expr->name); - return options.optimizationLevel >= 1 && (!global || (!global->written && !global->special)); + return options.optimizationLevel >= 1 && (!global || (!global->written && !global->writable)); } void compileExprIndexName(AstExprIndexName* expr, uint8_t target) @@ -2465,9 +2468,10 @@ struct Compiler } else if (node->is()) { + LUAU_ASSERT(!loops.empty()); + // before exiting out of the loop, we need to close all local variables that were captured in closures since loop start // normally they are closed by the enclosing blocks, including the loop block, but we're skipping that here - LUAU_ASSERT(!loops.empty()); closeLocals(loops.back().localOffset); size_t label = bytecode.emitLabel(); @@ -2478,12 +2482,13 @@ struct Compiler } else if (AstStatContinue* stat = node->as()) { + LUAU_ASSERT(!loops.empty()); + if (loops.back().untilCondition) validateContinueUntil(stat, loops.back().untilCondition); // before continuing, we need to close all local variables that were captured in closures since loop start // normally they are closed by the enclosing blocks, including the loop block, but we're skipping that here - LUAU_ASSERT(!loops.empty()); closeLocals(loops.back().localOffset); size_t label = bytecode.emitLabel(); @@ -2900,6 +2905,11 @@ struct Compiler break; case AstExprUnary::Len: + if (arg.type == Constant::Type_String) + { + result.type = Constant::Type_Number; + result.valueNumber = double(arg.valueString.size); + } break; default: @@ -3440,7 +3450,7 @@ struct Compiler struct Global { - bool special = false; + bool writable = false; bool written = false; }; @@ -3498,7 +3508,7 @@ struct Compiler { Global* g = globals.find(object->name); - return !g || (!g->special && !g->written) ? Builtin{object->name, expr->index} : Builtin(); + return !g || (!g->writable && !g->written) ? Builtin{object->name, expr->index} : Builtin(); } else { @@ -3629,6 +3639,10 @@ struct Compiler return LBF_BIT32_RROTATE; if (builtin.method == "rshift") return LBF_BIT32_RSHIFT; + if (builtin.method == "countlz" && FFlag::LuauBit32CountBuiltin) + return LBF_BIT32_COUNTLZ; + if (builtin.method == "countrz" && FFlag::LuauBit32CountBuiltin) + return LBF_BIT32_COUNTRZ; } if (builtin.object == "string") @@ -3696,13 +3710,24 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName Compiler compiler(bytecode, options); - // since access to some global objects may result in values that change over time, we block table imports - for (const char* global : kSpecialGlobals) + // since access to some global objects may result in values that change over time, we block imports from non-readonly tables + if (FFlag::LuauGenericSpecialGlobals) { - AstName name = names.get(global); + if (AstName name = names.get("_G"); name.value) + compiler.globals[name].writable = true; - if (name.value) - compiler.globals[name].special = true; + if (options.mutableGlobals) + for (const char** ptr = options.mutableGlobals; *ptr; ++ptr) + if (AstName name = names.get(*ptr); name.value) + compiler.globals[name].writable = true; + } + else + { + for (const char* global : kSpecialGlobals) + { + if (AstName name = names.get(global); name.value) + compiler.globals[name].writable = true; + } } // this visitor traverses the AST to analyze mutability of locals/globals, filling Local::written and Global::written @@ -3717,7 +3742,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName } // this visitor tracks calls to getfenv/setfenv and disables some optimizations when they are found - if (FFlag::LuauPreloadClosuresFenv && options.optimizationLevel >= 1) + if (FFlag::LuauPreloadClosuresFenv && options.optimizationLevel >= 1 && (names.get("getfenv").value || names.get("setfenv").value)) { Compiler::FenvVisitor fenvVisitor(compiler.getfenvUsed, compiler.setfenvUsed); root->visit(&fenvVisitor); diff --git a/Makefile b/Makefile index 7788251d8..5d51b3d4e 100644 --- a/Makefile +++ b/Makefile @@ -49,7 +49,10 @@ OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(VM_OBJECTS) $(T CXXFLAGS=-g -Wall -Werror LDFLAGS= -CXXFLAGS+=-Wno-unused # temporary, for older gcc versions +# temporary, for older gcc versions as they treat var in `if (type var = val)` as unused +ifeq ($(findstring g++,$(shell $(CXX) --version)),g++) + CXXFLAGS+=-Wno-unused +endif # configuration-specific flags ifeq ($(config),release) @@ -134,12 +137,11 @@ $(TESTS_TARGET) $(REPL_CLI_TARGET) $(ANALYZE_CLI_TARGET): # executable targets for fuzzing fuzz-%: $(BUILD)/fuzz/%.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) + $(CXX) $^ $(LDFLAGS) -o $@ + fuzz-proto: $(BUILD)/fuzz/proto.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) | build/libprotobuf-mutator fuzz-prototest: $(BUILD)/fuzz/prototest.cpp.o $(BUILD)/fuzz/protoprint.cpp.o $(BUILD)/fuzz/luau.pb.cpp.o $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) | build/libprotobuf-mutator -fuzz-%: - $(CXX) $^ $(LDFLAGS) -o $@ - # static library targets $(AST_TARGET): $(AST_OBJECTS) $(COMPILER_TARGET): $(COMPILER_OBJECTS) diff --git a/VM/include/lua.h b/VM/include/lua.h index 2f93ad903..a9d3e875a 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -213,6 +213,8 @@ LUA_API int lua_resume(lua_State* L, lua_State* from, int narg); LUA_API int lua_resumeerror(lua_State* L, lua_State* from); LUA_API int lua_status(lua_State* L); LUA_API int lua_isyieldable(lua_State* L); +LUA_API void* lua_getthreaddata(lua_State* L); +LUA_API void lua_setthreaddata(lua_State* L, void* data); /* ** garbage-collection function and options @@ -346,6 +348,8 @@ struct lua_Debug * can only be changed when the VM is not running any code */ struct lua_Callbacks { + void* userdata; /* arbitrary userdata pointer that is never overwritten by Luau */ + void (*interrupt)(lua_State* L, int gc); /* gets called at safepoints (loop back edges, call/ret, gc) if set */ void (*panic)(lua_State* L, int errcode); /* gets called when an unprotected error is raised (if longjmp is used) */ diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index f2e97c669..7e742644f 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -703,6 +703,7 @@ void lua_setreadonly(lua_State* L, int objindex, bool value) const TValue* o = index2adr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); + api_check(L, t != hvalue(registry(L))); t->readonly = value; return; } @@ -987,6 +988,16 @@ int lua_status(lua_State* L) return L->status; } +void* lua_getthreaddata(lua_State* L) +{ + return L->userdata; +} + +void lua_setthreaddata(lua_State* L, void* data) +{ + L->userdata = data; +} + /* ** Garbage-collection function */ diff --git a/VM/src/lbitlib.cpp b/VM/src/lbitlib.cpp index 0754a351d..c72fe6748 100644 --- a/VM/src/lbitlib.cpp +++ b/VM/src/lbitlib.cpp @@ -4,6 +4,8 @@ #include "lnumutils.h" +LUAU_FASTFLAGVARIABLE(LuauBit32Count, false) + #define ALLONES ~0u #define NBITS int(8 * sizeof(unsigned)) @@ -177,6 +179,44 @@ static int b_replace(lua_State* L) return 1; } +static int b_countlz(lua_State* L) +{ + if (!FFlag::LuauBit32Count) + luaL_error(L, "bit32.countlz isn't enabled"); + + b_uint v = luaL_checkunsigned(L, 1); + + b_uint r = NBITS; + for (int i = 0; i < NBITS; ++i) + if (v & (1u << (NBITS - 1 - i))) + { + r = i; + break; + } + + lua_pushunsigned(L, r); + return 1; +} + +static int b_countrz(lua_State* L) +{ + if (!FFlag::LuauBit32Count) + luaL_error(L, "bit32.countrz isn't enabled"); + + b_uint v = luaL_checkunsigned(L, 1); + + b_uint r = NBITS; + for (int i = 0; i < NBITS; ++i) + if (v & (1u << i)) + { + r = i; + break; + } + + lua_pushunsigned(L, r); + return 1; +} + static const luaL_Reg bitlib[] = { {"arshift", b_arshift}, {"band", b_and}, @@ -190,6 +230,8 @@ static const luaL_Reg bitlib[] = { {"replace", b_replace}, {"rrotate", b_rrot}, {"rshift", b_rshift}, + {"countlz", b_countlz}, + {"countrz", b_countrz}, {NULL, NULL}, }; diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index e1c99b21a..9ab57ac9b 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -20,8 +20,9 @@ // If types of the arguments mismatch, luauF_* needs to return -1 and the execution will fall back to the usual call path // If luauF_* succeeds, it needs to return *all* requested arguments, filling results with nil as appropriate. // On input, nparams refers to the actual number of arguments (0+), whereas nresults contains LUA_MULTRET for arbitrary returns or 0+ for a -// fixed-length return Because of this, and the fact that "extra" returned values will be ignored, implementations below typically check that nresults -// is <= expected number, which covers the LUA_MULTRET case. +// fixed-length return +// Because of this, and the fact that "extra" returned values will be ignored, implementations below typically check that nresults is <= expected +// number, which covers the LUA_MULTRET case. static int luauF_assert(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { @@ -1030,6 +1031,52 @@ static int luauF_vector(lua_State* L, StkId res, TValue* arg0, int nresults, Stk return -1; } +static int luauF_countlz(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + + unsigned n; + luai_num2unsigned(n, a1); + +#ifdef _MSC_VER + unsigned long rl; + int r = _BitScanReverse(&rl, n) ? 31 - int(rl) : 32; +#else + int r = n == 0 ? 32 : __builtin_clz(n); +#endif + + setnvalue(res, double(r)); + return 1; + } + + return -1; +} + +static int luauF_countrz(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams >= 1 && nresults <= 1 && ttisnumber(arg0)) + { + double a1 = nvalue(arg0); + + unsigned n; + luai_num2unsigned(n, a1); + +#ifdef _MSC_VER + unsigned long rl; + int r = _BitScanForward(&rl, n) ? int(rl) : 32; +#else + int r = n == 0 ? 32 : __builtin_ctz(n); +#endif + + setnvalue(res, double(r)); + return 1; + } + + return -1; +} + luau_FastFunction luauF_table[256] = { NULL, luauF_assert, @@ -1096,4 +1143,7 @@ luau_FastFunction luauF_table[256] = { luauF_tunpack, luauF_vector, + + luauF_countlz, + luauF_countrz, }; diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 6553009ff..11f79d1a3 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -13,7 +13,6 @@ #include LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgainForwardBarrier, false) -LUAU_FASTFLAGVARIABLE(LuauConsolidatedStep, false) LUAU_FASTFLAGVARIABLE(LuauSeparateAtomic, false) LUAU_FASTFLAG(LuauArrayBoundary) @@ -677,117 +676,6 @@ static size_t atomic(lua_State* L) return work; } -static size_t singlestep(lua_State* L) -{ - size_t cost = 0; - global_State* g = L->global; - switch (g->gcstate) - { - case GCSpause: - { - markroot(L); /* start a new collection */ - LUAU_ASSERT(g->gcstate == GCSpropagate); - break; - } - case GCSpropagate: - { - if (g->gray) - { - g->gcstats.currcycle.markitems++; - - cost = propagatemark(g); - } - else - { - // perform one iteration over 'gray again' list - g->gray = g->grayagain; - g->grayagain = NULL; - - g->gcstate = GCSpropagateagain; - } - break; - } - case GCSpropagateagain: - { - if (g->gray) - { - g->gcstats.currcycle.markitems++; - - cost = propagatemark(g); - } - else /* no more `gray' objects */ - { - if (FFlag::LuauSeparateAtomic) - { - g->gcstate = GCSatomic; - } - else - { - double starttimestamp = lua_clock(); - - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - - atomic(L); /* finish mark phase */ - LUAU_ASSERT(g->gcstate == GCSsweepstring); - - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; - } - } - break; - } - case GCSatomic: - { - g->gcstats.currcycle.atomicstarttimestamp = lua_clock(); - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - - cost = atomic(L); /* finish mark phase */ - LUAU_ASSERT(g->gcstate == GCSsweepstring); - break; - } - case GCSsweepstring: - { - size_t traversedcount = 0; - sweepwholelist(L, &g->strt.hash[g->sweepstrgc++], &traversedcount); - - // nothing more to sweep? - if (g->sweepstrgc >= g->strt.size) - { - // sweep string buffer list and preserve used string count - uint32_t nuse = L->global->strt.nuse; - sweepwholelist(L, &g->strbufgc, &traversedcount); - L->global->strt.nuse = nuse; - - g->gcstate = GCSsweep; // end sweep-string phase - } - - g->gcstats.currcycle.sweepitems += traversedcount; - - cost = GC_SWEEPCOST; - break; - } - case GCSsweep: - { - size_t traversedcount = 0; - g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX, &traversedcount); - - g->gcstats.currcycle.sweepitems += traversedcount; - - if (*g->sweepgc == NULL) - { /* nothing more to sweep? */ - shrinkbuffers(L); - g->gcstate = GCSpause; /* end collection */ - } - cost = GC_SWEEPMAX * GC_SWEEPCOST; - break; - } - default: - LUAU_ASSERT(!"Unexpected GC state"); - } - - return cost; -} - static size_t gcstep(lua_State* L, size_t limit) { size_t cost = 0; @@ -980,37 +868,12 @@ void luaC_step(lua_State* L, bool assist) int lastgcstate = g->gcstate; double lasttimestamp = lua_clock(); - if (FFlag::LuauConsolidatedStep) - { - size_t work = gcstep(L, lim); + size_t work = gcstep(L, lim); - if (assist) - g->gcstats.currcycle.assistwork += work; - else - g->gcstats.currcycle.explicitwork += work; - } + if (assist) + g->gcstats.currcycle.assistwork += work; else - { - // always perform at least one single step - do - { - lim -= singlestep(L); - - // if we have switched to a different state, capture the duration of last stage - // this way we reduce the number of timer calls we make - if (lastgcstate != g->gcstate) - { - GC_INTERRUPT(lastgcstate); - - double now = lua_clock(); - - recordGcStateTime(g, lastgcstate, now - lasttimestamp, assist); - - lasttimestamp = now; - lastgcstate = g->gcstate; - } - } while (lim > 0 && g->gcstate != GCSpause); - } + g->gcstats.currcycle.explicitwork += work; recordGcStateTime(g, lastgcstate, lua_clock() - lasttimestamp, assist); @@ -1037,14 +900,7 @@ void luaC_step(lua_State* L, bool assist) g->GCthreshold -= debt; } - if (FFlag::LuauConsolidatedStep) - { - GC_INTERRUPT(lastgcstate); - } - else - { - GC_INTERRUPT(g->gcstate); - } + GC_INTERRUPT(lastgcstate); } void luaC_fullgc(lua_State* L) @@ -1070,10 +926,7 @@ void luaC_fullgc(lua_State* L) while (g->gcstate != GCSpause) { LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); - if (FFlag::LuauConsolidatedStep) - gcstep(L, SIZE_MAX); - else - singlestep(L); + gcstep(L, SIZE_MAX); } finishGcCycleStats(g); @@ -1084,10 +937,7 @@ void luaC_fullgc(lua_State* L) markroot(L); while (g->gcstate != GCSpause) { - if (FFlag::LuauConsolidatedStep) - gcstep(L, SIZE_MAX); - else - singlestep(L); + gcstep(L, SIZE_MAX); } /* reclaim as much buffer memory as possible (shrinkbuffers() called during sweep is incremental) */ shrinkbuffersfull(L); diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index a9db37272..80a34483a 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -8,6 +8,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauStrPackUBCastFix, false) + /* macro to `unsign' a character */ #define uchar(c) ((unsigned char)(c)) @@ -1404,10 +1406,20 @@ static int str_pack(lua_State* L) } case Kuint: { /* unsigned integers */ - unsigned long long n = (unsigned long long)luaL_checknumber(L, arg); - if (size < SZINT) /* need overflow check? */ - luaL_argcheck(L, n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); - packint(&b, n, h.islittle, size, 0); + if (FFlag::LuauStrPackUBCastFix) + { + long long n = (long long)luaL_checknumber(L, arg); + if (size < SZINT) /* need overflow check? */ + luaL_argcheck(L, (unsigned long long)n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); + packint(&b, (unsigned long long)n, h.islittle, size, 0); + } + else + { + unsigned long long n = (unsigned long long)luaL_checknumber(L, arg); + if (size < SZINT) /* need overflow check? */ + luaL_argcheck(L, n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); + packint(&b, n, h.islittle, size, 0); + } break; } case Kfloat: diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 883442ae0..07d22d596 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -30,6 +30,7 @@ LUAU_FASTFLAGVARIABLE(LuauArrayBoundary, false) #define MAXBITS 26 #define MAXSIZE (1 << MAXBITS) +static_assert(offsetof(LuaNode, val) == 0, "Unexpected Node memory layout, pointer cast in gval2slot is incorrect"); // TKey is bitpacked for memory efficiency so we need to validate bit counts for worst case static_assert(TKey{{NULL}, 0, LUA_TDEADKEY, 0}.tt == LUA_TDEADKEY, "not enough bits for tt"); static_assert(TKey{{NULL}, 0, LUA_TNIL, MAXSIZE - 1}.next == MAXSIZE - 1, "not enough bits for next"); diff --git a/VM/src/ltable.h b/VM/src/ltable.h index f98d87b1d..45061443e 100644 --- a/VM/src/ltable.h +++ b/VM/src/ltable.h @@ -9,7 +9,6 @@ #define gval(n) (&(n)->val) #define gnext(n) ((n)->key.next) -static_assert(offsetof(LuaNode, val) == 0, "Unexpected Node memory layout, pointer cast below is incorrect"); #define gval2slot(t, v) int(cast_to(LuaNode*, static_cast(v)) - t->node) LUAI_FUNC const TValue* luaH_getnum(Table* t, int key); diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index de5788eb6..370258189 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -9,8 +9,6 @@ #include "ldebug.h" #include "lvm.h" -LUAU_FASTFLAGVARIABLE(LuauTableFreeze, false) - static int foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); @@ -491,9 +489,6 @@ static int tclear(lua_State* L) static int tfreeze(lua_State* L) { - if (!FFlag::LuauTableFreeze) - luaG_runerror(L, "table.freeze is disabled"); - luaL_checktype(L, 1, LUA_TTABLE); luaL_argcheck(L, !lua_getreadonly(L, 1), 1, "table is already frozen"); luaL_argcheck(L, !luaL_getmetafield(L, 1, "__metatable"), 1, "table has a protected metatable"); @@ -506,9 +501,6 @@ static int tfreeze(lua_State* L) static int tisfrozen(lua_State* L) { - if (!FFlag::LuauTableFreeze) - luaG_runerror(L, "table.isfrozen is disabled"); - luaL_checktype(L, 1, LUA_TTABLE); lua_pushboolean(L, lua_getreadonly(L, 1)); diff --git a/bench/tests/chess.lua b/bench/tests/chess.lua index 87b9abfd4..f6ae2cc6b 100644 --- a/bench/tests/chess.lua +++ b/bench/tests/chess.lua @@ -205,38 +205,48 @@ function Bitboard:empty() return self.h == 0 and self.l == 0 end -function Bitboard:ctz() - local target = self.l - local offset = 0 - local result = 0 - - if target == 0 then - target = self.h - result = 32 +if not bit32.countrz then + local function ctz(v) + if v == 0 then return 32 end + local offset = 0 + while bit32.extract(v, offset) == 0 do + offset = offset + 1 + end + return offset end - - if target == 0 then - return 64 + function Bitboard:ctz() + local result = ctz(self.l) + if result == 32 then + return ctz(self.h) + 32 + else + return result + end end - - while bit32.extract(target, offset) == 0 do - offset = offset + 1 + function Bitboard:ctzafter(start) + start = start + 1 + if start < 32 then + for i=start,31 do + if bit32.extract(self.l, i) == 1 then return i end + end + end + for i=math.max(32,start),63 do + if bit32.extract(self.h, i-32) == 1 then return i end + end + return 64 end - - return result + offset -end - -function Bitboard:ctzafter(start) - start = start + 1 - if start < 32 then - for i=start,31 do - if bit32.extract(self.l, i) == 1 then return i end +else + function Bitboard:ctz() + local result = bit32.countrz(self.l) + if result == 32 then + return bit32.countrz(self.h) + 32 + else + return result end end - for i=math.max(32,start),63 do - if bit32.extract(self.h, i-32) == 1 then return i end + function Bitboard:ctzafter(start) + local masked = self:band(Bitboard.full:lshift(start+1)) + return masked:ctz() end - return 64 end @@ -245,7 +255,7 @@ function Bitboard:lshift(amt) if amt == 0 then return self end if amt > 31 then - return Bitboard.from(0, bit32.lshift(self.l, amt-31)) + return Bitboard.from(0, bit32.lshift(self.l, amt-32)) end local l = bit32.lshift(self.l, amt) @@ -832,12 +842,12 @@ end local testCases = {} local function addTest(...) table.insert(testCases, {...}) end -addTest(StartingFen, 3, 8902) -addTest("r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 0", 2, 2039) -addTest("8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 0", 3, 2812) -addTest("r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", 3, 9467) -addTest("rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", 2, 1486) -addTest("r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", 2, 2079) +addTest(StartingFen, 2, 400) +addTest("r3k2r/p1ppqpb1/bn2pnp1/3PN3/1p2P3/2N2Q1p/PPPBBPPP/R3K2R w KQkq - 0 0", 1, 48) +addTest("8/2p5/3p4/KP5r/1R3p1k/8/4P1P1/8 w - - 0 0", 2, 191) +addTest("r3k2r/Pppp1ppp/1b3nbN/nP6/BBP1P3/q4N2/Pp1P2PP/R2Q1RK1 w kq - 0 1", 2, 264) +addTest("rnbq1k1r/pp1Pbppp/2p5/8/2B5/8/PPP1NnPP/RNBQK2R w KQ - 1 8", 1, 44) +addTest("r4rk1/1pp1qppp/p1np1n2/2b1p1B1/2B1P1b1/P1NP1N2/1PP1QPPP/R4RK1 w - - 0 10", 1, 46) local function chess() diff --git a/fuzz/luau.proto b/fuzz/luau.proto index 41a1d077f..c78fcf31c 100644 --- a/fuzz/luau.proto +++ b/fuzz/luau.proto @@ -19,6 +19,7 @@ message Expr { ExprTable table = 13; ExprUnary unary = 14; ExprBinary binary = 15; + ExprIfElse ifelse = 16; } } @@ -149,6 +150,12 @@ message ExprBinary { required Expr right = 3; } +message ExprIfElse { + required Expr cond = 1; + required Expr then = 2; + required Expr else = 3; +} + message LValue { oneof lvalue_oneof { ExprLocal local = 1; diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index 6c230b67f..c85fac7d3 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -11,6 +11,7 @@ #include "Luau/BytecodeBuilder.h" #include "Luau/Common.h" #include "Luau/ToString.h" +#include "Luau/Transpiler.h" #include "lua.h" #include "lualib.h" @@ -23,6 +24,7 @@ const bool kFuzzLinter = true; const bool kFuzzTypeck = true; const bool kFuzzVM = true; const bool kFuzzTypes = true; +const bool kFuzzTranspile = true; static_assert(!(kFuzzVM && !kFuzzCompiler), "VM requires the compiler!"); @@ -242,6 +244,11 @@ DEFINE_PROTO_FUZZER(const luau::StatBlock& message) } } + if (kFuzzTranspile && parseResult.root) + { + transpileWithTypes(*parseResult.root); + } + // run resulting bytecode if (kFuzzVM && bytecode.size()) { diff --git a/fuzz/protoprint.cpp b/fuzz/protoprint.cpp index 2c861a553..e61b69365 100644 --- a/fuzz/protoprint.cpp +++ b/fuzz/protoprint.cpp @@ -476,6 +476,16 @@ struct ProtoToLuau print(expr.right()); } + void print(const luau::ExprIfElse& expr) + { + source += " if "; + print(expr.cond()); + source += " then "; + print(expr.then()); + source += " else "; + print(expr.else_()); + } + void print(const luau::LValue& expr) { if (expr.has_local()) diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index dd49e675e..aa53a92b2 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -45,7 +45,6 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "prop") TEST_CASE_FIXTURE(DocumentationSymbolFixture, "event_callback_arg") { ScopedFastFlag sffs[] = { - {"LuauDontMutatePersistentFunctions", true}, {"LuauPersistDefinitionFileTypes", true}, }; diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 44b8362df..8a7798f3a 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1287,9 +1287,6 @@ local e: (n: n@5 TEST_CASE_FIXTURE(ACFixture, "generic_types") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); - ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); - check(R"( function f(a: T@1 local b: string = "don't trip" diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index bbac3302c..7f03019c3 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -13,6 +13,7 @@ LUAU_FASTFLAG(LuauPreloadClosures) LUAU_FASTFLAG(LuauPreloadClosuresFenv) LUAU_FASTFLAG(LuauPreloadClosuresUpval) +LUAU_FASTFLAG(LuauGenericSpecialGlobals) using namespace Luau; @@ -1168,6 +1169,17 @@ RETURN R0 1 )"); } +TEST_CASE("ConstantFoldStringLen") +{ + CHECK_EQ("\n" + compileFunction0("return #'string', #'', #'a', #('b')"), R"( +LOADN R0 6 +LOADN R1 0 +LOADN R2 1 +LOADN R3 1 +RETURN R0 4 +)"); +} + TEST_CASE("ConstantFoldCompare") { // ordered comparisons @@ -3659,4 +3671,118 @@ RETURN R0 0 )"); } +TEST_CASE("LuauGenericSpecialGlobals") +{ + const char* source = R"( +print() +Game.print() +Workspace.print() +_G.print() +game.print() +plugin.print() +script.print() +shared.print() +workspace.print() +)"; + + { + ScopedFastFlag genericSpecialGlobals{"LuauGenericSpecialGlobals", false}; + + // Check Roblox globals are here + CHECK_EQ("\n" + compileFunction0(source), R"( +GETIMPORT R0 1 +CALL R0 0 0 +GETIMPORT R1 3 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 5 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 7 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 9 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 11 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 13 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 15 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 17 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +RETURN R0 0 +)"); + } + + ScopedFastFlag genericSpecialGlobals{"LuauGenericSpecialGlobals", true}; + + // Check Roblox globals are no longer here + CHECK_EQ("\n" + compileFunction0(source), R"( +GETIMPORT R0 1 +CALL R0 0 0 +GETIMPORT R0 3 +CALL R0 0 0 +GETIMPORT R0 5 +CALL R0 0 0 +GETIMPORT R1 7 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R0 9 +CALL R0 0 0 +GETIMPORT R0 11 +CALL R0 0 0 +GETIMPORT R0 13 +CALL R0 0 0 +GETIMPORT R0 15 +CALL R0 0 0 +GETIMPORT R0 17 +CALL R0 0 0 +RETURN R0 0 +)"); + + // Check we can add them back + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::CompileOptions options; + const char* mutableGlobals[] = {"Game", "Workspace", "game", "plugin", "script", "shared", "workspace", NULL}; + options.mutableGlobals = &mutableGlobals[0]; + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +GETIMPORT R0 1 +CALL R0 0 0 +GETIMPORT R1 3 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 5 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 7 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 9 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 11 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 13 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 15 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +GETIMPORT R1 17 +GETTABLEKS R0 R1 K0 +CALL R0 0 0 +RETURN R0 0 +)"); +} + TEST_SUITE_END(); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 06b3c5237..c1b790b9c 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -240,8 +240,6 @@ TEST_CASE("Math") TEST_CASE("Table") { - ScopedFastFlag sff("LuauTableFreeze", true); - runConformance("nextvar.lua"); } @@ -322,6 +320,8 @@ TEST_CASE("GC") TEST_CASE("Bitwise") { + ScopedFastFlag sff("LuauBit32Count", true); + runConformance("bitwise.lua"); } @@ -359,6 +359,8 @@ TEST_CASE("PCall") TEST_CASE("Pack") { + ScopedFastFlag sff{ "LuauStrPackUBCastFix", true }; + runConformance("tpack.lua"); } diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 37f1b60b8..7ba40c503 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -479,10 +479,6 @@ return foo1 TEST_CASE_FIXTURE(Fixture, "UnknownType") { - ScopedFastFlag sff{"LuauLinterUnknownTypeVectorAware", true}; - - SourceModule sm; - unfreeze(typeChecker.globalTypes); TableTypeVar::Props instanceProps{ {"ClassName", {typeChecker.anyType}}, @@ -1400,8 +1396,6 @@ end TEST_CASE_FIXTURE(Fixture, "TableOperations") { - ScopedFastFlag sff("LuauLinterTableMoveZero", true); - LintResult result = lintTyped(R"( local t = {} local tt = {} diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index a80718e47..e3e6ce6d8 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -7,6 +7,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauFixAmbiguousErrorRecoveryInAssign) + using namespace Luau; namespace @@ -625,10 +627,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_messages") )"), "Cannot have more than one table indexer"); - ScopedFastFlag sffs1{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauGenericFunctionsParserFix", true}; - ScopedFastFlag sffs3{"LuauParseGenericFunctions", true}; - CHECK_EQ(getParseError(R"( type T = foo )"), @@ -1624,6 +1622,20 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_confusing_function_call") "statements"); CHECK(result3.errors.size() == 1); + + auto result4 = matchParseError(R"( + local t = {} + function f() return t end + t.x, (f) + ().y = 5, 6 + )", + "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of new statement; use ';' to separate " + "statements"); + + if (FFlag::LuauFixAmbiguousErrorRecoveryInAssign) + CHECK(result4.errors.size() == 1); + else + CHECK(result4.errors.size() == 5); } TEST_CASE_FIXTURE(Fixture, "parse_error_varargs") @@ -1824,9 +1836,6 @@ TEST_CASE_FIXTURE(Fixture, "variadic_definition_parsing") TEST_CASE_FIXTURE(Fixture, "generic_pack_parsing") { - // Doesn't need LuauGenericFunctions - ScopedFastFlag sffs{"LuauParseGenericFunctions", true}; - ParseResult result = parseEx(R"( function f(...: a...) end @@ -1861,9 +1870,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_pack_parsing") TEST_CASE_FIXTURE(Fixture, "generic_function_declaration_parsing") { - // Doesn't need LuauGenericFunctions - ScopedFastFlag sffs{"LuauParseGenericFunctions", true}; - ParseResult result = parseEx(R"( declare function f() )"); @@ -1953,12 +1959,7 @@ TEST_CASE_FIXTURE(Fixture, "function_type_named_arguments") matchParseError("type MyFunc = (a: number, b: string, c: number) -> (d: number, e: string, f: number)", "Expected '->' when parsing function type, got "); - { - ScopedFastFlag luauParseGenericFunctions{"LuauParseGenericFunctions", true}; - ScopedFastFlag luauGenericFunctionsParserFix{"LuauGenericFunctionsParserFix", true}; - - matchParseError("type MyFunc = (number) -> (d: number) -> number", "Expected '->' when parsing function type, got '<'"); - } + matchParseError("type MyFunc = (number) -> (d: number) -> number", "Expected '->' when parsing function type, got '<'"); } TEST_SUITE_END(); @@ -2362,8 +2363,6 @@ type Fn = ( CHECK_EQ("Expected '->' when parsing function type, got ')'", e.getErrors().front().getMessage()); } - ScopedFastFlag sffs3{"LuauParseGenericFunctions", true}; - try { parse(R"(type Fn = (any, string | number | ()) -> any)"); @@ -2397,8 +2396,6 @@ TEST_CASE_FIXTURE(Fixture, "AstName_comparison") TEST_CASE_FIXTURE(Fixture, "generic_type_list_recovery") { - ScopedFastFlag luauParseGenericFunctions{"LuauParseGenericFunctions", true}; - try { parse(R"( @@ -2521,7 +2518,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") TEST_CASE_FIXTURE(Fixture, "parse_type_pack_type_parameters") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); AstStat* stat = parse(R"( @@ -2534,4 +2530,9 @@ type C = Packed<(number, X...)> REQUIRE(stat != nullptr); } +TEST_CASE_FIXTURE(Fixture, "function_type_matching_parenthesis") +{ + matchParseError("local a: (number -> string", "Expected ')' (to close '(' at column 13), got '->'"); +} + TEST_SUITE_END(); diff --git a/tests/Predicate.test.cpp b/tests/Predicate.test.cpp index bb5a93c54..7081693e2 100644 --- a/tests/Predicate.test.cpp +++ b/tests/Predicate.test.cpp @@ -33,8 +33,6 @@ TEST_SUITE_BEGIN("Predicate"); TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - RefinementMap m{ {"b", typeChecker.stringType}, {"c", typeChecker.numberType}, @@ -61,8 +59,6 @@ TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order") TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order2") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - RefinementMap m{ {"a", typeChecker.stringType}, {"b", typeChecker.stringType}, @@ -89,8 +85,6 @@ TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order2") TEST_CASE_FIXTURE(Fixture, "one_map_has_overlap_at_end_whereas_other_has_it_in_start") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - RefinementMap m{ {"a", typeChecker.stringType}, {"b", typeChecker.numberType}, diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index e18bf7cdd..b076e9ad7 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -259,9 +259,6 @@ TEST_CASE_FIXTURE(Fixture, "function_type_with_argument_names") TEST_CASE_FIXTURE(Fixture, "function_type_with_argument_names_generic") { - ScopedFastFlag luauGenericFunctions{"LuauGenericFunctions", true}; - ScopedFastFlag luauParseGenericFunctions{"LuauParseGenericFunctions", true}; - CheckResult result = check("local function f(n: number, ...: a...): (a...) return ... end"); LUAU_REQUIRE_NO_ERRORS(result); @@ -340,10 +337,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") TEST_CASE_FIXTURE(Fixture, "toStringDetailed2") { - ScopedFastFlag sff[] = { - {"LuauGenericFunctions", true}, - }; - CheckResult result = check(R"( local base = {} function base:one() return 1 end @@ -468,8 +461,6 @@ TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_inters TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") { - ScopedFastFlag luauInstantiatedTypeParamRecursion{"LuauInstantiatedTypeParamRecursion", true}; - TypeVar tableTy{TableTypeVar{}}; TableTypeVar* ttv = getMutable(&tableTy); ttv->name = "Table"; diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index bfff60f9d..928c03a31 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -21,7 +21,7 @@ local function isPortal(element) return false end - return element.component==Core.Portal + return element.component == Core.Portal end )"; @@ -223,12 +223,24 @@ TEST_CASE("escaped_strings") CHECK_EQ(code, transpile(code).code); } +TEST_CASE("escaped_strings_2") +{ + const std::string code = R"( local s="\a\b\f\n\r\t\v\'\"\\" )"; + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("need_a_space_between_number_literals_and_dots") { const std::string code = R"( return point and math.ceil(point* 100000* 100)/ 100000 .. '%'or '' )"; CHECK_EQ(code, transpile(code).code); } +TEST_CASE("binary_keywords") +{ + const std::string code = "local c = a0 ._ or b0 ._"; + CHECK_EQ(code, transpile(code).code); +} + TEST_CASE("do_blocks") { const std::string code = R"( @@ -364,10 +376,10 @@ TEST_CASE_FIXTURE(Fixture, "type_lists_should_be_emitted_correctly") )"; std::string expected = R"( - local a:(string,number,...string)->(string,...number)=function(a:string,b:number,...:...string): (string,...number) + local a:(string,number,...string)->(string,...number)=function(a:string,b:number,...:string): (string,...number) end - local b:(...string)->(...number)=function(...:...string): ...number + local b:(...string)->(...number)=function(...:string): ...number end local c:()->()=function(): () @@ -400,4 +412,238 @@ TEST_CASE_FIXTURE(Fixture, "function_type_location") CHECK_EQ(expected, actual); } +TEST_CASE_FIXTURE(Fixture, "transpile_type_assertion") +{ + std::string code = "local a = 5 :: number"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_if_then_else") +{ + ScopedFastFlag luauIfElseExpressionBaseSupport("LuauIfElseExpressionBaseSupport", true); + + std::string code = "local a = if 1 then 2 else 3"; + + CHECK_EQ(code, transpile(code).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_type_reference_import") +{ + fileResolver.source["game/A"] = R"( +export type Type = { a: number } +return {} + )"; + + std::string code = R"( +local Import = require(game.A) +local a: Import.Type + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_type_packs") +{ + ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); + ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); + + std::string code = R"( +type Packed = (T...)->(T...) +local a: Packed<> +local b: Packed<(number, string)> + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_union_type_nested") +{ + std::string code = "local a: ((number)->(string))|((string)->(string))"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_union_type_nested_2") +{ + std::string code = "local a: (number&string)|(string&boolean)"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_union_type_nested_3") +{ + std::string code = "local a: nil | (string & number)"; + + CHECK_EQ("local a: ( string & number)?", transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_intersection_type_nested") +{ + std::string code = "local a: ((number)->(string))&((string)->(string))"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_intersection_type_nested_2") +{ + std::string code = "local a: (number|string)&(string|boolean)"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_varargs") +{ + std::string code = "local function f(...) return ... end"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_index_expr") +{ + std::string code = "local a = {1, 2, 3} local b = a[2]"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_unary") +{ + std::string code = R"( +local a = 1 +local b = -1 +local c = true +local d = not c +local e = 'hello' +local d = #e + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_break_continue") +{ + std::string code = R"( +local a, b, c +repeat + if a then break end + if b then continue end +until c + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_compound_assignmenr") +{ + std::string code = R"( +local a = 1 +a += 2 +a -= 3 +a *= 4 +a /= 5 +a %= 6 +a ^= 7 +a ..= ' - result' + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_assign_multiple") +{ + std::string code = "a, b, c = 1, 2, 3"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_generic_function") +{ + ScopedFastFlag luauParseGenericFunctionTypeBegin("LuauParseGenericFunctionTypeBegin", true); + + std::string code = R"( +local function foo(a: T, ...: S...) return 1 end +local f: (T, S...)->(number) = foo + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_union_reverse") +{ + std::string code = "local a: nil | number"; + + CHECK_EQ("local a: number?", transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_for_in_multiple") +{ + std::string code = "for k,v in next,{}do print(k,v) end"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_error_expr") +{ + std::string code = "local a = f:-"; + + auto allocator = Allocator{}; + auto names = AstNameTable{allocator}; + ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, {}); + + CHECK_EQ("local a = (error-expr: f.%error-id%)-(error-expr)", transpileWithTypes(*parseResult.root)); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_error_stat") +{ + std::string code = "-"; + + auto allocator = Allocator{}; + auto names = AstNameTable{allocator}; + ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, {}); + + CHECK_EQ("(error-stat: (error-expr))", transpileWithTypes(*parseResult.root)); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_error_type") +{ + std::string code = "local a: "; + + auto allocator = Allocator{}; + auto names = AstNameTable{allocator}; + ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, {}); + + CHECK_EQ("local a:%error-type%", transpileWithTypes(*parseResult.root)); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_parse_error") +{ + std::string code = "local a = -"; + + auto result = transpile(code); + CHECK_EQ("", result.code); + CHECK_EQ("Expected identifier when parsing expression, got ", result.parseError); +} + +TEST_CASE_FIXTURE(Fixture, "transpile_to_string") +{ + std::string code = "local a: string = 'hello'"; + + auto allocator = Allocator{}; + auto names = AstNameTable{allocator}; + ParseResult parseResult = Parser::parse(code.data(), code.size(), names, allocator, {}); + + REQUIRE(parseResult.root); + REQUIRE(parseResult.root->body.size == 1); + AstStatLocal* statLocal = parseResult.root->body.data[0]->as(); + REQUIRE(statLocal); + CHECK_EQ("local a: string = 'hello'", toString(statLocal)); + REQUIRE(statLocal->vars.size == 1); + AstLocal* local = statLocal->vars.data[0]; + REQUIRE(local->annotation); + CHECK_EQ("string", toString(local->annotation)); + REQUIRE(statLocal->values.size == 1); + AstExpr* expr = statLocal->values.data[0]; + CHECK_EQ("'hello'", toString(expr)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index f580604ca..c27f8083b 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -247,9 +247,6 @@ TEST_CASE_FIXTURE(Fixture, "export_type_and_type_alias_are_duplicates") TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") { - ScopedFastFlag sffs3{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( type Node = { value: T, child: Node? } diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 17e32e9f9..1e2eae147 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -444,8 +444,6 @@ TEST_CASE_FIXTURE(Fixture, "os_time_takes_optional_date_table") TEST_CASE_FIXTURE(Fixture, "thread_is_a_type") { - ScopedFastFlag sff{"LuauDontMutatePersistentFunctions", true}; - CheckResult result = check(R"( local co = coroutine.create(function() end) )"); @@ -456,8 +454,6 @@ TEST_CASE_FIXTURE(Fixture, "thread_is_a_type") TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") { - ScopedFastFlag sff{"LuauDontMutatePersistentFunctions", true}; - CheckResult result = check(R"( local function nifty(x, y) print(x, y) @@ -476,8 +472,6 @@ TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") TEST_CASE_FIXTURE(Fixture, "coroutine_wrap_anything_goes") { - ScopedFastFlag sff{"LuauDontMutatePersistentFunctions", true}; - CheckResult result = check(R"( --!nonstrict local function nifty(x, y) @@ -822,8 +816,6 @@ TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_posi TEST_CASE_FIXTURE(Fixture, "dont_add_definitions_to_persistent_types") { - ScopedFastFlag sff{"LuauDontMutatePersistentFunctions", true}; - CheckResult result = check(R"( local f = math.sin local function g(x) return math.sin(x) end diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index eabf7e65d..1ff23fe69 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -232,8 +232,6 @@ TEST_CASE_FIXTURE(ClassFixture, "can_assign_to_prop_of_base_class") TEST_CASE_FIXTURE(ClassFixture, "can_read_prop_of_base_class_using_string") { - ScopedFastFlag luauClassPropertyAccessAsString("LuauClassPropertyAccessAsString", true); - CheckResult result = check(R"( local c = ChildClass.New() local x = 1 + c["BaseField"] @@ -244,8 +242,6 @@ TEST_CASE_FIXTURE(ClassFixture, "can_read_prop_of_base_class_using_string") TEST_CASE_FIXTURE(ClassFixture, "can_assign_to_prop_of_base_class_using_string") { - ScopedFastFlag luauClassPropertyAccessAsString("LuauClassPropertyAccessAsString", true); - CheckResult result = check(R"( local c = ChildClass.New() c["BaseField"] = 444 @@ -451,4 +447,25 @@ b.X = 2 -- real Vector2.X is also read-only CHECK_EQ("Value of type 'Vector2?' could be nil", toString(result.errors[3])); } +TEST_CASE_FIXTURE(ClassFixture, "detailed_class_unification_error") +{ + ScopedFastFlag luauExtendedClassMismatchError{"LuauExtendedClassMismatchError", true}; + + CheckResult result = check(R"( +local function foo(v) + return v.X :: number + string.len(v.Y) +end + +local a: Vector2 +local b = foo +b(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type 'Vector2' could not be converted into '{- X: a, Y: string -}' +caused by: + Property 'Y' is not compatible. Type 'number' could not be converted into 'string')", + toString(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 41e3e45ad..2652486be 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -171,9 +171,6 @@ TEST_CASE_FIXTURE(Fixture, "no_cyclic_defined_classes") TEST_CASE_FIXTURE(Fixture, "declaring_generic_functions") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; - loadDefinition(R"( declare function f(a: a, b: b): string declare function g(...: a...): b... diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 581375a12..de2f01544 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -13,8 +13,6 @@ TEST_SUITE_BEGIN("GenericsTests"); TEST_CASE_FIXTURE(Fixture, "check_generic_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( function id(x:a): a return x @@ -27,8 +25,6 @@ TEST_CASE_FIXTURE(Fixture, "check_generic_function") TEST_CASE_FIXTURE(Fixture, "check_generic_local_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local function id(x:a): a return x @@ -41,10 +37,6 @@ TEST_CASE_FIXTURE(Fixture, "check_generic_local_function") TEST_CASE_FIXTURE(Fixture, "check_generic_typepack_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauGenericVariadicsUnification", true}; - ScopedFastFlag sffs5{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function id(...: a...): (a...) return ... end local x: string, y: boolean = id("hi", true) @@ -56,8 +48,6 @@ TEST_CASE_FIXTURE(Fixture, "check_generic_typepack_function") TEST_CASE_FIXTURE(Fixture, "types_before_typepacks") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( function f() end )"); @@ -66,8 +56,6 @@ TEST_CASE_FIXTURE(Fixture, "types_before_typepacks") TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local function id(x:a):a return x end local f: (a)->a = id @@ -79,7 +67,6 @@ TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "inferred_local_vars_can_be_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function id(x) return x end print("This is bogus") -- TODO: CLI-39916 @@ -92,7 +79,6 @@ TEST_CASE_FIXTURE(Fixture, "inferred_local_vars_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_instantiated_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function id(x) return x end print("This is bogus") -- TODO: CLI-39916 @@ -104,8 +90,6 @@ TEST_CASE_FIXTURE(Fixture, "local_vars_can_be_instantiated_polytypes") TEST_CASE_FIXTURE(Fixture, "properties_can_be_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local t = {} t.m = function(x: a):a return x end @@ -117,8 +101,6 @@ TEST_CASE_FIXTURE(Fixture, "properties_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "properties_can_be_instantiated_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local t: { m: (number)->number } = { m = function(x:number) return x+1 end } local function id(x:a):a return x end @@ -129,8 +111,6 @@ TEST_CASE_FIXTURE(Fixture, "properties_can_be_instantiated_polytypes") TEST_CASE_FIXTURE(Fixture, "check_nested_generic_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local function f() local function id(x:a): a @@ -145,8 +125,6 @@ TEST_CASE_FIXTURE(Fixture, "check_nested_generic_function") TEST_CASE_FIXTURE(Fixture, "check_recursive_generic_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local function id(x:a):a local y: string = id("hi") @@ -159,8 +137,6 @@ TEST_CASE_FIXTURE(Fixture, "check_recursive_generic_function") TEST_CASE_FIXTURE(Fixture, "check_mutual_generic_functions") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( local id2 local function id1(x:a):a @@ -179,8 +155,6 @@ TEST_CASE_FIXTURE(Fixture, "check_mutual_generic_functions") TEST_CASE_FIXTURE(Fixture, "generic_functions_in_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( type T = { id: (a) -> a } local x: T = { id = function(x:a):a return x end } @@ -192,8 +166,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_functions_in_types") TEST_CASE_FIXTURE(Fixture, "generic_factories") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( type T = { id: (a) -> a } type Factory = { build: () -> T } @@ -215,10 +187,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_factories") TEST_CASE_FIXTURE(Fixture, "factories_of_generics") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauRankNTypes", true}; - CheckResult result = check(R"( type T = { id: (a) -> a } type Factory = { build: () -> T } @@ -241,7 +209,6 @@ TEST_CASE_FIXTURE(Fixture, "factories_of_generics") TEST_CASE_FIXTURE(Fixture, "infer_generic_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( function id(x) return x @@ -265,7 +232,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_function") TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function id(x) return x @@ -289,7 +255,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_local_function") TEST_CASE_FIXTURE(Fixture, "infer_nested_generic_function") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function f() local function id(x) @@ -304,7 +269,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_nested_generic_function") TEST_CASE_FIXTURE(Fixture, "infer_generic_methods") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local x = {} function x:id(x) return x end @@ -316,7 +280,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_methods") TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local x = {} function x:id(x) return x end @@ -331,8 +294,6 @@ TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") TEST_CASE_FIXTURE(Fixture, "infer_generic_property") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauRankNTypes", true}; CheckResult result = check(R"( local t = {} t.m = function(x) return x end @@ -344,9 +305,6 @@ TEST_CASE_FIXTURE(Fixture, "infer_generic_property") TEST_CASE_FIXTURE(Fixture, "function_arguments_can_be_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauRankNTypes", true}; CheckResult result = check(R"( local function f(g: (a)->a) local x: number = g(37) @@ -358,9 +316,6 @@ TEST_CASE_FIXTURE(Fixture, "function_arguments_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "function_results_can_be_polytypes") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauRankNTypes", true}; CheckResult result = check(R"( local function f() : (a)->a local function id(x:a):a return x end @@ -372,9 +327,6 @@ TEST_CASE_FIXTURE(Fixture, "function_results_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "type_parameters_can_be_polytypes") { - ScopedFastFlag sffs1{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauRankNTypes", true}; CheckResult result = check(R"( local function id(x:a):a return x end local f: (a)->a = id(id) @@ -384,7 +336,6 @@ TEST_CASE_FIXTURE(Fixture, "type_parameters_can_be_polytypes") TEST_CASE_FIXTURE(Fixture, "dont_leak_generic_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function f(y) -- this will only typecheck if we infer z: any @@ -406,7 +357,6 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_generic_types") TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( local function f(y) local z = y @@ -423,12 +373,6 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") TEST_CASE_FIXTURE(Fixture, "dont_substitute_bound_types") { - ScopedFastFlag sffs[] = { - {"LuauGenericFunctions", true}, - {"LuauParseGenericFunctions", true}, - {"LuauRankNTypes", true}, - }; - CheckResult result = check(R"( type T = { m: (a) -> T } function f(t : T) @@ -440,10 +384,6 @@ TEST_CASE_FIXTURE(Fixture, "dont_substitute_bound_types") TEST_CASE_FIXTURE(Fixture, "dont_unify_bound_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauRankNTypes", true}; - CheckResult result = check(R"( type F = () -> (a, b) -> a type G = (b, b) -> b @@ -470,7 +410,6 @@ TEST_CASE_FIXTURE(Fixture, "mutable_state_polymorphism") // Replaying the classic problem with polymorphism and mutable state in Luau // See, e.g. Tofte (1990) // https://www.sciencedirect.com/science/article/pii/089054019090018D. - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( --!strict -- Our old friend the polymorphic identity function @@ -508,7 +447,6 @@ TEST_CASE_FIXTURE(Fixture, "mutable_state_polymorphism") TEST_CASE_FIXTURE(Fixture, "rank_N_types_via_typeof") { - ScopedFastFlag sffs{"LuauGenericFunctions", false}; CheckResult result = check(R"( --!strict local function id(x) return x end @@ -531,8 +469,6 @@ TEST_CASE_FIXTURE(Fixture, "rank_N_types_via_typeof") TEST_CASE_FIXTURE(Fixture, "duplicate_generic_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs2{"LuauParseGenericFunctions", true}; CheckResult result = check(R"( function f(x:a):a return x end )"); @@ -541,7 +477,6 @@ TEST_CASE_FIXTURE(Fixture, "duplicate_generic_types") TEST_CASE_FIXTURE(Fixture, "duplicate_generic_type_packs") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( function f() end )"); @@ -550,7 +485,6 @@ TEST_CASE_FIXTURE(Fixture, "duplicate_generic_type_packs") TEST_CASE_FIXTURE(Fixture, "typepacks_before_types") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; CheckResult result = check(R"( function f() end )"); @@ -559,9 +493,6 @@ TEST_CASE_FIXTURE(Fixture, "typepacks_before_types") TEST_CASE_FIXTURE(Fixture, "variadic_generics") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function f(...: a) end @@ -573,9 +504,6 @@ TEST_CASE_FIXTURE(Fixture, "variadic_generics") TEST_CASE_FIXTURE(Fixture, "generic_type_pack_syntax") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function f(...: a...): (a...) return ... end )"); @@ -586,10 +514,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_type_pack_syntax") TEST_CASE_FIXTURE(Fixture, "generic_type_pack_parentheses") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauGenericVariadicsUnification", true}; - ScopedFastFlag sffs5{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function f(...: a...): any return (...) end )"); @@ -599,9 +523,6 @@ TEST_CASE_FIXTURE(Fixture, "generic_type_pack_parentheses") TEST_CASE_FIXTURE(Fixture, "better_mismatch_error_messages") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs5{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function f(...: T...) return ... @@ -626,9 +547,6 @@ TEST_CASE_FIXTURE(Fixture, "better_mismatch_error_messages") TEST_CASE_FIXTURE(Fixture, "reject_clashing_generic_and_pack_names") { - ScopedFastFlag sffs{"LuauGenericFunctions", true}; - ScopedFastFlag sffs3{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( function f() end )"); @@ -641,8 +559,6 @@ TEST_CASE_FIXTURE(Fixture, "reject_clashing_generic_and_pack_names") TEST_CASE_FIXTURE(Fixture, "instantiation_sharing_types") { - ScopedFastFlag sffs1{"LuauGenericFunctions", true}; - CheckResult result = check(R"( function f(z) local o = {} @@ -665,8 +581,6 @@ TEST_CASE_FIXTURE(Fixture, "instantiation_sharing_types") TEST_CASE_FIXTURE(Fixture, "quantification_sharing_types") { - ScopedFastFlag sffs1{"LuauGenericFunctions", true}; - CheckResult result = check(R"( function f(x) return {5} end function g(x, y) return f(x) end @@ -680,8 +594,6 @@ TEST_CASE_FIXTURE(Fixture, "quantification_sharing_types") TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types") { - ScopedFastFlag sffs1{"LuauGenericFunctions", true}; - CheckResult result = check(R"( type T = { x: {a}, y: {number} } local o1: T = { x = {true}, y = {5} } @@ -697,7 +609,6 @@ TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types") TEST_CASE_FIXTURE(Fixture, "bound_tables_do_not_clone_original_fields") { - ScopedFastFlag luauRankNTypes{"LuauRankNTypes", true}; ScopedFastFlag luauCloneBoundTables{"LuauCloneBoundTables", true}; CheckResult result = check(R"( diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 9685f4f35..893bc2b30 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -341,4 +341,43 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_setmetatable") LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_part") +{ + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; + + CheckResult result = check(R"( +type X = { x: number } +type Y = { y: number } +type Z = { z: number } + +type XYZ = X & Y & Z + +local a: XYZ = 3 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'number' could not be converted into 'X & Y & Z' +caused by: + Not all intersection parts are compatible. Type 'number' could not be converted into 'X')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_all") +{ + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; + + CheckResult result = check(R"( +type X = { x: number } +type Y = { y: number } +type Z = { z: number } + +type XYZ = X & Y & Z + +local a: XYZ +local b: number = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'X & Y & Z' could not be converted into 'number'; none of the intersection parts are compatible)"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 419da8ad1..e5c14dde8 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -194,9 +194,6 @@ TEST_CASE_FIXTURE(Fixture, "normal_conditional_expression_has_refinements") // Luau currently doesn't yet know how to allow assignments when the binding was refined. TEST_CASE_FIXTURE(Fixture, "while_body_are_also_refined") { - ScopedFastFlag sffs2{"LuauGenericFunctions", true}; - ScopedFastFlag sffs5{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( type Node = { value: T, child: Node? } @@ -596,11 +593,9 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") { ScopedFastFlag luauCloneCorrectlyBeforeMutatingTableType{"LuauCloneCorrectlyBeforeMutatingTableType", true}; - ScopedFastFlag luauFollowInTypeFunApply{"LuauFollowInTypeFunApply", true}; - ScopedFastFlag luauInstantiatedTypeParamRecursion{"LuauInstantiatedTypeParamRecursion", true}; // Mutability in type function application right now can create strange recursive types - // TODO: instantiation right now is problematic, it this example should either leave the Table type alone + // TODO: instantiation right now is problematic, in this example should either leave the Table type alone // or it should rename the type to 'Self' so that the result will be 'Self' CheckResult result = check(R"( type Table = { a: number } diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 36dcaa959..733fc39b3 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -7,7 +7,6 @@ #include "doctest.h" LUAU_FASTFLAG(LuauWeakEqConstraint) -LUAU_FASTFLAG(LuauOrPredicate) LUAU_FASTFLAG(LuauQuantifyInPlace2) using namespace Luau; @@ -133,11 +132,8 @@ TEST_CASE_FIXTURE(Fixture, "or_predicate_with_truthy_predicates") CHECK_EQ("string?", toString(requireTypeAtPosition({3, 26}))); CHECK_EQ("number?", toString(requireTypeAtPosition({4, 26}))); - if (FFlag::LuauOrPredicate) - { - CHECK_EQ("nil", toString(requireTypeAtPosition({6, 26}))); - CHECK_EQ("nil", toString(requireTypeAtPosition({7, 26}))); - } + CHECK_EQ("nil", toString(requireTypeAtPosition({6, 26}))); + CHECK_EQ("nil", toString(requireTypeAtPosition({7, 26}))); } TEST_CASE_FIXTURE(Fixture, "type_assertion_expr_carry_its_constraints") @@ -283,6 +279,8 @@ TEST_CASE_FIXTURE(Fixture, "assert_non_binary_expressions_actually_resolve_const TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_type_is_illegal") { + ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( local t: {x: number?} = {x = nil} @@ -293,7 +291,10 @@ TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_ty )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type '{| x: number? |}' could not be converted into '{| x: number |}'", toString(result.errors[0])); + CHECK_EQ(R"(Type '{| x: number? |}' could not be converted into '{| x: number |}' +caused by: + Property 'x' is not compatible. Type 'number?' could not be converted into 'number')", + toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_another_lvalue") @@ -749,8 +750,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") { - ScopedFastFlag sff{"LuauTypeGuardPeelsAwaySubclasses", true}; - CheckResult result = check(R"( local function f(x: Part | Folder | string) if typeof(x) == "Instance" then @@ -769,8 +768,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") { - ScopedFastFlag sff{"LuauTypeGuardPeelsAwaySubclasses", true}; - CheckResult result = check(R"( local function f(x: Part | Folder | Instance | string | Vector3 | any) if typeof(x) == "Instance" then @@ -789,8 +786,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( --!nonstrict @@ -811,11 +806,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") { - ScopedFastFlag sffs[] = { - {"LuauOrPredicate", true}, - {"LuauTypeGuardPeelsAwaySubclasses", true}, - }; - CheckResult result = check(R"( local function f(x: Part | Folder | string) if typeof(x) ~= "Instance" or not x:IsA("Part") then @@ -890,8 +880,6 @@ TEST_CASE_FIXTURE(Fixture, "type_guard_warns_on_no_overlapping_types_only_when_s TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local function f(a: number?, b: number?) if (not a) or (not b) then @@ -909,8 +897,6 @@ TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b") TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b2") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local function f(a: number?, b: number?) if not (a and b) then @@ -928,8 +914,6 @@ TEST_CASE_FIXTURE(Fixture, "not_a_or_not_b2") TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local function f(a: number?, b: number?) if (not a) and (not b) then @@ -947,8 +931,6 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b") TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local function f(a: number?, b: number?) if not (a or b) then @@ -966,8 +948,6 @@ TEST_CASE_FIXTURE(Fixture, "not_a_and_not_b2") TEST_CASE_FIXTURE(Fixture, "either_number_or_string") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local function f(x: any) if type(x) == "number" or type(x) == "string" then @@ -983,8 +963,6 @@ TEST_CASE_FIXTURE(Fixture, "either_number_or_string") TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local function f(t: {x: boolean}?) if not t or t.x then @@ -1000,8 +978,6 @@ TEST_CASE_FIXTURE(Fixture, "not_t_or_some_prop_of_t") TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local a: (number | string)? assert(a) @@ -1018,8 +994,6 @@ TEST_CASE_FIXTURE(Fixture, "assert_a_to_be_truthy_then_assert_a_to_be_number") TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - // This bug came up because there was a mistake in Luau::merge where zipping on two maps would produce the wrong merged result. CheckResult result = check(R"( local function f(b: string | { x: string }, a) @@ -1039,8 +1013,6 @@ TEST_CASE_FIXTURE(Fixture, "merge_should_be_fully_agnostic_of_hashmap_ordering") TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_number_or_string") { - ScopedFastFlag sff{"LuauOrPredicate", true}; - CheckResult result = check(R"( local function f(a: string | number | boolean) if type(a) ~= "number" and type(a) ~= "string" then diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index f1451a815..c3694be79 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1950,4 +1950,76 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_prop") +{ + ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; // Only for new path + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; + + CheckResult result = check(R"( +type A = { x: number, y: number } +type B = { x: number, y: string } + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property 'y' is not compatible. Type 'number' could not be converted into 'string')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_prop_nested") +{ + ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; // Only for new path + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; + + CheckResult result = check(R"( +type AS = { x: number, y: number } +type BS = { x: number, y: string } + +type A = { a: boolean, b: AS } +type B = { a: boolean, b: BS } + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property 'b' is not compatible. Type 'AS' could not be converted into 'BS' +caused by: + Property 'y' is not compatible. Type 'number' could not be converted into 'string')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_metatable_prop") +{ + ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; // Only for new path + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; + + CheckResult result = check(R"( +local a1 = setmetatable({ x = 2, y = 3 }, { __call = function(s) end }); +local b1 = setmetatable({ x = 2, y = "hello" }, { __call = function(s) end }); +local c1: typeof(a1) = b1 + +local a2 = setmetatable({ x = 2, y = 3 }, { __call = function(s) end }); +local b2 = setmetatable({ x = 2, y = 4 }, { __call = function(s, t) end }); +local c2: typeof(a2) = b2 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'b1' could not be converted into 'a1' +caused by: + Type '{| x: number, y: string |}' could not be converted into '{| x: number, y: number |}' +caused by: + Property 'y' is not compatible. Type 'string' could not be converted into 'number')"); + + CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' +caused by: + Type '{| __call: (a, b) -> () |}' could not be converted into '{| __call: (a) -> () |}' +caused by: + Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()')"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 453817574..30d9130a5 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -3926,8 +3926,6 @@ local b: number = 1 or a TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type") { - ScopedFastFlag sffs2{"LuauGenericFunctions", true}; - CheckResult result = check(R"( --!strict local tbl = {} @@ -4493,10 +4491,6 @@ f(function(x) print(x) end) TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") { - ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); - ScopedFastFlag luauRankNTypes("LuauRankNTypes", true); - CheckResult result = check(R"( local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end return sum(2, 3, function(a, b) return a + b end) @@ -4525,10 +4519,6 @@ local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} e TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") { - ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); - ScopedFastFlag luauRankNTypes("LuauRankNTypes", true); - CheckResult result = check(R"( local function g1(a: T, f: (T) -> T) return f(a) end local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end @@ -4579,10 +4569,6 @@ local a: TableWithFunc = { x = 3, y = 4, f = function(a, b) return a + b end } TEST_CASE_FIXTURE(Fixture, "do_not_infer_generic_functions") { - ScopedFastFlag luauGenericFunctions("LuauGenericFunctions", true); - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); - ScopedFastFlag luauRankNTypes("LuauRankNTypes", true); - CheckResult result = check(R"( local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end @@ -4600,8 +4586,6 @@ local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not i TEST_CASE_FIXTURE(Fixture, "infer_return_value_type") { - ScopedFastFlag luauInferReturnAssertAssign("LuauInferReturnAssertAssign", true); - CheckResult result = check(R"( local function f(): {string|number} return {1, "b", 3} @@ -4625,8 +4609,6 @@ end TEST_CASE_FIXTURE(Fixture, "infer_type_assertion_value_type") { - ScopedFastFlag luauInferReturnAssertAssign("LuauInferReturnAssertAssign", true); - CheckResult result = check(R"( local function f() return {4, "b", 3} :: {string|number} @@ -4638,8 +4620,6 @@ end TEST_CASE_FIXTURE(Fixture, "infer_assignment_value_types") { - ScopedFastFlag luauInferReturnAssertAssign("LuauInferReturnAssertAssign", true); - CheckResult result = check(R"( local a: (number, number) -> number = function(a, b) return a - b end @@ -4655,8 +4635,6 @@ b, c = {2, "s"}, {"b", 4} TEST_CASE_FIXTURE(Fixture, "infer_assignment_value_types_mutable_lval") { - ScopedFastFlag luauInferReturnAssertAssign("LuauInferReturnAssertAssign", true); - CheckResult result = check(R"( local a = {} a.x = 2 @@ -4668,8 +4646,6 @@ a = setmetatable(a, { __call = function(x) end }) TEST_CASE_FIXTURE(Fixture, "refine_and_or") { - ScopedFastFlag sff{"LuauSlightlyMoreFlexibleBinaryPredicates", true}; - CheckResult result = check(R"( local t: {x: number?}? = {x = nil} local u = t and t.x or 5 @@ -4682,10 +4658,6 @@ TEST_CASE_FIXTURE(Fixture, "refine_and_or") TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early") { - ScopedFastFlag sffs[] = { - {"LuauSlightlyMoreFlexibleBinaryPredicates", true}, - }; - CheckResult result = check(R"( local t: {x: number?}? = {x = nil} local u = t.x and t or 5 @@ -4698,10 +4670,6 @@ TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early") TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") { - ScopedFastFlag sffs[] = { - {"LuauSlightlyMoreFlexibleBinaryPredicates", true}, - }; - CheckResult result = check(R"( local t: {x: number?}? = {x = nil} local u = t and t.x == 5 or t.x == 31337 @@ -4714,7 +4682,7 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") { - ScopedFastFlag luauFollowInTypeFunApply("LuauFollowInTypeFunApply", true); + ScopedFastFlag luauCloneCorrectlyBeforeMutatingTableType{"LuauCloneCorrectlyBeforeMutatingTableType", true}; CheckResult result = check(R"( type A = { x: number } diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 1192a8ac0..2d697fc97 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -178,9 +178,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "unifying_variadic_pack_with_error_should_wor TEST_CASE_FIXTURE(TryUnifyFixture, "variadics_should_use_reversed_properly") { - ScopedFastFlag sffs2{"LuauGenericFunctions", true}; - ScopedFastFlag sffs4{"LuauParseGenericFunctions", true}; - CheckResult result = check(R"( --!strict local function f(...: T): ...T @@ -199,8 +196,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadics_should_use_reversed_properly") TEST_CASE_FIXTURE(TryUnifyFixture, "cli_41095_concat_log_in_sealed_table_unification") { - ScopedFastFlag sffs2("LuauGenericFunctions", true); - CheckResult result = check(R"( --!strict table.insert() diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 8dab2605b..c6de0abf5 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -296,7 +296,6 @@ end TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -361,7 +360,6 @@ local c: Packed TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_import") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -395,7 +393,6 @@ local d: { a: typeof(c) } TEST_CASE_FIXTURE(Fixture, "type_pack_type_parameters") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -434,7 +431,6 @@ type C = Import.Packed TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_nested") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -456,7 +452,6 @@ type Packed4 = (Packed3, T...) -> (Packed3, T...) TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_variadic") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -475,7 +470,6 @@ type E = X<(number, ...string)> TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_multi") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -507,7 +501,6 @@ type I = W TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -534,7 +527,6 @@ type F = X<(string, ...number)> TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -557,10 +549,8 @@ type D = Y TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi_tostring") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - ScopedFastFlag luauInstantiatedTypeParamRecursion("LuauInstantiatedTypeParamRecursion", true); // For correct toString block CheckResult result = check(R"( type Y = { f: (T...) -> (U...) } @@ -577,7 +567,6 @@ local b: Y<(), ()> TEST_CASE_FIXTURE(Fixture, "type_alias_backwards_compatible") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); @@ -599,7 +588,6 @@ type C = Y<(number), boolean> TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_errors") { - ScopedFastFlag luauParseGenericFunctions("LuauParseGenericFunctions", true); ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 34c25a9fe..9f29b6428 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -400,8 +400,6 @@ local e = a.z TEST_CASE_FIXTURE(Fixture, "unify_sealed_table_union_check") { - ScopedFastFlag luauSealedTableUnifyOptionalFix("LuauSealedTableUnifyOptionalFix", true); - CheckResult result = check(R"( local x: { x: number } = { x = 3 } type A = number? @@ -426,4 +424,43 @@ y = x LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_union_part") +{ + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; + + CheckResult result = check(R"( +type X = { x: number } +type Y = { y: number } +type Z = { z: number } + +type XYZ = X | Y | Z + +local a: XYZ +local b: { w: number } = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'X | Y | Z' could not be converted into '{| w: number |}' +caused by: + Not all union options are compatible. Table type 'X' not compatible with type '{| w: number |}' because the former is missing field 'w')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_union_all") +{ + ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; + + CheckResult result = check(R"( +type X = { x: number } +type Y = { y: number } +type Z = { z: number } + +type XYZ = X | Y | Z + +local a: XYZ = { w = 4 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'a' could not be converted into 'X | Y | Z'; none of the union options are compatible)"); +} + TEST_SUITE_END(); diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 930c1a39b..91efa8188 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -11,8 +11,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauGenericFunctions); - TEST_SUITE_BEGIN("TypeVarTests"); TEST_CASE_FIXTURE(Fixture, "primitives_are_equal") diff --git a/tests/conformance/bitwise.lua b/tests/conformance/bitwise.lua index 6efa5960e..13be3f942 100644 --- a/tests/conformance/bitwise.lua +++ b/tests/conformance/bitwise.lua @@ -113,6 +113,20 @@ assert(bit32.replace(0, -1, 4) == 2^4) assert(bit32.replace(-1, 0, 31) == 2^31 - 1) assert(bit32.replace(-1, 0, 1, 2) == 2^32 - 7) +-- testing countlz/countrc +assert(bit32.countlz(0) == 32) +assert(bit32.countlz(42) == 26) +assert(bit32.countlz(0xffffffff) == 0) +assert(bit32.countlz(0x80000000) == 0) +assert(bit32.countlz(0x7fffffff) == 1) + +assert(bit32.countrz(0) == 32) +assert(bit32.countrz(1) == 0) +assert(bit32.countrz(42) == 1) +assert(bit32.countrz(0x80000000) == 31) +assert(bit32.countrz(0x40000000) == 30) +assert(bit32.countrz(0x7fffffff) == 0) + --[[ This test verifies a fix in luauF_replace() where if the 4th parameter was not a number, but the first three are numbers, it will @@ -136,5 +150,7 @@ assert(bit32.bxor("1", 3) == 2) assert(bit32.bxor(1, "3") == 2) assert(bit32.btest(1, "3") == true) assert(bit32.btest("1", 3) == true) +assert(bit32.countlz("42") == 26) +assert(bit32.countrz("42") == 1) return('OK') From 8fe0dc0b6d553dc053a43f5ed6607c9c4d1a26c0 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 11 Nov 2021 18:23:34 -0800 Subject: [PATCH 04/32] Fix build --- VM/src/lbitlib.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/VM/src/lbitlib.cpp b/VM/src/lbitlib.cpp index c72fe6748..907c43c42 100644 --- a/VM/src/lbitlib.cpp +++ b/VM/src/lbitlib.cpp @@ -2,6 +2,7 @@ // This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details #include "lualib.h" +#include "lcommon.h" #include "lnumutils.h" LUAU_FASTFLAGVARIABLE(LuauBit32Count, false) From 863d3ff6ffa64398a6d13dc8c189bae30fefd557 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 11 Nov 2021 19:42:50 -0800 Subject: [PATCH 05/32] Attempt to work around non-sensical error --- Analysis/src/TypeInfer.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 8fad1af91..b27f3e170 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -5030,7 +5030,8 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement return isaP.ty; } - return std::nullopt; + std::optional res = std::nullopt; + return res; }; std::optional ty = resolveLValue(refis, scope, isaP.lvalue); From 3c3541aba84d9209b6098a5c6ae01727ab11ec32 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 11 Nov 2021 20:36:53 -0800 Subject: [PATCH 06/32] Add a comment --- Analysis/src/TypeInfer.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index b27f3e170..a6696efde 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -5030,6 +5030,7 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement return isaP.ty; } + // local variable works around an odd gcc 9.3 warning: may be used uninitialized std::optional res = std::nullopt; return res; }; From 60e6e86adb5c7a687153989bc471fc05fc4637e8 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 18 Nov 2021 14:21:07 -0800 Subject: [PATCH 07/32] Sync to upstream/release/505 --- Analysis/include/Luau/Documentation.h | 11 +- Analysis/include/Luau/ToString.h | 11 +- Analysis/include/Luau/TypeInfer.h | 24 +- Analysis/include/Luau/TypeVar.h | 87 ++++- Analysis/include/Luau/Unifiable.h | 2 + Analysis/include/Luau/Unifier.h | 1 + Analysis/src/Autocomplete.cpp | 33 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 1 + Analysis/src/Error.cpp | 8 +- Analysis/src/Frontend.cpp | 4 +- Analysis/src/Module.cpp | 16 +- Analysis/src/ToString.cpp | 115 +++++- Analysis/src/Transpiler.cpp | 55 --- Analysis/src/TypeAttach.cpp | 16 + Analysis/src/TypeInfer.cpp | 374 ++++++++++++------- Analysis/src/TypePack.cpp | 1 - Analysis/src/TypeVar.cpp | 50 ++- Analysis/src/Unifier.cpp | 64 +++- Ast/include/Luau/Ast.h | 32 ++ Ast/include/Luau/Parser.h | 1 + Ast/include/Luau/StringUtils.h | 2 + Ast/src/Ast.cpp | 22 ++ Ast/src/Parser.cpp | 68 +++- Ast/src/StringUtils.cpp | 58 +++ CLI/Analyze.cpp | 28 +- CLI/FileUtils.cpp | 46 ++- CLI/FileUtils.h | 3 + CLI/Repl.cpp | 81 ++--- CMakeLists.txt | 17 +- Compiler/include/Luau/Compiler.h | 4 +- Compiler/include/luacode.h | 39 ++ Compiler/src/Compiler.cpp | 46 +-- Compiler/src/lcode.cpp | 29 ++ Makefile | 10 +- Sources.cmake | 3 + VM/include/lua.h | 18 +- VM/include/lualib.h | 6 +- VM/src/lapi.cpp | 10 +- VM/src/lbaselib.cpp | 8 +- VM/src/lbitlib.cpp | 1 + VM/src/lcorolib.cpp | 36 +- VM/src/ldebug.cpp | 10 +- VM/src/ldo.cpp | 6 +- VM/src/linit.cpp | 2 +- VM/src/lstate.cpp | 28 ++ VM/src/lstrlib.cpp | 2 +- VM/src/lutf8lib.cpp | 2 +- bench/bench.py | 32 +- fuzz/proto.cpp | 2 +- tests/Autocomplete.test.cpp | 16 +- tests/Compiler.test.cpp | 63 +--- tests/Conformance.test.cpp | 99 ++--- tests/IostreamOptional.h | 3 +- tests/Module.test.cpp | 6 +- tests/ToString.test.cpp | 111 +++++- tests/TypeInfer.aliases.test.cpp | 2 +- tests/TypeInfer.annotations.test.cpp | 4 +- tests/TypeInfer.generics.test.cpp | 6 +- tests/TypeInfer.refinements.test.cpp | 3 +- tests/TypeInfer.singletons.test.cpp | 377 ++++++++++++++++++++ tests/TypeInfer.test.cpp | 87 +++-- tests/TypeInfer.tryUnify.test.cpp | 30 +- tests/TypeInfer.unionTypes.test.cpp | 5 +- tests/conformance/coroutine.lua | 54 +++ tests/conformance/debugger.lua | 9 + 65 files changed, 1820 insertions(+), 580 deletions(-) create mode 100644 Compiler/include/luacode.h create mode 100644 Compiler/src/lcode.cpp create mode 100644 tests/TypeInfer.singletons.test.cpp diff --git a/Analysis/include/Luau/Documentation.h b/Analysis/include/Luau/Documentation.h index 7b609b4fa..68ff3a7c2 100644 --- a/Analysis/include/Luau/Documentation.h +++ b/Analysis/include/Luau/Documentation.h @@ -12,10 +12,17 @@ namespace Luau struct FunctionDocumentation; struct TableDocumentation; struct OverloadedFunctionDocumentation; +struct BasicDocumentation; -using Documentation = Luau::Variant; +using Documentation = Luau::Variant; using DocumentationSymbol = std::string; +struct BasicDocumentation +{ + std::string documentation; + std::string learnMoreLink; +}; + struct FunctionParameterDocumentation { std::string name; @@ -29,6 +36,7 @@ struct FunctionDocumentation std::string documentation; std::vector parameters; std::vector returns; + std::string learnMoreLink; }; struct OverloadedFunctionDocumentation @@ -43,6 +51,7 @@ struct TableDocumentation { std::string documentation; Luau::DenseHashMap keys; + std::string learnMoreLink; }; using DocumentationDatabase = Luau::DenseHashMap; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index e5683fc40..50379c1cd 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -23,10 +23,11 @@ struct ToStringNameMap struct ToStringOptions { - bool exhaustive = false; // If true, we produce complete output rather than comprehensible output - bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable. - bool functionTypeArguments = false; // If true, output function type argument names when they are available - bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' + bool exhaustive = false; // If true, we produce complete output rather than comprehensible output + bool useLineBreaks = false; // If true, we insert new lines to separate long results such as table entries/metatable. + bool functionTypeArguments = false; // If true, output function type argument names when they are available + bool hideTableKind = false; // If true, all tables will be surrounded with plain '{}' + bool hideNamedFunctionTypeParameters = false; // If true, type parameters of functions will be hidden at top-level. size_t maxTableLength = size_t(FInt::LuauTableTypeMaximumStringifierLength); // Only applied to TableTypeVars size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); std::optional nameMap; @@ -64,6 +65,8 @@ inline std::string toString(TypePackId ty) std::string toString(const TypeVar& tv, const ToStringOptions& opts = {}); std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {}); +std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts = {}); + // It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression void dump(TypeId ty); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 306ac77d8..78d642c58 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -175,10 +175,10 @@ struct TypeChecker std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, - std::vector& overloadsThatMatchArgCount, std::vector& errors); + std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, const std::vector& errors); - ExprResult reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, + void reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, const std::vector& errors); @@ -282,6 +282,14 @@ struct TypeChecker // Wrapper for merge(l, r, toUnion) but without the lambda junk. void merge(RefinementMap& l, const RefinementMap& r); + // Produce an "emergency backup type" for recovery from type errors. + // This comes in two flavours, depening on whether or not we can make a good guess + // for an error recovery type. + TypeId errorRecoveryType(TypeId guess); + TypePackId errorRecoveryTypePack(TypePackId guess); + TypeId errorRecoveryType(const ScopePtr& scope); + TypePackId errorRecoveryTypePack(const ScopePtr& scope); + private: void prepareErrorsForDisplay(ErrorVec& errVec); void diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& data); @@ -297,6 +305,10 @@ struct TypeChecker TypeId freshType(const ScopePtr& scope); TypeId freshType(TypeLevel level); + // Produce a new singleton type var. + TypeId singletonType(bool value); + TypeId singletonType(std::string value); + // Returns nullopt if the predicate filters down the TypeId to 0 options. std::optional filterMap(TypeId type, TypeIdPredicate predicate); @@ -330,8 +342,8 @@ struct TypeChecker const std::vector& typePackParams, const Location& location); // Note: `scope` must be a fresh scope. - std::pair, std::vector> createGenericTypes( - const ScopePtr& scope, std::optional levelOpt, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); + std::pair, std::vector> createGenericTypes(const ScopePtr& scope, std::optional levelOpt, + const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); public: ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); @@ -347,7 +359,6 @@ struct TypeChecker void resolve(const OrPredicate& orP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const IsAPredicate& isaP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); - void DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); void resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense); bool isNonstrictMode() const; @@ -387,12 +398,9 @@ struct TypeChecker const TypeId booleanType; const TypeId threadType; const TypeId anyType; - - const TypeId errorType; const TypeId optionalNumberType; const TypePackId anyTypePack; - const TypePackId errorTypePack; private: int checkRecursionCount = 0; diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 6bd7932db..093ea4319 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -108,6 +108,79 @@ struct PrimitiveTypeVar } }; +// Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md +// Types for true and false +struct BoolSingleton +{ + bool value; + + bool operator==(const BoolSingleton& rhs) const + { + return value == rhs.value; + } + + bool operator!=(const BoolSingleton& rhs) const + { + return !(*this == rhs); + } +}; + +// Types for "foo", "bar" etc. +struct StringSingleton +{ + std::string value; + + bool operator==(const StringSingleton& rhs) const + { + return value == rhs.value; + } + + bool operator!=(const StringSingleton& rhs) const + { + return !(*this == rhs); + } +}; + +// No type for float singletons, partly because === isn't any equalivalence on floats +// (NaN != NaN). + +using SingletonVariant = Luau::Variant; + +struct SingletonTypeVar +{ + explicit SingletonTypeVar(const SingletonVariant& variant) + : variant(variant) + { + } + + explicit SingletonTypeVar(SingletonVariant&& variant) + : variant(std::move(variant)) + { + } + + // Default operator== is C++20. + bool operator==(const SingletonTypeVar& rhs) const + { + return variant == rhs.variant; + } + + bool operator!=(const SingletonTypeVar& rhs) const + { + return !(*this == rhs); + } + + SingletonVariant variant; +}; + +template +const T* get(const SingletonTypeVar* stv) +{ + if (stv) + return get_if(&stv->variant); + else + return nullptr; +} + struct FunctionArgument { Name name; @@ -332,8 +405,8 @@ struct LazyTypeVar using ErrorTypeVar = Unifiable::Error; -using TypeVariant = Unifiable::Variant; +using TypeVariant = Unifiable::Variant; struct TypeVar final { @@ -410,6 +483,9 @@ bool isGeneric(const TypeId ty); // Checks if a type may be instantiated to one containing generic type binders bool maybeGeneric(const TypeId ty); +// Checks if a type is of the form T1|...|Tn where one of the Ti is a singleton +bool maybeSingleton(TypeId ty); + struct SingletonTypes { const TypeId nilType; @@ -418,16 +494,19 @@ struct SingletonTypes const TypeId booleanType; const TypeId threadType; const TypeId anyType; - const TypeId errorType; const TypeId optionalNumberType; const TypePackId anyTypePack; - const TypePackId errorTypePack; SingletonTypes(); SingletonTypes(const SingletonTypes&) = delete; void operator=(const SingletonTypes&) = delete; + TypeId errorRecoveryType(TypeId guess); + TypePackId errorRecoveryTypePack(TypePackId guess); + TypeId errorRecoveryType(); + TypePackId errorRecoveryTypePack(); + private: std::unique_ptr arena; TypeId makeStringMetatable(); diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index c2e07e466..b47610fca 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -105,6 +105,8 @@ struct Generic struct Error { + // This constructor has to be public, since it's used in TypeVar and TypePack, + // but shouldn't be called directly. Please use errorRecoveryType() instead. Error(); int index; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index be0aadd05..503034a1b 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -65,6 +65,7 @@ struct Unifier private: void tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall = false, bool isIntersection = false); void tryUnifyPrimitives(TypeId superTy, TypeId subTy); + void tryUnifySingletons(TypeId superTy, TypeId subTy); void tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall = false); void tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); void DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 1c94bb684..6fc0b3f88 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(ElseElseIfCompletionImprovements, false); LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) +LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -198,11 +199,24 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ UnifierSharedState unifierState(&iceReporter); Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState); - unifier.tryUnify(expectedType, actualType); + if (FFlag::LuauAutocompleteAvoidMutation) + { + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + expectedType = clone(expectedType, *typeArena, seenTypes, seenTypePacks, nullptr); + actualType = clone(actualType, *typeArena, seenTypes, seenTypePacks, nullptr); + + auto errors = unifier.canUnify(expectedType, actualType); + return errors.empty(); + } + else + { + unifier.tryUnify(expectedType, actualType); - bool ok = unifier.errors.empty(); - unifier.log.rollback(); - return ok; + bool ok = unifier.errors.empty(); + unifier.log.rollback(); + return ok; + } }; auto expr = node->asExpr(); @@ -1496,11 +1510,9 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName if (!sourceModule) return {}; - TypeChecker& typeChecker = - (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); - ModulePtr module = - (frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName) - : frontend.moduleResolver.getModule(moduleName)); + TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + ModulePtr module = (frontend.options.typecheckTwice ? frontend.moduleResolverForAutocomplete.getModule(moduleName) + : frontend.moduleResolver.getModule(moduleName)); if (!module) return {}; @@ -1527,8 +1539,7 @@ OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view sourceModule->mode = Mode::Strict; sourceModule->commentLocations = std::move(result.commentLocations); - TypeChecker& typeChecker = - (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); + TypeChecker& typeChecker = (frontend.options.typecheckTwice ? frontend.typeCheckerForAutocomplete : frontend.typeChecker); ModulePtr module = typeChecker.check(*sourceModule, Mode::Strict); diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 96703ef16..9f5c82500 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -153,6 +153,7 @@ declare function gcinfo(): number wrap: ((A...) -> R...) -> any, yield: (A...) -> R..., isyieldable: () -> boolean, + close: (thread) -> (boolean, any?) } declare table: { diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 46ff2c72a..f80d50a7a 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -180,13 +180,13 @@ struct ErrorConverter switch (e.context) { case CountMismatch::Return: - return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + - std::to_string(e.actual) + " " + actualVerb + " returned here"; + return "Expected to return " + std::to_string(e.expected) + " value" + expectedS + ", but " + std::to_string(e.actual) + " " + + actualVerb + " returned here"; case CountMismatch::Result: // It is alright if right hand side produces more values than the // left hand side accepts. In this context consider only the opposite case. - return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + - std::to_string(e.actual) + " are required here"; + return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + std::to_string(e.actual) + + " are required here"; case CountMismatch::Arg: if (FFlag::LuauTypeAliasPacks) return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 2f411274b..1e97705dc 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -22,7 +22,6 @@ LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) LUAU_FASTFLAG(LuauNewRequireTrace2) -LUAU_FASTFLAGVARIABLE(LuauClearScopes, false) namespace Luau { @@ -458,8 +457,7 @@ CheckResult Frontend::check(const ModuleName& name) module->astTypes.clear(); module->astExpectedTypes.clear(); module->astOriginalCallTypes.clear(); - if (FFlag::LuauClearScopes) - module->scopes.resize(1); + module->scopes.resize(1); } if (mode != Mode::NoCheck) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 880ffd2e5..32a0646ae 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -161,6 +161,7 @@ struct TypeCloner void operator()(const Unifiable::Bound& t); void operator()(const Unifiable::Error& t); void operator()(const PrimitiveTypeVar& t); + void operator()(const SingletonTypeVar& t); void operator()(const FunctionTypeVar& t); void operator()(const TableTypeVar& t); void operator()(const MetatableTypeVar& t); @@ -199,7 +200,9 @@ struct TypePackCloner if (encounteredFreeType) *encounteredFreeType = true; - seenTypePacks[typePackId] = dest.addTypePack(TypePackVar{Unifiable::Error{}}); + TypePackId err = singletonTypes.errorRecoveryTypePack(singletonTypes.anyTypePack); + TypePackId cloned = dest.addTypePack(*err); + seenTypePacks[typePackId] = cloned; } void operator()(const Unifiable::Generic& t) @@ -251,8 +254,9 @@ void TypeCloner::operator()(const Unifiable::Free& t) { if (encounteredFreeType) *encounteredFreeType = true; - - seenTypes[typeId] = dest.addType(ErrorTypeVar{}); + TypeId err = singletonTypes.errorRecoveryType(singletonTypes.anyType); + TypeId cloned = dest.addType(*err); + seenTypes[typeId] = cloned; } void TypeCloner::operator()(const Unifiable::Generic& t) @@ -270,11 +274,17 @@ void TypeCloner::operator()(const Unifiable::Error& t) { defaultClone(t); } + void TypeCloner::operator()(const PrimitiveTypeVar& t) { defaultClone(t); } +void TypeCloner::operator()(const SingletonTypeVar& t) +{ + defaultClone(t); +} + void TypeCloner::operator()(const FunctionTypeVar& t) { TypeId result = dest.addType(FunctionTypeVar{TypeLevel{0, 0}, {}, {}, nullptr, nullptr, t.definition, t.hasSelf}); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 885fd489b..735bfa503 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -350,6 +350,23 @@ struct TypeVarStringifier } } + void operator()(TypeId, const SingletonTypeVar& stv) + { + if (const BoolSingleton* bs = Luau::get(&stv)) + state.emit(bs->value ? "true" : "false"); + else if (const StringSingleton* ss = Luau::get(&stv)) + { + state.emit("\""); + state.emit(escape(ss->value)); + state.emit("\""); + } + else + { + LUAU_ASSERT(!"Unknown singleton type"); + throw std::runtime_error("Unknown singleton type"); + } + } + void operator()(TypeId, const FunctionTypeVar& ftv) { if (state.hasSeen(&ftv)) @@ -359,6 +376,7 @@ struct TypeVarStringifier return; } + // We should not be respecting opts.hideNamedFunctionTypeParameters here. if (ftv.generics.size() > 0 || ftv.genericPacks.size() > 0) { state.emit("<"); @@ -514,7 +532,14 @@ struct TypeVarStringifier break; } - state.emit(name); + if (isIdentifier(name)) + state.emit(name); + else + { + state.emit("[\""); + state.emit(escape(name)); + state.emit("\"]"); + } state.emit(": "); stringify(prop.type); comma = true; @@ -1084,6 +1109,94 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts) return toString(const_cast(&tp), std::move(opts)); } +std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) +{ + std::string s = prefix; + + auto toString_ = [&opts](TypeId ty) -> std::string { + ToStringResult res = toStringDetailed(ty, opts); + opts.nameMap = std::move(res.nameMap); + return res.name; + }; + + auto toStringPack_ = [&opts](TypePackId ty) -> std::string { + ToStringResult res = toStringDetailed(ty, opts); + opts.nameMap = std::move(res.nameMap); + return res.name; + }; + + if (!opts.hideNamedFunctionTypeParameters && (!ftv.generics.empty() || !ftv.genericPacks.empty())) + { + s += "<"; + + bool first = true; + for (TypeId g : ftv.generics) + { + if (!first) + s += ", "; + first = false; + s += toString_(g); + } + + for (TypePackId gp : ftv.genericPacks) + { + if (!first) + s += ", "; + first = false; + s += toStringPack_(gp); + } + + s += ">"; + } + + s += "("; + + auto argPackIter = begin(ftv.argTypes); + auto argNameIter = ftv.argNames.begin(); + + bool first = true; + while (argPackIter != end(ftv.argTypes)) + { + if (!first) + s += ", "; + first = false; + + // argNames is guaranteed to be equal to argTypes iff argNames is not empty. + // We don't currently respect opts.functionTypeArguments. I don't think this function should. + if (!ftv.argNames.empty()) + s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; + s += toString_(*argPackIter); + + ++argPackIter; + if (!ftv.argNames.empty()) + { + LUAU_ASSERT(argNameIter != ftv.argNames.end()); + ++argNameIter; + } + } + + if (argPackIter.tail()) + { + if (auto vtp = get(*argPackIter.tail())) + s += ", ...: " + toString_(vtp->ty); + else + s += ", ...: " + toStringPack_(*argPackIter.tail()); + } + + s += "): "; + + size_t retSize = size(ftv.retType); + bool hasTail = !finite(ftv.retType); + if (retSize == 0 && !hasTail) + s += "()"; + else if ((retSize == 0 && hasTail) || (retSize == 1 && !hasTail)) + s += toStringPack_(ftv.retType); + else + s += "(" + toStringPack_(ftv.retType) + ")"; + + return s; +} + void dump(TypeId ty) { ToStringOptions opts; diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 7d880af49..6627fbe36 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -14,61 +14,6 @@ LUAU_FASTFLAG(LuauTypeAliasPacks) namespace { - -std::string escape(std::string_view s) -{ - std::string r; - r.reserve(s.size() + 50); // arbitrary number to guess how many characters we'll be inserting - - for (uint8_t c : s) - { - if (c >= ' ' && c != '\\' && c != '\'' && c != '\"') - r += c; - else - { - r += '\\'; - - switch (c) - { - case '\a': - r += 'a'; - break; - case '\b': - r += 'b'; - break; - case '\f': - r += 'f'; - break; - case '\n': - r += 'n'; - break; - case '\r': - r += 'r'; - break; - case '\t': - r += 't'; - break; - case '\v': - r += 'v'; - break; - case '\'': - r += '\''; - break; - case '\"': - r += '\"'; - break; - case '\\': - r += '\\'; - break; - default: - Luau::formatAppend(r, "%03u", c); - } - } - } - - return r; -} - bool isIdentifierStartChar(char c) { return (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || c == '_'; diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 11aa7b394..af6d2543d 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -96,6 +96,22 @@ class TypeRehydrationVisitor return nullptr; } } + + AstType* operator()(const SingletonTypeVar& stv) + { + if (const BoolSingleton* bs = get(&stv)) + return allocator->alloc(Location(), bs->value); + else if (const StringSingleton* ss = get(&stv)) + { + AstArray value; + value.data = const_cast(ss->value.c_str()); + value.size = strlen(value.data); + return allocator->alloc(Location(), value); + } + else + return nullptr; + } + AstType* operator()(const AnyTypeVar&) { return allocator->alloc(Location(), std::nullopt, AstName("any")); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 8fad1af91..b2ae94c72 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -36,6 +36,9 @@ LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAG(LuauNewRequireTrace2) LUAU_FASTFLAG(LuauTypeAliasPacks) +LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) +LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) +LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) namespace Luau { @@ -211,10 +214,8 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan , booleanType(singletonTypes.booleanType) , threadType(singletonTypes.threadType) , anyType(singletonTypes.anyType) - , errorType(singletonTypes.errorType) , optionalNumberType(singletonTypes.optionalNumberType) , anyTypePack(singletonTypes.anyTypePack) - , errorTypePack(singletonTypes.errorTypePack) { globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); @@ -484,7 +485,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std TypeId type = bindings[name].type; if (get(follow(type))) { - *asMutable(type) = ErrorTypeVar{}; + *asMutable(type) = *errorRecoveryType(anyType); reportError(TypeError{typealias->location, OccursCheckFailed{}}); } } @@ -719,7 +720,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) else if (auto tail = valueIter.tail()) { if (get(*tail)) - right = errorType; + right = errorRecoveryType(scope); else if (auto vtp = get(*tail)) right = vtp->ty; else if (get(*tail)) @@ -961,7 +962,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) else if (get(callRetPack) || !first(callRetPack)) { for (TypeId var : varTypes) - unify(var, errorType, forin.location); + unify(var, errorRecoveryType(scope), forin.location); return check(loopScope, *forin.body); } @@ -979,7 +980,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) const FunctionTypeVar* iterFunc = get(iterTy); if (!iterFunc) { - TypeId varTy = get(iterTy) ? anyType : errorType; + TypeId varTy = get(iterTy) ? anyType : errorRecoveryType(loopScope); for (TypeId var : varTypes) unify(var, varTy, forin.location); @@ -1152,9 +1153,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); if (FFlag::LuauTypeAliasPacks) - bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorType}; + bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)}; else - bindingsMap[name] = TypeFun{binding->typeParams, errorType}; + bindingsMap[name] = TypeFun{binding->typeParams, errorRecoveryType(anyType)}; } else { @@ -1398,7 +1399,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) { reportErrorCodeTooComplex(expr.location); - return {errorType}; + return {errorRecoveryType(scope)}; } ExprResult result; @@ -1407,12 +1408,22 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& result = checkExpr(scope, *a->expr); else if (expr.is()) result = {nilType}; - else if (expr.is()) - result = {booleanType}; + else if (const AstExprConstantBool* bexpr = expr.as()) + { + if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType)) + result = {singletonType(bexpr->value)}; + else + result = {booleanType}; + } + else if (const AstExprConstantString* sexpr = expr.as()) + { + if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType)) + result = {singletonType(std::string(sexpr->value.data, sexpr->value.size))}; + else + result = {stringType}; + } else if (expr.is()) result = {numberType}; - else if (expr.is()) - result = {stringType}; else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) @@ -1485,7 +1496,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprLo // TODO: tempting to ice here, but this breaks very often because our toposort doesn't enforce this constraint // ice("AstExprLocal exists but no binding definition for it?", expr.location); reportError(TypeError{expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}}); - return {errorType}; + return {errorRecoveryType(scope)}; } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGlobal& expr) @@ -1497,7 +1508,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprGl return {*ty, {TruthyPredicate{std::move(*lvalue), expr.location}}}; reportError(TypeError{expr.location, UnknownSymbol{expr.name.value, UnknownSymbol::Binding}}); - return {errorType}; + return {errorRecoveryType(scope)}; } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVarargs& expr) @@ -1509,7 +1520,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVa std::vector types = flatten(varargPack).first; return {!types.empty() ? types[0] : nilType}; } - else if (auto ftp = get(varargPack)) + else if (get(varargPack)) { TypeId head = freshType(scope); TypePackId tail = freshTypePack(scope); @@ -1517,14 +1528,14 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprVa return {head}; } if (get(varargPack)) - return {errorType}; + return {errorRecoveryType(scope)}; else if (auto vtp = get(varargPack)) return {vtp->ty}; else if (get(varargPack)) { // TODO: Better error? reportError(expr.location, GenericError{"Trying to get a type from a variadic type parameter"}); - return {errorType}; + return {errorRecoveryType(scope)}; } else ice("Unknown TypePack type in checkExpr(AstExprVarargs)!"); @@ -1539,7 +1550,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa { return {pack->head.empty() ? nilType : pack->head[0], std::move(result.predicates)}; } - else if (auto ftp = get(retPack)) + else if (get(retPack)) { TypeId head = freshType(scope); TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope)}}); @@ -1547,7 +1558,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa return {head, std::move(result.predicates)}; } if (get(retPack)) - return {errorType, std::move(result.predicates)}; + return {errorRecoveryType(scope), std::move(result.predicates)}; else if (auto vtp = get(retPack)) return {vtp->ty, std::move(result.predicates)}; else if (get(retPack)) @@ -1572,7 +1583,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn if (std::optional ty = getIndexTypeFromType(scope, lhsType, name, expr.location, true)) return {*ty}; - return {errorType}; + return {errorRecoveryType(scope)}; } std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location) @@ -1876,6 +1887,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa std::vector> fieldTypes(expr.items.size); const TableTypeVar* expectedTable = nullptr; + const UnionTypeVar* expectedUnion = nullptr; std::optional expectedIndexType; std::optional expectedIndexResultType; @@ -1894,6 +1906,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa } } } + else if (FFlag::LuauExpectedTypesOfProperties) + if (const UnionTypeVar* utv = get(follow(*expectedType))) + expectedUnion = utv; } for (size_t i = 0; i < expr.items.size; ++i) @@ -1916,6 +1931,18 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTa if (auto prop = expectedTable->props.find(key->value.data); prop != expectedTable->props.end()) expectedResultType = prop->second.type; } + else if (FFlag::LuauExpectedTypesOfProperties && expectedUnion) + { + std::vector expectedResultTypes; + for (TypeId expectedOption : expectedUnion) + if (const TableTypeVar* ttv = get(follow(expectedOption))) + if (auto prop = ttv->props.find(key->value.data); prop != ttv->props.end()) + expectedResultTypes.push_back(prop->second.type); + if (expectedResultTypes.size() == 1) + expectedResultType = expectedResultTypes[0]; + else if (expectedResultTypes.size() > 1) + expectedResultType = addType(UnionTypeVar{expectedResultTypes}); + } } else { @@ -1958,21 +1985,22 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn { TypeId actualFunctionType = instantiate(scope, *fnt, expr.location); TypePackId arguments = addTypePack({operandType}); - TypePackId retType = freshTypePack(scope); - TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retType)); + TypePackId retTypePack = freshTypePack(scope); + TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); Unifier state = mkUnifier(expr.location); state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); + TypeId retType = first(retTypePack).value_or(nilType); if (!state.errors.empty()) - return {errorType}; + retType = errorRecoveryType(retType); - return {first(retType).value_or(nilType)}; + return {retType}; } reportError(expr.location, GenericError{format("Unary operator '%s' not supported by type '%s'", toString(expr.op).c_str(), toString(operandType).c_str())}); - return {errorType}; + return {errorRecoveryType(scope)}; } reportErrors(tryUnify(numberType, operandType, expr.location)); @@ -1984,7 +2012,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn operandType = stripFromNilAndReport(operandType, expr.location); if (get(operandType)) - return {errorType}; + return {errorRecoveryType(scope)}; if (get(operandType)) return {numberType}; // Not strictly correct: metatables permit overriding this @@ -2044,7 +2072,7 @@ TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const Location& location, b if (unify(a, b, location)) return a; - return errorType; + return errorRecoveryType(anyType); } if (*a == *b) @@ -2166,11 +2194,13 @@ TypeId TypeChecker::checkRelationalOperation( std::optional leftMetatable = isString(lhsType) ? std::nullopt : getMetatable(follow(lhsType)); std::optional rightMetatable = isString(rhsType) ? std::nullopt : getMetatable(follow(rhsType)); + // TODO: this check seems odd, the second part is redundant + // is it meant to be if (leftMetatable && rightMetatable && leftMetatable != rightMetatable) if (bool(leftMetatable) != bool(rightMetatable) && leftMetatable != rightMetatable) { reportError(expr.location, GenericError{format("Types %s and %s cannot be compared with %s because they do not have the same metatable", toString(lhsType).c_str(), toString(rhsType).c_str(), toString(expr.op).c_str())}); - return errorType; + return errorRecoveryType(booleanType); } if (leftMetatable) @@ -2188,7 +2218,7 @@ TypeId TypeChecker::checkRelationalOperation( if (!state.errors.empty()) { reportError(expr.location, GenericError{format("Metamethod '%s' must return type 'boolean'", metamethodName.c_str())}); - return errorType; + return errorRecoveryType(booleanType); } } } @@ -2206,7 +2236,7 @@ TypeId TypeChecker::checkRelationalOperation( { reportError( expr.location, GenericError{format("Table %s does not offer metamethod %s", toString(lhsType).c_str(), metamethodName.c_str())}); - return errorType; + return errorRecoveryType(booleanType); } } @@ -2214,14 +2244,14 @@ TypeId TypeChecker::checkRelationalOperation( { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Comparison}); - return errorType; + return errorRecoveryType(booleanType); } if (needsMetamethod) { reportError(expr.location, GenericError{format("Type %s cannot be compared with %s because it has no metatable", toString(lhsType).c_str(), toString(expr.op).c_str())}); - return errorType; + return errorRecoveryType(booleanType); } return booleanType; @@ -2266,7 +2296,8 @@ TypeId TypeChecker::checkBinaryOperation( { auto name = getIdentifierOfBaseVar(expr.left); reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); - return errorType; + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); } // If we know nothing at all about the lhs type, we can usually say nothing about the result. @@ -2296,18 +2327,33 @@ TypeId TypeChecker::checkBinaryOperation( auto checkMetatableCall = [this, &scope, &expr](TypeId fnt, TypeId lhst, TypeId rhst) -> TypeId { TypeId actualFunctionType = instantiate(scope, fnt, expr.location); TypePackId arguments = addTypePack({lhst, rhst}); - TypePackId retType = freshTypePack(scope); - TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retType)); + TypePackId retTypePack = freshTypePack(scope); + TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); Unifier state = mkUnifier(expr.location); state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); reportErrors(state.errors); + bool hasErrors = !state.errors.empty(); - if (!state.errors.empty()) - return errorType; + if (FFlag::LuauErrorRecoveryType && hasErrors) + { + // If there are unification errors, the return type may still be unknown + // so we loosen the argument types to see if that helps. + TypePackId fallbackArguments = freshTypePack(scope); + TypeId fallbackFunctionType = addType(FunctionTypeVar(scope->level, fallbackArguments, retTypePack)); + state.log.rollback(); + state.errors.clear(); + state.tryUnify(fallbackFunctionType, actualFunctionType, /*isFunctionCall*/ true); + if (!state.errors.empty()) + state.log.rollback(); + } - return first(retType).value_or(nilType); + TypeId retType = first(retTypePack).value_or(nilType); + if (hasErrors) + retType = errorRecoveryType(retType); + + return retType; }; std::string op = opToMetaTableEntry(expr.op); @@ -2321,7 +2367,8 @@ TypeId TypeChecker::checkBinaryOperation( reportError(expr.location, GenericError{format("Binary operator '%s' not supported by types '%s' and '%s'", toString(expr.op).c_str(), toString(lhsType).c_str(), toString(rhsType).c_str())}); - return errorType; + + return errorRecoveryType(scope); } switch (expr.op) @@ -2414,11 +2461,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy ExprResult result = checkExpr(scope, *expr.expr, annotationType); ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); + reportErrors(errorVec); if (!errorVec.empty()) - { - reportErrors(errorVec); - return {errorType, std::move(result.predicates)}; - } + annotationType = errorRecoveryType(annotationType); return {annotationType, std::move(result.predicates)}; } @@ -2434,7 +2479,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprEr // any type errors that may arise from it are going to be useless. currentModule->errors.resize(oldSize); - return {errorType}; + return {errorRecoveryType(scope)}; } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr) @@ -2476,7 +2521,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope { for (AstExpr* expr : a->expressions) checkExpr(scope, *expr); - return std::pair(errorType, nullptr); + return {errorRecoveryType(scope), nullptr}; } else ice("Unexpected AST node in checkLValue", expr.location); @@ -2488,7 +2533,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope return {*ty, nullptr}; reportError(expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}); - return {errorType, nullptr}; + return {errorRecoveryType(scope), nullptr}; } std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr) @@ -2545,24 +2590,25 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope { Unifier state = mkUnifier(expr.location); state.tryUnify(indexer->indexType, stringType); + TypeId retType = indexer->indexResultType; if (!state.errors.empty()) { state.log.rollback(); reportError(expr.location, UnknownProperty{lhs, name}); - return std::pair(errorType, nullptr); + retType = errorRecoveryType(retType); } - return std::pair(indexer->indexResultType, nullptr); + return std::pair(retType, nullptr); } else if (lhsTable->state == TableState::Sealed) { reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } else { reportError(TypeError{expr.location, GenericError{"Internal error: generic tables are not lvalues"}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } } else if (const ClassTypeVar* lhsClass = get(lhs)) @@ -2571,7 +2617,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!prop) { reportError(TypeError{expr.location, UnknownProperty{lhs, name}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } return std::pair(prop->type, nullptr); @@ -2585,12 +2631,12 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (isTableIntersection(lhs)) { reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } } reportError(TypeError{expr.location, NotATable{lhs}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr) @@ -2615,7 +2661,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!prop) { reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } return std::pair(prop->type, nullptr); } @@ -2626,7 +2672,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!exprTable) { reportError(TypeError{expr.expr->location, NotATable{exprType}}); - return std::pair(errorType, nullptr); + return std::pair(errorRecoveryType(scope), nullptr); } if (value) @@ -2678,7 +2724,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) if (isNonstrictMode()) return globalScope->bindings[name].typeId; - return errorType; + return errorRecoveryType(scope); } else { @@ -2705,20 +2751,21 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) TableTypeVar* ttv = getMutableTableType(lhsType); if (!ttv) { - if (!isTableIntersection(lhsType)) + if (!FFlag::LuauErrorRecoveryType && !isTableIntersection(lhsType)) + // This error now gets reported when we check the function body. reportError(TypeError{funName.location, OnlyTablesCanHaveMethods{lhsType}}); - return errorType; + return errorRecoveryType(scope); } // Cannot extend sealed table, but we dont report an error here because it will be reported during AstStatFunction check if (lhsType->persistent || ttv->state == TableState::Sealed) - return errorType; + return errorRecoveryType(scope); Name name = indexName->index.value; if (ttv->props.count(name)) - return errorType; + return errorRecoveryType(scope); Property& property = ttv->props[name]; @@ -2728,9 +2775,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) return property.type; } else if (funName.is()) - { - return errorType; - } + return errorRecoveryType(scope); else { ice("Unexpected AST node type", funName.location); @@ -2991,7 +3036,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A else if (expr.is()) { if (!scope->varargPack) - return {addTypePack({addType(ErrorTypeVar())})}; + return {errorRecoveryTypePack(scope)}; return {*scope->varargPack}; } @@ -3095,10 +3140,9 @@ void TypeChecker::checkArgumentList( if (get(tail)) { // Unify remaining parameters so we don't leave any free-types hanging around. - TypeId argTy = errorType; while (paramIter != endIter) { - state.tryUnify(*paramIter, argTy); + state.tryUnify(*paramIter, errorRecoveryType(anyType)); ++paramIter; } return; @@ -3157,7 +3201,7 @@ void TypeChecker::checkArgumentList( { while (argIter != endIter) { - unify(*argIter, errorType, state.location); + unify(*argIter, errorRecoveryType(scope), state.location); ++argIter; } // For this case, we want the error span to cover every errant extra parameter @@ -3246,7 +3290,8 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A // For each overload // Compare parameter and argument types // Report any errors (also speculate dot vs colon warnings!) - // If there are no errors, return the resulting return type + // Return the resulting return type (even if there are errors) + // If there are no matching overloads, unify with (a...) -> (b...) and return b... TypeId selfType = nullptr; TypeId functionType = nullptr; @@ -3268,8 +3313,8 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A } else { - functionType = errorType; - actualFunctionType = errorType; + functionType = errorRecoveryType(scope); + actualFunctionType = functionType; } } else @@ -3296,7 +3341,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A TypePackId argPack = argListResult.type; if (get(argPack)) - return ExprResult{errorTypePack}; + return {errorRecoveryTypePack(scope)}; TypePack* args = getMutable(argPack); LUAU_ASSERT(args != nullptr); @@ -3314,19 +3359,34 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A std::vector errors; // errors encountered for each overload std::vector overloadsThatMatchArgCount; + std::vector overloadsThatDont; for (TypeId fn : overloads) { fn = follow(fn); - if (auto ret = checkCallOverload(scope, expr, fn, retPack, argPack, args, argLocations, argListResult, overloadsThatMatchArgCount, errors)) + if (auto ret = checkCallOverload( + scope, expr, fn, retPack, argPack, args, argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) return *ret; } if (handleSelfCallMismatch(scope, expr, args, argLocations, errors)) return {retPack}; - return reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); + reportOverloadResolutionError(scope, expr, retPack, argPack, argLocations, overloads, overloadsThatMatchArgCount, errors); + + if (FFlag::LuauErrorRecoveryType) + { + const FunctionTypeVar* overload = nullptr; + if (!overloadsThatMatchArgCount.empty()) + overload = get(overloadsThatMatchArgCount[0]); + if (!overload && !overloadsThatDont.empty()) + overload = get(overloadsThatDont[0]); + if (overload) + return {errorRecoveryTypePack(overload->retType)}; + } + + return {errorRecoveryTypePack(retPack)}; } std::vector> TypeChecker::getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall) @@ -3382,7 +3442,7 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, - std::vector& overloadsThatMatchArgCount, std::vector& errors) + std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors) { fn = stripFromNilAndReport(fn, expr.func->location); @@ -3394,7 +3454,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (get(fn)) { - return {{addTypePack(TypePackVar{Unifiable::Error{}})}}; + return {{errorRecoveryTypePack(scope)}}; } if (get(fn)) @@ -3427,14 +3487,14 @@ std::optional> TypeChecker::checkCallOverload(const Scope TypeId fn = *ty; fn = instantiate(scope, fn, expr.func->location); - return checkCallOverload( - scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult, overloadsThatMatchArgCount, errors); + return checkCallOverload(scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult, + overloadsThatMatchArgCount, overloadsThatDont, errors); } } reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}}); - unify(retPack, errorTypePack, expr.func->location); - return {{errorTypePack}}; + unify(retPack, errorRecoveryTypePack(scope), expr.func->location); + return {{errorRecoveryTypePack(retPack)}}; } // When this function type has magic functions and did return something, we select that overload instead. @@ -3476,6 +3536,8 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (!argMismatch) overloadsThatMatchArgCount.push_back(fn); + else if (FFlag::LuauErrorRecoveryType) + overloadsThatDont.push_back(fn); errors.emplace_back(std::move(state.errors), args->head, ftv); state.log.rollback(); @@ -3586,14 +3648,14 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal return false; } -ExprResult TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, - TypePackId argPack, const std::vector& argLocations, const std::vector& overloads, - const std::vector& overloadsThatMatchArgCount, const std::vector& errors) +void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const AstExprCall& expr, TypePackId retPack, TypePackId argPack, + const std::vector& argLocations, const std::vector& overloads, const std::vector& overloadsThatMatchArgCount, + const std::vector& errors) { if (overloads.size() == 1) { reportErrors(std::get<0>(errors.front())); - return {errorTypePack}; + return; } std::vector overloadTypes = overloadsThatMatchArgCount; @@ -3622,7 +3684,7 @@ ExprResult TypeChecker::reportOverloadResolutionError(const ScopePtr // If only one overload matched, we don't need this error because we provided the previous errors. if (overloadsThatMatchArgCount.size() == 1) - return {errorTypePack}; + return; } std::string s; @@ -3655,7 +3717,7 @@ ExprResult TypeChecker::reportOverloadResolutionError(const ScopePtr reportError(expr.func->location, ExtraInformation{"Other overloads are also not viable: " + s}); // No viable overload - return {errorTypePack}; + return; } ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const Location& location, const AstArray& exprs, @@ -3740,7 +3802,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict) { reportError(TypeError{location, UnknownRequire{}}); - return errorType; + return errorRecoveryType(anyType); } return anyType; @@ -3758,14 +3820,14 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module reportError(TypeError{location, UnknownRequire{reportedModulePath}}); } - return errorType; + return errorRecoveryType(scope); } if (module->type != SourceCode::Module) { std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); reportError(location, IllegalRequire{humanReadableName, "Module is not a ModuleScript. It cannot be required."}); - return errorType; + return errorRecoveryType(scope); } std::optional moduleType = first(module->getModuleScope()->returnType); @@ -3773,7 +3835,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module { std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); reportError(location, IllegalRequire{humanReadableName, "Module does not return exactly 1 value. It cannot be required."}); - return errorType; + return errorRecoveryType(scope); } SeenTypes seenTypes; @@ -4078,7 +4140,7 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location if (!qty.has_value()) { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(scope); } if (ty == *qty) @@ -4101,7 +4163,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat else { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(scope); } } @@ -4116,7 +4178,7 @@ TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) else { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(anyType); } } @@ -4131,7 +4193,7 @@ TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location lo else { reportError(location, UnificationTooComplex{}); - return errorTypePack; + return errorRecoveryTypePack(anyTypePack); } } @@ -4279,6 +4341,38 @@ TypeId TypeChecker::freshType(TypeLevel level) return currentModule->internalTypes.addType(TypeVar(FreeTypeVar(level))); } +TypeId TypeChecker::singletonType(bool value) +{ + // TODO: cache singleton types + return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(BoolSingleton{value}))); +} + +TypeId TypeChecker::singletonType(std::string value) +{ + // TODO: cache singleton types + return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(StringSingleton{std::move(value)}))); +} + +TypeId TypeChecker::errorRecoveryType(const ScopePtr& scope) +{ + return singletonTypes.errorRecoveryType(); +} + +TypeId TypeChecker::errorRecoveryType(TypeId guess) +{ + return singletonTypes.errorRecoveryType(guess); +} + +TypePackId TypeChecker::errorRecoveryTypePack(const ScopePtr& scope) +{ + return singletonTypes.errorRecoveryTypePack(); +} + +TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess) +{ + return singletonTypes.errorRecoveryTypePack(guess); +} + std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) { std::vector types = Luau::filterMap(type, predicate); @@ -4350,7 +4444,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (lit->parameters.size != 1 || !lit->parameters.data[0].type) { reportError(TypeError{annotation.location, GenericError{"_luau_print requires one generic parameter"}}); - return addType(ErrorTypeVar{}); + return errorRecoveryType(anyType); } ToStringOptions opts; @@ -4368,7 +4462,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (!tf) { if (lit->name == Parser::errorName) - return addType(ErrorTypeVar{}); + return errorRecoveryType(scope); std::string typeName; if (lit->hasPrefix) @@ -4380,7 +4474,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else reportError(TypeError{annotation.location, UnknownSymbol{typeName, UnknownSymbol::Type}}); - return addType(ErrorTypeVar{}); + return errorRecoveryType(scope); } if (lit->parameters.size == 0 && tf->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf->typePackParams.empty())) @@ -4390,14 +4484,17 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else if (!FFlag::LuauTypeAliasPacks && lit->parameters.size != tf->typeParams.size()) { reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->parameters.size, 0}}); - return addType(ErrorTypeVar{}); + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); } - else if (FFlag::LuauTypeAliasPacks) + + if (FFlag::LuauTypeAliasPacks) { if (!lit->hasParameterList && !tf->typePackParams.empty()) { reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); - return addType(ErrorTypeVar{}); + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); } std::vector typeParams; @@ -4445,7 +4542,17 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation { reportError( TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); - return addType(ErrorTypeVar{}); + + if (FFlag::LuauErrorRecoveryType) + { + // Pad the types out with error recovery types + while (typeParams.size() < tf->typeParams.size()) + typeParams.push_back(errorRecoveryType(scope)); + while (typePackParams.size() < tf->typePackParams.size()) + typePackParams.push_back(errorRecoveryTypePack(scope)); + } + else + return errorRecoveryType(scope); } if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams) @@ -4464,6 +4571,14 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation for (const auto& param : lit->parameters) typeParams.push_back(resolveType(scope, *param.type)); + if (FFlag::LuauErrorRecoveryType) + { + // If there aren't enough type parameters, pad them out with error recovery types + // (we've already reported the error) + while (typeParams.size() < lit->parameters.size) + typeParams.push_back(errorRecoveryType(scope)); + } + if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams) { // If the generic parameters and the type arguments are the same, we are about to @@ -4483,8 +4598,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation props[prop.name.value] = {resolveType(scope, *prop.type)}; if (const auto& indexer = table->indexer) - tableIndexer = TableIndexer( - resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); + tableIndexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); return addType(TableTypeVar{ props, tableIndexer, scope->level, @@ -4536,14 +4650,20 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation return addType(IntersectionTypeVar{types}); } - else if (annotation.is()) + else if (const auto& tsb = annotation.as()) + { + return singletonType(tsb->value); + } + else if (const auto& tss = annotation.as()) { - return addType(ErrorTypeVar{}); + return singletonType(std::string(tss->value.data, tss->value.size)); } + else if (annotation.is()) + return errorRecoveryType(scope); else { reportError(TypeError{annotation.location, GenericError{"Unknown type annotation?"}}); - return addType(ErrorTypeVar{}); + return errorRecoveryType(scope); } } @@ -4584,7 +4704,7 @@ TypePackId TypeChecker::resolveTypePack(const ScopePtr& scope, const AstTypePack else reportError(TypeError{generic->location, UnknownSymbol{genericName, UnknownSymbol::Type}}); - return addTypePack(TypePackVar{Unifiable::Error{}}); + return errorRecoveryTypePack(scope); } return *genericTy; @@ -4706,12 +4826,12 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, if (!maybeInstantiated.has_value()) { reportError(location, UnificationTooComplex{}); - return errorType; + return errorRecoveryType(scope); } if (FFlag::LuauRecursiveTypeParameterRestriction && applyTypeFunction.encounteredForwardedType) { reportError(TypeError{location, GenericError{"Recursive type being used with different parameters"}}); - return errorType; + return errorRecoveryType(scope); } TypeId instantiated = *maybeInstantiated; @@ -4773,8 +4893,8 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, return instantiated; } -std::pair, std::vector> TypeChecker::createGenericTypes( - const ScopePtr& scope, std::optional levelOpt, const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) +std::pair, std::vector> TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional levelOpt, + const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) { LUAU_ASSERT(scope->parent); @@ -5030,7 +5150,9 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement return isaP.ty; } - return std::nullopt; + // local variable works around an odd gcc 9.3 warning: may be used uninitialized + std::optional res = std::nullopt; + return res; }; std::optional ty = resolveLValue(refis, scope, isaP.lvalue); @@ -5041,7 +5163,7 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement addRefinement(refis, isaP.lvalue, *result); else { - addRefinement(refis, isaP.lvalue, errorType); + addRefinement(refis, isaP.lvalue, errorRecoveryType(scope)); errVec.push_back(TypeError{isaP.location, TypeMismatch{isaP.ty, *ty}}); } } @@ -5105,7 +5227,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec addRefinement(refis, typeguardP.lvalue, *result); else { - addRefinement(refis, typeguardP.lvalue, errorType); + addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); if (sense) errVec.push_back( TypeError{typeguardP.location, GenericError{"Type '" + toString(*ty) + "' has no overlap with '" + typeguardP.kind + "'"}}); @@ -5116,7 +5238,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec auto fail = [&](const TypeErrorData& err) { errVec.push_back(TypeError{typeguardP.location, err}); - addRefinement(refis, typeguardP.lvalue, errorType); + addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); }; if (!typeguardP.isTypeof) @@ -5137,28 +5259,6 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec return resolve(IsAPredicate{std::move(typeguardP.lvalue), typeguardP.location, type}, errVec, refis, scope, sense); } -void TypeChecker::DEPRECATED_resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) -{ - if (!sense) - return; - - static std::vector primitives{ - "string", "number", "boolean", "nil", "thread", - "table", // no op. Requires special handling. - "function", // no op. Requires special handling. - "userdata", // no op. Requires special handling. - }; - - if (auto typeFun = globalScope->lookupType(typeguardP.kind); - typeFun && typeFun->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || typeFun->typePackParams.empty())) - { - if (auto it = std::find(primitives.begin(), primitives.end(), typeguardP.kind); it != primitives.end()) - addRefinement(refis, typeguardP.lvalue, typeFun->type); - else if (typeguardP.isTypeof) - addRefinement(refis, typeguardP.lvalue, typeFun->type); - } -} - void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) { // This refinement will require success typing to do everything correctly. For now, we can get most of the way there. diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 228b19267..d3221c732 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -286,5 +286,4 @@ TypePack* asMutable(const TypePack* tp) { return const_cast(tp); } - } // namespace Luau diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index cd447ca23..924bf082a 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -21,6 +21,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTFLAG(LuauTypeAliasPacks) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) +LUAU_FASTFLAG(LuauErrorRecoveryType) namespace Luau { @@ -293,7 +294,7 @@ bool isGeneric(TypeId ty) bool maybeGeneric(TypeId ty) { ty = follow(ty); - if (auto ftv = get(ty)) + if (get(ty)) return true; else if (auto ttv = get(ty)) { @@ -305,6 +306,18 @@ bool maybeGeneric(TypeId ty) return isGeneric(ty); } +bool maybeSingleton(TypeId ty) +{ + ty = follow(ty); + if (get(ty)) + return true; + if (const UnionTypeVar* utv = get(ty)) + for (TypeId option : utv) + if (get(follow(option))) + return true; + return false; +} + FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional defn, bool hasSelf) : argTypes(argTypes) , retType(retType) @@ -562,10 +575,8 @@ SingletonTypes::SingletonTypes() , booleanType(&booleanType_) , threadType(&threadType_) , anyType(&anyType_) - , errorType(&errorType_) , optionalNumberType(&optionalNumberType_) , anyTypePack(&anyTypePack_) - , errorTypePack(&errorTypePack_) , arena(new TypeArena) { TypeId stringMetatable = makeStringMetatable(); @@ -634,6 +645,32 @@ TypeId SingletonTypes::makeStringMetatable() return arena->addType(TableTypeVar{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } +TypeId SingletonTypes::errorRecoveryType() +{ + return &errorType_; +} + +TypePackId SingletonTypes::errorRecoveryTypePack() +{ + return &errorTypePack_; +} + +TypeId SingletonTypes::errorRecoveryType(TypeId guess) +{ + if (FFlag::LuauErrorRecoveryType) + return guess; + else + return &errorType_; +} + +TypePackId SingletonTypes::errorRecoveryTypePack(TypePackId guess) +{ + if (FFlag::LuauErrorRecoveryType) + return guess; + else + return &errorTypePack_; +} + SingletonTypes singletonTypes; void persist(TypeId ty) @@ -1141,6 +1178,11 @@ struct QVarFinder return false; } + bool operator()(const SingletonTypeVar&) const + { + return false; + } + bool operator()(const FunctionTypeVar& ftv) const { if (hasGeneric(ftv.argTypes)) @@ -1412,7 +1454,7 @@ static std::vector parseFormatString(TypeChecker& typechecker, const cha else if (strchr(options, data[i])) result.push_back(typechecker.numberType); else - result.push_back(typechecker.errorType); + result.push_back(typechecker.errorRecoveryType(typechecker.anyType)); } } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 82f621b66..e1a52be4e 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,7 +22,9 @@ LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false) LUAU_FASTFLAG(LuauShareTxnSeen); LUAU_FASTFLAGVARIABLE(LuauCacheUnifyTableResults, false) LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false) +LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false) +LUAU_FASTFLAG(LuauErrorRecoveryType); namespace Luau { @@ -211,6 +213,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { occursCheck(subTy, superTy); + // The occurrence check might have caused superTy no longer to be a free type if (!get(subTy)) { log(subTy); @@ -221,10 +224,20 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool } else if (l && r) { - log(superTy); + if (!FFlag::LuauErrorRecoveryType) + log(superTy); occursCheck(superTy, subTy); r->level = min(r->level, l->level); - *asMutable(superTy) = BoundTypeVar(subTy); + + // The occurrence check might have caused superTy no longer to be a free type + if (!FFlag::LuauErrorRecoveryType) + *asMutable(superTy) = BoundTypeVar(subTy); + else if (!get(superTy)) + { + log(superTy); + *asMutable(superTy) = BoundTypeVar(subTy); + } + return; } else if (l) @@ -240,6 +253,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool return; } + // The occurrence check might have caused superTy no longer to be a free type if (!get(superTy)) { if (auto rightLevel = getMutableLevel(subTy)) @@ -251,6 +265,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool log(superTy); *asMutable(superTy) = BoundTypeVar(subTy); } + return; } else if (r) @@ -512,6 +527,9 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool else if (get(superTy) && get(subTy)) tryUnifyPrimitives(superTy, subTy); + else if (FFlag::LuauSingletonTypes && (get(superTy) || get(superTy)) && get(subTy)) + tryUnifySingletons(superTy, subTy); + else if (get(superTy) && get(subTy)) tryUnifyFunctions(superTy, subTy, isFunctionCall); @@ -723,17 +741,18 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal { occursCheck(superTp, subTp); + // The occurrence check might have caused superTp no longer to be a free type if (!get(superTp)) { log(superTp); *asMutable(superTp) = Unifiable::Bound(subTp); } } - else if (get(subTp)) { occursCheck(subTp, superTp); + // The occurrence check might have caused superTp no longer to be a free type if (!get(subTp)) { log(subTp); @@ -874,13 +893,13 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal while (superIter.good()) { - tryUnify_(singletonTypes.errorType, *superIter); + tryUnify_(singletonTypes.errorRecoveryType(), *superIter); superIter.advance(); } while (subIter.good()) { - tryUnify_(singletonTypes.errorType, *subIter); + tryUnify_(singletonTypes.errorRecoveryType(), *subIter); subIter.advance(); } @@ -906,6 +925,27 @@ void Unifier::tryUnifyPrimitives(TypeId superTy, TypeId subTy) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } +void Unifier::tryUnifySingletons(TypeId superTy, TypeId subTy) +{ + const PrimitiveTypeVar* lp = get(superTy); + const SingletonTypeVar* ls = get(superTy); + const SingletonTypeVar* rs = get(subTy); + + if ((!lp && !ls) || !rs) + ice("passed non singleton/primitive types to unifySingletons"); + + if (ls && *ls == *rs) + return; + + if (lp && lp->type == PrimitiveTypeVar::Boolean && get(rs) && variance == Covariant) + return; + + if (lp && lp->type == PrimitiveTypeVar::String && get(rs) && variance == Covariant) + return; + + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); +} + void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall) { FunctionTypeVar* lf = getMutable(superTy); @@ -1023,7 +1063,8 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) } // And vice versa if we're invariant - if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && lt->state != TableState::Free) + if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && + lt->state != TableState::Free) { for (const auto& [propName, subProp] : rt->props) { @@ -1038,7 +1079,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) return; } } - + // Reminder: left is the supertype, right is the subtype. // Width subtyping: any property in the supertype must be in the subtype, // and the types must agree. @@ -1634,9 +1675,8 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) { ok = false; errors.push_back(TypeError{location, UnknownProperty{superTy, propName}}); - if (!FFlag::LuauExtendedClassMismatchError) - tryUnify_(prop.type, singletonTypes.errorType); + tryUnify_(prop.type, singletonTypes.errorRecoveryType()); } else { @@ -1952,7 +1992,7 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) { LUAU_ASSERT(get(any)); - const TypeId anyTy = singletonTypes.errorType; + const TypeId anyTy = singletonTypes.errorRecoveryType(); if (FFlag::LuauTypecheckOpts) { @@ -2046,7 +2086,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHash { errors.push_back(TypeError{location, OccursCheckFailed{}}); log(needle); - *asMutable(needle) = ErrorTypeVar{}; + *asMutable(needle) = *singletonTypes.errorRecoveryType(); return; } @@ -2134,7 +2174,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, Dense { errors.push_back(TypeError{location, OccursCheckFailed{}}); log(needle); - *asMutable(needle) = ErrorTypeVar{}; + *asMutable(needle) = *singletonTypes.errorRecoveryTypePack(); return; } diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index a2189f7b7..5b4bfa033 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -255,6 +255,14 @@ class AstVisitor { return visit((class AstType*)node); } + virtual bool visit(class AstTypeSingletonBool* node) + { + return visit((class AstType*)node); + } + virtual bool visit(class AstTypeSingletonString* node) + { + return visit((class AstType*)node); + } virtual bool visit(class AstTypeError* node) { return visit((class AstType*)node); @@ -1158,6 +1166,30 @@ class AstTypeError : public AstType unsigned messageIndex; }; +class AstTypeSingletonBool : public AstType +{ +public: + LUAU_RTTI(AstTypeSingletonBool) + + AstTypeSingletonBool(const Location& location, bool value); + + void visit(AstVisitor* visitor) override; + + bool value; +}; + +class AstTypeSingletonString : public AstType +{ +public: + LUAU_RTTI(AstTypeSingletonString) + + AstTypeSingletonString(const Location& location, const AstArray& value); + + void visit(AstVisitor* visitor) override; + + const AstArray value; +}; + class AstTypePack : public AstNode { public: diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 39c7d9251..87ebc48b5 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -286,6 +286,7 @@ class Parser // `<' typeAnnotation[, ...] `>' AstArray parseTypeParams(); + std::optional> parseCharArray(); AstExpr* parseString(); AstLocal* pushLocal(const Binding& binding); diff --git a/Ast/include/Luau/StringUtils.h b/Ast/include/Luau/StringUtils.h index 4f7673fab..6ecf06062 100644 --- a/Ast/include/Luau/StringUtils.h +++ b/Ast/include/Luau/StringUtils.h @@ -34,4 +34,6 @@ bool equalsLower(std::string_view lhs, std::string_view rhs); size_t hashRange(const char* data, size_t size); +std::string escape(std::string_view s); +bool isIdentifier(std::string_view s); } // namespace Luau diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index b1209faa1..e709894d9 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -841,6 +841,28 @@ void AstTypeIntersection::visit(AstVisitor* visitor) } } +AstTypeSingletonBool::AstTypeSingletonBool(const Location& location, bool value) + : AstType(ClassIndex(), location) + , value(value) +{ +} + +void AstTypeSingletonBool::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + +AstTypeSingletonString::AstTypeSingletonString(const Location& location, const AstArray& value) + : AstType(ClassIndex(), location) + , value(value) +{ +} + +void AstTypeSingletonString::visit(AstVisitor* visitor) +{ + visitor->visit(this); +} + AstTypeError::AstTypeError(const Location& location, const AstArray& types, bool isMissing, unsigned messageIndex) : AstType(ClassIndex(), location) , types(types) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index a1bad65ef..bc63e37dc 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -16,6 +16,7 @@ LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) LUAU_FASTFLAGVARIABLE(LuauTypeAliasPacks, false) LUAU_FASTFLAGVARIABLE(LuauParseTypePackTypeParameters, false) LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) +LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctionTypeBegin, false) namespace Luau @@ -1278,7 +1279,27 @@ AstType* Parser::parseTableTypeAnnotation() while (lexer.current().type != '}') { - if (lexer.current().type == '[') + if (FFlag::LuauParseSingletonTypes && lexer.current().type == '[' && + (lexer.lookahead().type == Lexeme::RawString || lexer.lookahead().type == Lexeme::QuotedString)) + { + const Lexeme begin = lexer.current(); + nextLexeme(); // [ + std::optional> chars = parseCharArray(); + + expectMatchAndConsume(']', begin); + expectAndConsume(':', "table field"); + + AstType* type = parseTypeAnnotation(); + + // TODO: since AstName conains a char*, it can't contain null + bool containsNull = chars && (strnlen(chars->data, chars->size) < chars->size); + + if (chars && !containsNull) + props.push_back({AstName(chars->data), begin.location, type}); + else + report(begin.location, "String literal contains malformed escape sequence"); + } + else if (lexer.current().type == '[') { if (indexer) { @@ -1528,6 +1549,32 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) nextLexeme(); return {allocator.alloc(begin, std::nullopt, nameNil), {}}; } + else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::ReservedTrue) + { + nextLexeme(); + return {allocator.alloc(begin, true)}; + } + else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::ReservedFalse) + { + nextLexeme(); + return {allocator.alloc(begin, false)}; + } + else if (FFlag::LuauParseSingletonTypes && (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString)) + { + if (std::optional> value = parseCharArray()) + { + AstArray svalue = *value; + return {allocator.alloc(begin, svalue)}; + } + else + return {reportTypeAnnotationError(begin, {}, /*isMissing*/ false, "String literal contains malformed escape sequence")}; + } + else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::BrokenString) + { + Location location = lexer.current().location; + nextLexeme(); + return {reportTypeAnnotationError(location, {}, /*isMissing*/ false, "Malformed string")}; + } else if (lexer.current().type == Lexeme::Name) { std::optional prefix; @@ -2416,7 +2463,7 @@ AstArray Parser::parseTypeParams() return copy(parameters); } -AstExpr* Parser::parseString() +std::optional> Parser::parseCharArray() { LUAU_ASSERT(lexer.current().type == Lexeme::QuotedString || lexer.current().type == Lexeme::RawString); @@ -2426,11 +2473,8 @@ AstExpr* Parser::parseString() { if (!Lexer::fixupQuotedString(scratchData)) { - Location location = lexer.current().location; - nextLexeme(); - - return reportExprError(location, {}, "String literal contains malformed escape sequence"); + return std::nullopt; } } else @@ -2438,12 +2482,18 @@ AstExpr* Parser::parseString() Lexer::fixupMultilineString(scratchData); } - Location start = lexer.current().location; AstArray value = copy(scratchData); - nextLexeme(); + return value; +} - return allocator.alloc(start, value); +AstExpr* Parser::parseString() +{ + Location location = lexer.current().location; + if (std::optional> value = parseCharArray()) + return allocator.alloc(location, *value); + else + return reportExprError(location, {}, "String literal contains malformed escape sequence"); } AstLocal* Parser::pushLocal(const Binding& binding) diff --git a/Ast/src/StringUtils.cpp b/Ast/src/StringUtils.cpp index 24b2283a0..9c7fed316 100644 --- a/Ast/src/StringUtils.cpp +++ b/Ast/src/StringUtils.cpp @@ -225,4 +225,62 @@ size_t hashRange(const char* data, size_t size) return hash; } +bool isIdentifier(std::string_view s) +{ + return (s.find_first_not_of("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ01234567890_") == std::string::npos); +} + +std::string escape(std::string_view s) +{ + std::string r; + r.reserve(s.size() + 50); // arbitrary number to guess how many characters we'll be inserting + + for (uint8_t c : s) + { + if (c >= ' ' && c != '\\' && c != '\'' && c != '\"') + r += c; + else + { + r += '\\'; + + switch (c) + { + case '\a': + r += 'a'; + break; + case '\b': + r += 'b'; + break; + case '\f': + r += 'f'; + break; + case '\n': + r += 'n'; + break; + case '\r': + r += 'r'; + break; + case '\t': + r += 't'; + break; + case '\v': + r += 'v'; + break; + case '\'': + r += '\''; + break; + case '\"': + r += '\"'; + break; + case '\\': + r += '\\'; + break; + default: + Luau::formatAppend(r, "%03u", c); + } + } + } + + return r; +} } // namespace Luau diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 9ab10aaf0..ebdd78966 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -236,32 +236,12 @@ int main(int argc, char** argv) Luau::registerBuiltinTypes(frontend.typeChecker); Luau::freeze(frontend.typeChecker.globalTypes); - int failed = 0; + std::vector files = getSourceFiles(argc, argv); - for (int i = 1; i < argc; ++i) - { - if (argv[i][0] == '-') - continue; + int failed = 0; - if (isDirectory(argv[i])) - { - traverseDirectory(argv[i], [&](const std::string& name) { - // Look for .luau first and if absent, fall back to .lua - if (name.length() > 5 && name.rfind(".luau") == name.length() - 5) - { - failed += !analyzeFile(frontend, name.c_str(), format, annotate); - } - else if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) - { - failed += !analyzeFile(frontend, name.c_str(), format, annotate); - } - }); - } - else - { - failed += !analyzeFile(frontend, argv[i], format, annotate); - } - } + for (const std::string& path : files) + failed += !analyzeFile(frontend, path.c_str(), format, annotate); if (!configResolver.configErrors.empty()) { diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index 0702b74f1..b3c9557bb 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -142,6 +142,7 @@ static bool traverseDirectoryRec(const std::string& path, const std::function getParentPath(const std::string& path) return ""; } + +static std::string getExtension(const std::string& path) +{ + std::string::size_type dot = path.find_last_of(".\\/"); + + if (dot == std::string::npos || path[dot] != '.') + return ""; + + return path.substr(dot); +} + +std::vector getSourceFiles(int argc, char** argv) +{ + std::vector files; + + for (int i = 1; i < argc; ++i) + { + if (argv[i][0] == '-') + continue; + + if (isDirectory(argv[i])) + { + traverseDirectory(argv[i], [&](const std::string& name) { + std::string ext = getExtension(name); + + if (ext == ".lua" || ext == ".luau") + files.push_back(name); + }); + } + else + { + files.push_back(argv[i]); + } + } + + return files; +} diff --git a/CLI/FileUtils.h b/CLI/FileUtils.h index f7fbe8afc..da11f512d 100644 --- a/CLI/FileUtils.h +++ b/CLI/FileUtils.h @@ -4,6 +4,7 @@ #include #include #include +#include std::optional readFile(const std::string& name); @@ -12,3 +13,5 @@ bool traverseDirectory(const std::string& path, const std::function getParentPath(const std::string& path); + +std::vector getSourceFiles(int argc, char** argv); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 5c904cca4..b29cd6f9c 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -20,7 +20,7 @@ enum class CompileFormat { - Default, + Text, Binary }; @@ -33,7 +33,7 @@ static int lua_loadstring(lua_State* L) lua_setsafeenv(L, LUA_ENVIRONINDEX, false); std::string bytecode = Luau::compile(std::string(s, l)); - if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0) + if (luau_load(L, chunkname, bytecode.data(), bytecode.size(), 0) == 0) return 1; lua_pushnil(L); @@ -80,7 +80,7 @@ static int lua_require(lua_State* L) // now we can compile & run module on the new thread std::string bytecode = Luau::compile(*source); - if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) + if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { int status = lua_resume(ML, L, 0); @@ -151,7 +151,7 @@ static std::string runCode(lua_State* L, const std::string& source) { std::string bytecode = Luau::compile(source); - if (luau_load(L, "=stdin", bytecode.data(), bytecode.size()) != 0) + if (luau_load(L, "=stdin", bytecode.data(), bytecode.size(), 0) != 0) { size_t len; const char* msg = lua_tolstring(L, -1, &len); @@ -370,7 +370,7 @@ static bool runFile(const char* name, lua_State* GL) std::string bytecode = Luau::compile(*source); int status = 0; - if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) + if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { status = lua_resume(L, NULL, 0); } @@ -379,11 +379,7 @@ static bool runFile(const char* name, lua_State* GL) status = LUA_ERRSYNTAX; } - if (status == 0) - { - return true; - } - else + if (status != 0) { std::string error; @@ -400,8 +396,10 @@ static bool runFile(const char* name, lua_State* GL) error += lua_debugtrace(L); fprintf(stderr, "%s", error.c_str()); - return false; } + + lua_pop(GL, 1); + return status == 0; } static void report(const char* name, const Luau::Location& location, const char* type, const char* message) @@ -431,14 +429,18 @@ static bool compileFile(const char* name, CompileFormat format) try { Luau::BytecodeBuilder bcb; - bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source); - bcb.setDumpSource(*source); + + if (format == CompileFormat::Text) + { + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source); + bcb.setDumpSource(*source); + } Luau::compileOrThrow(bcb, *source); switch (format) { - case CompileFormat::Default: + case CompileFormat::Text: printf("%s", bcb.dumpEverything().c_str()); break; case CompileFormat::Binary: @@ -504,7 +506,7 @@ int main(int argc, char** argv) if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0) { - CompileFormat format = CompileFormat::Default; + CompileFormat format = CompileFormat::Text; if (strcmp(argv[1], "--compile=binary") == 0) format = CompileFormat::Binary; @@ -514,27 +516,12 @@ int main(int argc, char** argv) _setmode(_fileno(stdout), _O_BINARY); #endif + std::vector files = getSourceFiles(argc, argv); + int failed = 0; - for (int i = 2; i < argc; ++i) - { - if (argv[i][0] == '-') - continue; - - if (isDirectory(argv[i])) - { - traverseDirectory(argv[i], [&](const std::string& name) { - if (name.length() > 5 && name.rfind(".luau") == name.length() - 5) - failed += !compileFile(name.c_str(), format); - else if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) - failed += !compileFile(name.c_str(), format); - }); - } - else - { - failed += !compileFile(argv[i], format); - } - } + for (const std::string& path : files) + failed += !compileFile(path.c_str(), format); return failed; } @@ -548,33 +535,25 @@ int main(int argc, char** argv) int profile = 0; for (int i = 1; i < argc; ++i) + { + if (argv[i][0] != '-') + continue; + if (strcmp(argv[i], "--profile") == 0) profile = 10000; // default to 10 KHz else if (strncmp(argv[i], "--profile=", 10) == 0) profile = atoi(argv[i] + 10); + } if (profile) profilerStart(L, profile); - int failed = 0; + std::vector files = getSourceFiles(argc, argv); - for (int i = 1; i < argc; ++i) - { - if (argv[i][0] == '-') - continue; + int failed = 0; - if (isDirectory(argv[i])) - { - traverseDirectory(argv[i], [&](const std::string& name) { - if (name.length() > 4 && name.rfind(".lua") == name.length() - 4) - failed += !runFile(name.c_str(), L); - }); - } - else - { - failed += !runFile(argv[i], L); - } - } + for (const std::string& path : files) + failed += !runFile(path.c_str(), L); if (profile) { diff --git a/CMakeLists.txt b/CMakeLists.txt index 36014a983..9c69521ed 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,6 +9,7 @@ project(Luau LANGUAGES CXX) option(LUAU_BUILD_CLI "Build CLI" ON) option(LUAU_BUILD_TESTS "Build tests" ON) +option(LUAU_WERROR "Warnings as errors" OFF) add_library(Luau.Ast STATIC) add_library(Luau.Compiler STATIC) @@ -57,11 +58,18 @@ set(LUAU_OPTIONS) if(MSVC) list(APPEND LUAU_OPTIONS /D_CRT_SECURE_NO_WARNINGS) # We need to use the portable CRT functions. - list(APPEND LUAU_OPTIONS /WX) # Warnings are errors list(APPEND LUAU_OPTIONS /MP) # Distribute single project compilation across multiple cores else() list(APPEND LUAU_OPTIONS -Wall) # All warnings - list(APPEND LUAU_OPTIONS -Werror) # Warnings are errors +endif() + +# Enabled in CI; we should be warning free on our main compiler versions but don't guarantee being warning free everywhere +if(LUAU_WERROR) + if(MSVC) + list(APPEND LUAU_OPTIONS /WX) # Warnings are errors + else() + list(APPEND LUAU_OPTIONS -Werror) # Warnings are errors + endif() endif() target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS}) @@ -79,7 +87,10 @@ if(LUAU_BUILD_CLI) target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.VM) if(UNIX) - target_link_libraries(Luau.Repl.CLI PRIVATE pthread) + find_library(LIBPTHREAD pthread) + if (LIBPTHREAD) + target_link_libraries(Luau.Repl.CLI PRIVATE pthread) + endif() endif() if(NOT EMSCRIPTEN) diff --git a/Compiler/include/Luau/Compiler.h b/Compiler/include/Luau/Compiler.h index 4f88e602e..65e962dac 100644 --- a/Compiler/include/Luau/Compiler.h +++ b/Compiler/include/Luau/Compiler.h @@ -13,11 +13,9 @@ class AstNameTable; class BytecodeBuilder; class BytecodeEncoder; +// Note: this structure is duplicated in luacode.h, don't forget to change these in sync! struct CompileOptions { - // default bytecode version target; can be used to compile code for older clients - int bytecodeVersion = 1; - // 0 - no optimization // 1 - baseline optimization level that doesn't prevent debuggability // 2 - includes optimizations that harm debuggability such as inlining diff --git a/Compiler/include/luacode.h b/Compiler/include/luacode.h new file mode 100644 index 000000000..e235a2e77 --- /dev/null +++ b/Compiler/include/luacode.h @@ -0,0 +1,39 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include + +/* Can be used to reconfigure visibility/exports for public APIs */ +#ifndef LUACODE_API +#define LUACODE_API extern +#endif + +typedef struct lua_CompileOptions lua_CompileOptions; + +struct lua_CompileOptions +{ + // 0 - no optimization + // 1 - baseline optimization level that doesn't prevent debuggability + // 2 - includes optimizations that harm debuggability such as inlining + int optimizationLevel; // default=1 + + // 0 - no debugging support + // 1 - line info & function names only; sufficient for backtraces + // 2 - full debug info with local & upvalue names; necessary for debugger + int debugLevel; // default=1 + + // 0 - no code coverage support + // 1 - statement coverage + // 2 - statement and expression coverage (verbose) + int coverageLevel; // default=0 + + // global builtin to construct vectors; disabled by default + const char* vectorLib; + const char* vectorCtor; + + // null-terminated array of globals that are mutable; disables the import optimization for fields accessed through these + const char** mutableGlobals; +}; + +/* compile source to bytecode; when source compilation fails, the resulting bytecode contains the encoded error. use free() to destroy */ +LUACODE_API char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize); diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 9712f02f4..5b93c1dc0 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -11,9 +11,6 @@ #include LUAU_FASTFLAGVARIABLE(LuauPreloadClosures, false) -LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresFenv, false) -LUAU_FASTFLAGVARIABLE(LuauPreloadClosuresUpval, false) -LUAU_FASTFLAGVARIABLE(LuauGenericSpecialGlobals, false) LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport) LUAU_FASTFLAGVARIABLE(LuauBit32CountBuiltin, false) @@ -24,9 +21,6 @@ static const uint32_t kMaxRegisterCount = 255; static const uint32_t kMaxUpvalueCount = 200; static const uint32_t kMaxLocalCount = 200; -// TODO: Remove with LuauGenericSpecialGlobals -static const char* kSpecialGlobals[] = {"Game", "Workspace", "_G", "game", "plugin", "script", "shared", "workspace"}; - CompileError::CompileError(const Location& location, const std::string& message) : location(location) , message(message) @@ -466,7 +460,7 @@ struct Compiler bool shared = false; - if (FFlag::LuauPreloadClosuresUpval) + if (FFlag::LuauPreloadClosures) { // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it @@ -482,18 +476,6 @@ struct Compiler } } } - // Optimization: when closure has no upvalues, instead of allocating it every time we can share closure objects - // (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it is used) - else if (FFlag::LuauPreloadClosures && options.optimizationLevel >= 1 && f->upvals.empty() && !setfenvUsed) - { - int32_t cid = bytecode.addConstantClosure(f->id); - - if (cid >= 0 && cid < 32768) - { - bytecode.emitAD(LOP_DUPCLOSURE, target, cid); - return; - } - } if (!shared) bytecode.emitAD(LOP_NEWCLOSURE, target, pid); @@ -3298,8 +3280,7 @@ struct Compiler bool visit(AstStatLocalFunction* node) override { // record local->function association for some optimizations - if (FFlag::LuauPreloadClosuresUpval) - self->locals[node->name].func = node->func; + self->locals[node->name].func = node->func; return true; } @@ -3711,24 +3692,13 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName Compiler compiler(bytecode, options); // since access to some global objects may result in values that change over time, we block imports from non-readonly tables - if (FFlag::LuauGenericSpecialGlobals) - { - if (AstName name = names.get("_G"); name.value) - compiler.globals[name].writable = true; + if (AstName name = names.get("_G"); name.value) + compiler.globals[name].writable = true; - if (options.mutableGlobals) - for (const char** ptr = options.mutableGlobals; *ptr; ++ptr) - if (AstName name = names.get(*ptr); name.value) - compiler.globals[name].writable = true; - } - else - { - for (const char* global : kSpecialGlobals) - { - if (AstName name = names.get(global); name.value) + if (options.mutableGlobals) + for (const char** ptr = options.mutableGlobals; *ptr; ++ptr) + if (AstName name = names.get(*ptr); name.value) compiler.globals[name].writable = true; - } - } // this visitor traverses the AST to analyze mutability of locals/globals, filling Local::written and Global::written Compiler::AssignmentVisitor assignmentVisitor(&compiler); @@ -3742,7 +3712,7 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName } // this visitor tracks calls to getfenv/setfenv and disables some optimizations when they are found - if (FFlag::LuauPreloadClosuresFenv && options.optimizationLevel >= 1 && (names.get("getfenv").value || names.get("setfenv").value)) + if (options.optimizationLevel >= 1 && (names.get("getfenv").value || names.get("setfenv").value)) { Compiler::FenvVisitor fenvVisitor(compiler.getfenvUsed, compiler.setfenvUsed); root->visit(&fenvVisitor); diff --git a/Compiler/src/lcode.cpp b/Compiler/src/lcode.cpp new file mode 100644 index 000000000..ee150b172 --- /dev/null +++ b/Compiler/src/lcode.cpp @@ -0,0 +1,29 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "luacode.h" + +#include "Luau/Compiler.h" + +#include + +char* luau_compile(const char* source, size_t size, lua_CompileOptions* options, size_t* outsize) +{ + LUAU_ASSERT(outsize); + + Luau::CompileOptions opts; + + if (options) + { + static_assert(sizeof(lua_CompileOptions) == sizeof(Luau::CompileOptions), "C and C++ interface must match"); + memcpy(static_cast(&opts), options, sizeof(opts)); + } + + std::string result = compile(std::string(source, size), opts); + + char* copy = static_cast(malloc(result.size())); + if (!copy) + return nullptr; + + memcpy(copy, result.data(), result.size()); + *outsize = result.size(); + return copy; +} diff --git a/Makefile b/Makefile index 5d51b3d4e..cab3d43f1 100644 --- a/Makefile +++ b/Makefile @@ -46,14 +46,20 @@ endif OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(VM_OBJECTS) $(TESTS_OBJECTS) $(CLI_OBJECTS) $(FUZZ_OBJECTS) # common flags -CXXFLAGS=-g -Wall -Werror +CXXFLAGS=-g -Wall LDFLAGS= -# temporary, for older gcc versions as they treat var in `if (type var = val)` as unused +# some gcc versions treat var in `if (type var = val)` as unused +# some gcc versions treat variables used in constexpr if blocks as unused ifeq ($(findstring g++,$(shell $(CXX) --version)),g++) CXXFLAGS+=-Wno-unused endif +# enabled in CI; we should be warning free on our main compiler versions but don't guarantee being warning free everywhere +ifneq ($(werror),) + CXXFLAGS+=-Werror +endif + # configuration-specific flags ifeq ($(config),release) CXXFLAGS+=-O2 -DNDEBUG diff --git a/Sources.cmake b/Sources.cmake index c30cf77d9..23b931c6b 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -25,9 +25,11 @@ target_sources(Luau.Compiler PRIVATE Compiler/include/Luau/Bytecode.h Compiler/include/Luau/BytecodeBuilder.h Compiler/include/Luau/Compiler.h + Compiler/include/luacode.h Compiler/src/BytecodeBuilder.cpp Compiler/src/Compiler.cpp + Compiler/src/lcode.cpp ) # Luau.Analysis Sources @@ -204,6 +206,7 @@ if(TARGET Luau.UnitTest) tests/TypeInfer.intersectionTypes.test.cpp tests/TypeInfer.provisional.test.cpp tests/TypeInfer.refinements.test.cpp + tests/TypeInfer.singletons.test.cpp tests/TypeInfer.tables.test.cpp tests/TypeInfer.test.cpp tests/TypeInfer.tryUnify.test.cpp diff --git a/VM/include/lua.h b/VM/include/lua.h index a9d3e875a..1568d191f 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -102,6 +102,8 @@ LUA_API lua_State* lua_newstate(lua_Alloc f, void* ud); LUA_API void lua_close(lua_State* L); LUA_API lua_State* lua_newthread(lua_State* L); LUA_API lua_State* lua_mainthread(lua_State* L); +LUA_API void lua_resetthread(lua_State* L); +LUA_API int lua_isthreadreset(lua_State* L); /* ** basic stack manipulation @@ -162,8 +164,7 @@ LUA_API void lua_pushlstring(lua_State* L, const char* s, size_t l); LUA_API void lua_pushstring(lua_State* L, const char* s); LUA_API const char* lua_pushvfstring(lua_State* L, const char* fmt, va_list argp); LUA_API LUA_PRINTF_ATTR(2, 3) const char* lua_pushfstringL(lua_State* L, const char* fmt, ...); -LUA_API void lua_pushcfunction( - lua_State* L, lua_CFunction fn, const char* debugname = NULL, int nup = 0, lua_Continuation cont = NULL); +LUA_API void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont); LUA_API void lua_pushboolean(lua_State* L, int b); LUA_API void lua_pushlightuserdata(lua_State* L, void* p); LUA_API int lua_pushthread(lua_State* L); @@ -178,9 +179,9 @@ LUA_API void lua_rawget(lua_State* L, int idx); LUA_API void lua_rawgeti(lua_State* L, int idx, int n); LUA_API void lua_createtable(lua_State* L, int narr, int nrec); -LUA_API void lua_setreadonly(lua_State* L, int idx, bool value); +LUA_API void lua_setreadonly(lua_State* L, int idx, int enabled); LUA_API int lua_getreadonly(lua_State* L, int idx); -LUA_API void lua_setsafeenv(lua_State* L, int idx, bool value); +LUA_API void lua_setsafeenv(lua_State* L, int idx, int enabled); LUA_API void* lua_newuserdata(lua_State* L, size_t sz, int tag); LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)); @@ -200,7 +201,7 @@ LUA_API int lua_setfenv(lua_State* L, int idx); /* ** `load' and `call' functions (load and run Luau bytecode) */ -LUA_API int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size, int env = 0); +LUA_API int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size, int env); LUA_API void lua_call(lua_State* L, int nargs, int nresults); LUA_API int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc); @@ -293,6 +294,8 @@ LUA_API void lua_unref(lua_State* L, int ref); #define lua_isnoneornil(L, n) (lua_type(L, (n)) <= LUA_TNIL) #define lua_pushliteral(L, s) lua_pushlstring(L, "" s, (sizeof(s) / sizeof(char)) - 1) +#define lua_pushcfunction(L, fn, debugname) lua_pushcclosurek(L, fn, debugname, 0, NULL) +#define lua_pushcclosure(L, fn, debugname, nup) lua_pushcclosurek(L, fn, debugname, nup, NULL) #define lua_setglobal(L, s) lua_setfield(L, LUA_GLOBALSINDEX, (s)) #define lua_getglobal(L, s) lua_getfield(L, LUA_GLOBALSINDEX, (s)) @@ -319,8 +322,8 @@ LUA_API const char* lua_setlocal(lua_State* L, int level, int n); LUA_API const char* lua_getupvalue(lua_State* L, int funcindex, int n); LUA_API const char* lua_setupvalue(lua_State* L, int funcindex, int n); -LUA_API void lua_singlestep(lua_State* L, bool singlestep); -LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, bool enable); +LUA_API void lua_singlestep(lua_State* L, int enabled); +LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled); /* Warning: this function is not thread-safe since it stores the result in a shared global array! Only use for debugging. */ LUA_API const char* lua_debugtrace(lua_State* L); @@ -361,6 +364,7 @@ struct lua_Callbacks void (*debuginterrupt)(lua_State* L, lua_Debug* ar); /* gets called when thread execution is interrupted by break in another thread */ void (*debugprotectederror)(lua_State* L); /* gets called when protected call results in an error */ }; +typedef struct lua_Callbacks lua_Callbacks; LUA_API lua_Callbacks* lua_callbacks(lua_State* L); diff --git a/VM/include/lualib.h b/VM/include/lualib.h index 30cffaff8..fa836955c 100644 --- a/VM/include/lualib.h +++ b/VM/include/lualib.h @@ -8,11 +8,12 @@ #define luaL_typeerror(L, narg, tname) luaL_typeerrorL(L, narg, tname) #define luaL_argerror(L, narg, extramsg) luaL_argerrorL(L, narg, extramsg) -typedef struct luaL_Reg +struct luaL_Reg { const char* name; lua_CFunction func; -} luaL_Reg; +}; +typedef struct luaL_Reg luaL_Reg; LUALIB_API void luaL_register(lua_State* L, const char* libname, const luaL_Reg* l); LUALIB_API int luaL_getmetafield(lua_State* L, int obj, const char* e); @@ -75,6 +76,7 @@ struct luaL_Buffer struct TString* storage; char buffer[LUA_BUFFERSIZE]; }; +typedef struct luaL_Buffer luaL_Buffer; // when internal buffer storage is exhausted, a mutable string value 'storage' will be placed on the stack // in general, functions expect the mutable string buffer to be placed on top of the stack (top-1) diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 7e742644f..a79ba0d40 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -593,7 +593,7 @@ const char* lua_pushfstringL(lua_State* L, const char* fmt, ...) return ret; } -void lua_pushcfunction(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont) +void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont) { luaC_checkGC(L); luaC_checkthreadsleep(L); @@ -698,13 +698,13 @@ void lua_createtable(lua_State* L, int narray, int nrec) return; } -void lua_setreadonly(lua_State* L, int objindex, bool value) +void lua_setreadonly(lua_State* L, int objindex, int enabled) { const TValue* o = index2adr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); api_check(L, t != hvalue(registry(L))); - t->readonly = value; + t->readonly = bool(enabled); return; } @@ -717,12 +717,12 @@ int lua_getreadonly(lua_State* L, int objindex) return res; } -void lua_setsafeenv(lua_State* L, int objindex, bool value) +void lua_setsafeenv(lua_State* L, int objindex, int enabled) { const TValue* o = index2adr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); - t->safeenv = value; + t->safeenv = bool(enabled); return; } diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp index 87fc16318..61798e2bc 100644 --- a/VM/src/lbaselib.cpp +++ b/VM/src/lbaselib.cpp @@ -436,8 +436,8 @@ static const luaL_Reg base_funcs[] = { static void auxopen(lua_State* L, const char* name, lua_CFunction f, lua_CFunction u) { - lua_pushcfunction(L, u); - lua_pushcfunction(L, f, name, 1); + lua_pushcfunction(L, u, NULL); + lua_pushcclosure(L, f, name, 1); lua_setfield(L, -2, name); } @@ -456,10 +456,10 @@ LUALIB_API int luaopen_base(lua_State* L) auxopen(L, "ipairs", luaB_ipairs, luaB_inext); auxopen(L, "pairs", luaB_pairs, luaB_next); - lua_pushcfunction(L, luaB_pcally, "pcall", 0, luaB_pcallcont); + lua_pushcclosurek(L, luaB_pcally, "pcall", 0, luaB_pcallcont); lua_setfield(L, -2, "pcall"); - lua_pushcfunction(L, luaB_xpcally, "xpcall", 0, luaB_xpcallcont); + lua_pushcclosurek(L, luaB_xpcally, "xpcall", 0, luaB_xpcallcont); lua_setfield(L, -2, "xpcall"); return 1; diff --git a/VM/src/lbitlib.cpp b/VM/src/lbitlib.cpp index c72fe6748..907c43c42 100644 --- a/VM/src/lbitlib.cpp +++ b/VM/src/lbitlib.cpp @@ -2,6 +2,7 @@ // This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details #include "lualib.h" +#include "lcommon.h" #include "lnumutils.h" LUAU_FASTFLAGVARIABLE(LuauBit32Count, false) diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index 9724c0e72..0178fae84 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -5,6 +5,8 @@ #include "lstate.h" #include "lvm.h" +LUAU_FASTFLAGVARIABLE(LuauCoroutineClose, false) + #define CO_RUN 0 /* running */ #define CO_SUS 1 /* suspended */ #define CO_NOR 2 /* 'normal' (it resumed another coroutine) */ @@ -208,8 +210,7 @@ static int cowrap(lua_State* L) { cocreate(L); - lua_pushcfunction(L, auxwrapy, NULL, 1, auxwrapcont); - + lua_pushcclosurek(L, auxwrapy, NULL, 1, auxwrapcont); return 1; } @@ -232,6 +233,34 @@ static int coyieldable(lua_State* L) return 1; } +static int coclose(lua_State* L) +{ + if (!FFlag::LuauCoroutineClose) + luaL_error(L, "coroutine.close is not enabled"); + + lua_State* co = lua_tothread(L, 1); + luaL_argexpected(L, co, 1, "thread"); + + int status = auxstatus(L, co); + if (status != CO_DEAD && status != CO_SUS) + luaL_error(L, "cannot close %s coroutine", statnames[status]); + + if (co->status == LUA_OK || co->status == LUA_YIELD) + { + lua_pushboolean(L, true); + lua_resetthread(co); + return 1; + } + else + { + lua_pushboolean(L, false); + if (lua_gettop(co)) + lua_xmove(co, L, 1); /* move error message */ + lua_resetthread(co); + return 2; + } +} + static const luaL_Reg co_funcs[] = { {"create", cocreate}, {"running", corunning}, @@ -239,6 +268,7 @@ static const luaL_Reg co_funcs[] = { {"wrap", cowrap}, {"yield", coyield}, {"isyieldable", coyieldable}, + {"close", coclose}, {NULL, NULL}, }; @@ -246,7 +276,7 @@ LUALIB_API int luaopen_coroutine(lua_State* L) { luaL_register(L, LUA_COLIBNAME, co_funcs); - lua_pushcfunction(L, coresumey, "resume", 0, coresumecont); + lua_pushcclosurek(L, coresumey, "resume", 0, coresumecont); lua_setfield(L, -2, "resume"); return 1; diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 1890e6823..d77f84ef9 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -316,7 +316,7 @@ void luaG_breakpoint(lua_State* L, Proto* p, int line, bool enable) p->debuginsn[j] = LUAU_INSN_OP(p->code[j]); } - uint8_t op = enable ? LOP_BREAK : LUAU_INSN_OP(p->code[i]); + uint8_t op = enable ? LOP_BREAK : LUAU_INSN_OP(p->debuginsn[i]); // patch just the opcode byte, leave arguments alone p->code[i] &= ~0xff; @@ -357,17 +357,17 @@ int luaG_getline(Proto* p, int pc) return p->abslineinfo[pc >> p->linegaplog2] + p->lineinfo[pc]; } -void lua_singlestep(lua_State* L, bool singlestep) +void lua_singlestep(lua_State* L, int enabled) { - L->singlestep = singlestep; + L->singlestep = bool(enabled); } -void lua_breakpoint(lua_State* L, int funcindex, int line, bool enable) +void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled) { const TValue* func = luaA_toobject(L, funcindex); api_check(L, ttisfunction(func) && !clvalue(func)->isC); - luaG_breakpoint(L, clvalue(func)->l.p, line, enable); + luaG_breakpoint(L, clvalue(func)->l.p, line, bool(enabled)); } static size_t append(char* buf, size_t bufsize, size_t offset, const char* data) diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 328b47e69..1259d4619 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -19,6 +19,7 @@ LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false) LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false) +LUAU_FASTFLAG(LuauCoroutineClose) /* ** {====================================================== @@ -300,7 +301,10 @@ static void resume(lua_State* L, void* ud) if (L->status == 0) { // start coroutine - LUAU_ASSERT(L->ci == L->base_ci && firstArg > L->base); + LUAU_ASSERT(L->ci == L->base_ci && firstArg >= L->base); + if (FFlag::LuauCoroutineClose && firstArg == L->base) + luaG_runerror(L, "cannot resume dead coroutine"); + if (luau_precall(L, firstArg - 1, LUA_MULTRET) != PCRLUA) return; diff --git a/VM/src/linit.cpp b/VM/src/linit.cpp index bf5e738f9..4e40165ab 100644 --- a/VM/src/linit.cpp +++ b/VM/src/linit.cpp @@ -22,7 +22,7 @@ LUALIB_API void luaL_openlibs(lua_State* L) const luaL_Reg* lib = lualibs; for (; lib->func; lib++) { - lua_pushcfunction(L, lib->func); + lua_pushcfunction(L, lib->func, NULL); lua_pushstring(L, lib->name); lua_call(L, 1, 0); } diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index 0b2dfb692..24e970635 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -124,6 +124,34 @@ void luaE_freethread(lua_State* L, lua_State* L1) luaM_free(L, L1, sizeof(lua_State), L1->memcat); } +void lua_resetthread(lua_State* L) +{ + /* close upvalues before clearing anything */ + luaF_close(L, L->stack); + /* clear call frames */ + CallInfo* ci = L->base_ci; + ci->func = L->stack; + ci->base = ci->func + 1; + ci->top = ci->base + LUA_MINSTACK; + setnilvalue(ci->func); + L->ci = ci; + luaD_reallocCI(L, BASIC_CI_SIZE); + /* clear thread state */ + L->status = LUA_OK; + L->base = L->ci->base; + L->top = L->ci->base; + L->nCcalls = L->baseCcalls = 0; + /* clear thread stack */ + luaD_reallocstack(L, BASIC_STACK_SIZE); + for (int i = 0; i < L->stacksize; i++) + setnilvalue(L->stack + i); +} + +int lua_isthreadreset(lua_State* L) +{ + return L->ci == L->base_ci && L->base == L->top && L->status == LUA_OK; +} + lua_State* lua_newstate(lua_Alloc f, void* ud) { int i; diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index 80a34483a..b576f8093 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -748,7 +748,7 @@ static int gmatch(lua_State* L) luaL_checkstring(L, 2); lua_settop(L, 2); lua_pushinteger(L, 0); - lua_pushcfunction(L, gmatch_aux, NULL, 3); + lua_pushcclosure(L, gmatch_aux, NULL, 3); return 1; } diff --git a/VM/src/lutf8lib.cpp b/VM/src/lutf8lib.cpp index 6a0262962..378de3d0d 100644 --- a/VM/src/lutf8lib.cpp +++ b/VM/src/lutf8lib.cpp @@ -265,7 +265,7 @@ static int iter_aux(lua_State* L) static int iter_codes(lua_State* L) { luaL_checkstring(L, 1); - lua_pushcfunction(L, iter_aux); + lua_pushcfunction(L, iter_aux, NULL); lua_pushvalue(L, 1); lua_pushinteger(L, 0); return 3; diff --git a/bench/bench.py b/bench/bench.py index b23ca8913..39f219f31 100644 --- a/bench/bench.py +++ b/bench/bench.py @@ -25,8 +25,8 @@ import scipy from scipy import stats except ModuleNotFoundError: - print("scipy package is required") - exit(1) + print("Warning: scipy package is not installed, confidence values will not be available") + stats = None scriptdir = os.path.dirname(os.path.realpath(__file__)) defaultVm = 'luau.exe' if os.name == "nt" else './luau' @@ -200,11 +200,14 @@ def finalizeResult(result): result.sampleStdDev = math.sqrt(sumOfSquares / (result.count - 1)) result.unbiasedEst = result.sampleStdDev * result.sampleStdDev - # Two-tailed distribution with 95% conf. - tValue = stats.t.ppf(1 - 0.05 / 2, result.count - 1) + if stats: + # Two-tailed distribution with 95% conf. + tValue = stats.t.ppf(1 - 0.05 / 2, result.count - 1) - # Compute confidence interval - result.sampleConfidenceInterval = tValue * result.sampleStdDev / math.sqrt(result.count) + # Compute confidence interval + result.sampleConfidenceInterval = tValue * result.sampleStdDev / math.sqrt(result.count) + else: + result.sampleConfidenceInterval = result.sampleStdDev else: result.sampleStdDev = 0 result.unbiasedEst = 0 @@ -377,14 +380,19 @@ def analyzeResult(subdir, main, comparisons): tStat = abs(main.avg - compare.avg) / (pooledStdDev * math.sqrt(2 / main.count)) degreesOfFreedom = 2 * main.count - 2 - # Two-tailed distribution with 95% conf. - tCritical = stats.t.ppf(1 - 0.05 / 2, degreesOfFreedom) - - noSignificantDifference = tStat < tCritical + if stats: + # Two-tailed distribution with 95% conf. + tCritical = stats.t.ppf(1 - 0.05 / 2, degreesOfFreedom) - pValue = 2 * (1 - stats.t.cdf(tStat, df = degreesOfFreedom)) + noSignificantDifference = tStat < tCritical + pValue = 2 * (1 - stats.t.cdf(tStat, df = degreesOfFreedom)) + else: + noSignificantDifference = None + pValue = -1 - if noSignificantDifference: + if noSignificantDifference is None: + verdict = "" + elif noSignificantDifference: verdict = "likely same" elif main.avg < compare.avg: verdict = "likely worse" diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index c85fac7d3..ae2399e49 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -257,7 +257,7 @@ DEFINE_PROTO_FUZZER(const luau::StatBlock& message) lua_State* L = lua_newthread(globalState); luaL_sandboxthread(L); - if (luau_load(L, "=fuzz", bytecode.data(), bytecode.size()) == 0) + if (luau_load(L, "=fuzz", bytecode.data(), bytecode.size(), 0) == 0) { interruptDeadline = std::chrono::system_clock::now() + kInterruptTimeout; diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 8a7798f3a..5a7c86023 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -91,10 +91,6 @@ struct ACFixture : ACFixtureImpl { }; -struct UnfrozenACFixture : ACFixtureImpl -{ -}; - TEST_SUITE_BEGIN("AutocompleteTest"); TEST_CASE_FIXTURE(ACFixture, "empty_program") @@ -1919,9 +1915,10 @@ local bar: @1= foo CHECK(!ac.entryMap.count("foo")); } -// CLI-45692: Remove UnfrozenACFixture here -TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_function_no_parenthesis") +TEST_CASE_FIXTURE(ACFixture, "type_correct_function_no_parenthesis") { + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + check(R"( local function target(a: (number) -> number) return a(4) end local function bar1(a: number) return -a end @@ -1950,9 +1947,10 @@ local fp: @1= f CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); } -// CLI-45692: Remove UnfrozenACFixture here -TEST_CASE_FIXTURE(UnfrozenACFixture, "type_correct_keywords") +TEST_CASE_FIXTURE(ACFixture, "type_correct_keywords") { + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + check(R"( local function a(x: boolean) end local function b(x: number?) end @@ -2484,7 +2482,7 @@ local t = { CHECK(ac.entryMap.count("second")); } -TEST_CASE_FIXTURE(Fixture, "autocomplete_documentation_symbols") +TEST_CASE_FIXTURE(UnfrozenFixture, "autocomplete_documentation_symbols") { loadDefinition(R"( declare y: { diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 7f03019c3..4ce8d08ae 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -11,8 +11,6 @@ #include LUAU_FASTFLAG(LuauPreloadClosures) -LUAU_FASTFLAG(LuauPreloadClosuresFenv) -LUAU_FASTFLAG(LuauPreloadClosuresUpval) LUAU_FASTFLAG(LuauGenericSpecialGlobals) using namespace Luau; @@ -2797,7 +2795,7 @@ CAPTURE UPVAL U1 RETURN R0 1 )"); - if (FFlag::LuauPreloadClosuresUpval) + if (FFlag::LuauPreloadClosures) { // recursive capture CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( @@ -3479,15 +3477,13 @@ CAPTURE VAL R0 RETURN R1 1 )"); - if (FFlag::LuauPreloadClosuresFenv) - { - // if they don't need upvalues but we sense that environment may be modified, we disable this to avoid fenv-related identity confusion - CHECK_EQ("\n" + compileFunction(R"( + // if they don't need upvalues but we sense that environment may be modified, we disable this to avoid fenv-related identity confusion + CHECK_EQ("\n" + compileFunction(R"( setfenv(1, {}) return function() print("hi") end )", - 1), - R"( + 1), + R"( GETIMPORT R0 1 LOADN R1 1 NEWTABLE R2 0 0 @@ -3496,23 +3492,21 @@ NEWCLOSURE R0 P0 RETURN R0 1 )"); - // note that fenv analysis isn't flow-sensitive right now, which is sort of a feature - CHECK_EQ("\n" + compileFunction(R"( + // note that fenv analysis isn't flow-sensitive right now, which is sort of a feature + CHECK_EQ("\n" + compileFunction(R"( if false then setfenv(1, {}) end return function() print("hi") end )", - 1), - R"( + 1), + R"( NEWCLOSURE R0 P0 RETURN R0 1 )"); - } } TEST_CASE("SharedClosure") { ScopedFastFlag sff1("LuauPreloadClosures", true); - ScopedFastFlag sff2("LuauPreloadClosuresUpval", true); // closures can be shared even if functions refer to upvalues, as long as upvalues are top-level CHECK_EQ("\n" + compileFunction(R"( @@ -3671,7 +3665,7 @@ RETURN R0 0 )"); } -TEST_CASE("LuauGenericSpecialGlobals") +TEST_CASE("MutableGlobals") { const char* source = R"( print() @@ -3685,43 +3679,6 @@ shared.print() workspace.print() )"; - { - ScopedFastFlag genericSpecialGlobals{"LuauGenericSpecialGlobals", false}; - - // Check Roblox globals are here - CHECK_EQ("\n" + compileFunction0(source), R"( -GETIMPORT R0 1 -CALL R0 0 0 -GETIMPORT R1 3 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 5 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 7 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 9 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 11 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 13 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 15 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -GETIMPORT R1 17 -GETTABLEKS R0 R1 K0 -CALL R0 0 0 -RETURN R0 0 -)"); - } - - ScopedFastFlag genericSpecialGlobals{"LuauGenericSpecialGlobals", true}; - // Check Roblox globals are no longer here CHECK_EQ("\n" + compileFunction0(source), R"( GETIMPORT R0 1 diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index c1b790b9c..e495a2136 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -1,5 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Compiler.h" +#include "lua.h" +#include "lualib.h" +#include "luacode.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/ModuleResolver.h" @@ -10,9 +12,6 @@ #include "doctest.h" #include "ScopedFlags.h" -#include "lua.h" -#include "lualib.h" - #include #include @@ -49,8 +48,12 @@ static int lua_loadstring(lua_State* L) lua_setsafeenv(L, LUA_ENVIRONINDEX, false); - std::string bytecode = Luau::compile(std::string(s, l)); - if (luau_load(L, chunkname, bytecode.data(), bytecode.size()) == 0) + size_t bytecodeSize = 0; + char* bytecode = luau_compile(s, l, nullptr, &bytecodeSize); + int result = luau_load(L, chunkname, bytecode, bytecodeSize, 0); + free(bytecode); + + if (result == 0) return 1; lua_pushnil(L); @@ -179,21 +182,17 @@ static StateRef runConformance( std::string chunkname = "=" + std::string(name); - Luau::CompileOptions copts; + lua_CompileOptions copts = {}; + copts.optimizationLevel = 1; // default copts.debugLevel = 2; // for debugger tests copts.vectorCtor = "vector"; // for vector tests - std::string bytecode = Luau::compile(source, copts); - int status = 0; + size_t bytecodeSize = 0; + char* bytecode = luau_compile(source.data(), source.size(), &copts, &bytecodeSize); + int result = luau_load(L, chunkname.c_str(), bytecode, bytecodeSize, 0); + free(bytecode); - if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size()) == 0) - { - status = lua_resume(L, nullptr, 0); - } - else - { - status = LUA_ERRSYNTAX; - } + int status = (result == 0) ? lua_resume(L, nullptr, 0) : LUA_ERRSYNTAX; while (yield && (status == LUA_YIELD || status == LUA_BREAK)) { @@ -332,53 +331,61 @@ TEST_CASE("UTF8") TEST_CASE("Coroutine") { + ScopedFastFlag sff("LuauCoroutineClose", true); + runConformance("coroutine.lua"); } -TEST_CASE("PCall") +static int cxxthrow(lua_State* L) { - runConformance("pcall.lua", [](lua_State* L) { - lua_pushcfunction(L, [](lua_State* L) -> int { #if LUA_USE_LONGJMP - luaL_error(L, "oops"); + luaL_error(L, "oops"); #else - throw std::runtime_error("oops"); + throw std::runtime_error("oops"); #endif - }); +} + +TEST_CASE("PCall") +{ + runConformance("pcall.lua", [](lua_State* L) { + lua_pushcfunction(L, cxxthrow, "cxxthrow"); lua_setglobal(L, "cxxthrow"); - lua_pushcfunction(L, [](lua_State* L) -> int { - lua_State* co = lua_tothread(L, 1); - lua_xmove(L, co, 1); - lua_resumeerror(co, L); - return 0; - }); + lua_pushcfunction( + L, + [](lua_State* L) -> int { + lua_State* co = lua_tothread(L, 1); + lua_xmove(L, co, 1); + lua_resumeerror(co, L); + return 0; + }, + "resumeerror"); lua_setglobal(L, "resumeerror"); }); } TEST_CASE("Pack") { - ScopedFastFlag sff{ "LuauStrPackUBCastFix", true }; - + ScopedFastFlag sff{"LuauStrPackUBCastFix", true}; + runConformance("tpack.lua"); } TEST_CASE("Vector") { runConformance("vector.lua", [](lua_State* L) { - lua_pushcfunction(L, lua_vector); + lua_pushcfunction(L, lua_vector, "vector"); lua_setglobal(L, "vector"); lua_pushvector(L, 0.0f, 0.0f, 0.0f); luaL_newmetatable(L, "vector"); lua_pushstring(L, "__index"); - lua_pushcfunction(L, lua_vector_index); + lua_pushcfunction(L, lua_vector_index, nullptr); lua_settable(L, -3); lua_pushstring(L, "__namecall"); - lua_pushcfunction(L, lua_vector_namecall); + lua_pushcfunction(L, lua_vector_namecall, nullptr); lua_settable(L, -3); lua_setreadonly(L, -1, true); @@ -513,15 +520,19 @@ TEST_CASE("Debugger") }; // add breakpoint() function - lua_pushcfunction(L, [](lua_State* L) -> int { - int line = luaL_checkinteger(L, 1); - - lua_Debug ar = {}; - lua_getinfo(L, 1, "f", &ar); - - lua_breakpoint(L, -1, line, true); - return 0; - }); + lua_pushcfunction( + L, + [](lua_State* L) -> int { + int line = luaL_checkinteger(L, 1); + bool enabled = lua_isboolean(L, 2) ? lua_toboolean(L, 2) : true; + + lua_Debug ar = {}; + lua_getinfo(L, 1, "f", &ar); + + lua_breakpoint(L, -1, line, enabled); + return 0; + }, + "breakpoint"); lua_setglobal(L, "breakpoint"); }, [](lua_State* L) { @@ -744,7 +755,7 @@ TEST_CASE("ExceptionObject") if (nsize == 0) { free(ptr); - return NULL; + return nullptr; } else if (nsize > 512 * 1024) { diff --git a/tests/IostreamOptional.h b/tests/IostreamOptional.h index e55b5b0c3..e0756badd 100644 --- a/tests/IostreamOptional.h +++ b/tests/IostreamOptional.h @@ -4,7 +4,8 @@ #include #include -namespace std { +namespace std +{ inline std::ostream& operator<<(std::ostream& lhs, const std::nullopt_t&) { diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 18f55d2c1..7a3543c7c 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -203,6 +203,8 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + TypeVar freeTy(FreeTypeVar{TypeLevel{}}); TypePackVar freeTp(FreeTypePack{TypeLevel{}}); @@ -212,12 +214,12 @@ TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") bool encounteredFreeType = false; TypeId clonedTy = clone(&freeTy, dest, seenTypes, seenTypePacks, &encounteredFreeType); - CHECK(Luau::get(clonedTy)); + CHECK_EQ("any", toString(clonedTy)); CHECK(encounteredFreeType); encounteredFreeType = false; TypePackId clonedTp = clone(&freeTp, dest, seenTypes, seenTypePacks, &encounteredFreeType); - CHECK(Luau::get(clonedTp)); + CHECK_EQ("...any", toString(clonedTp)); CHECK(encounteredFreeType); } diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index b076e9ad7..80a258f5b 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -198,7 +198,8 @@ TEST_CASE_FIXTURE(Fixture, "stringifying_table_type_correctly_use_matching_table TypeVar tv{ttv}; - ToStringOptions o{/* exhaustive= */ false, /* useLineBreaks= */ false, /* functionTypeArguments= */ false, /* hideTableKind= */ false, 40}; + ToStringOptions o; + o.maxTableLength = 40; CHECK_EQ(toString(&tv, o), "{| a: number, b: number, c: number, d: number, e: number, ... 5 more ... |}"); } @@ -395,7 +396,7 @@ local function target(callback: nil) return callback(4, "hello") end )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(requireType("target")), "(nil) -> (*unknown*)"); + CHECK_EQ("(nil) -> (*unknown*)", toString(requireType("target"))); } TEST_CASE_FIXTURE(Fixture, "toStringGenericPack") @@ -469,4 +470,110 @@ TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") CHECK_EQ(toString(tableTy), "Table
"); } +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_id") +{ + CheckResult result = check(R"( + local function id(x) return x end + )"); + + TypeId ty = requireType("id"); + const FunctionTypeVar* ftv = get(follow(ty)); + + CHECK_EQ("id(x: a): a", toStringNamedFunction("id", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") +{ + CheckResult result = check(R"( + local function map(arr, fn) + local t = {} + for i = 0, #arr do + t[i] = fn(arr[i]) + end + return t + end + )"); + + TypeId ty = requireType("map"); + const FunctionTypeVar* ftv = get(follow(ty)); + + CHECK_EQ("map(arr: {a}, fn: (a) -> b): {b}", toStringNamedFunction("map", *ftv)); +} + +TEST_CASE("toStringNamedFunction_unit_f") +{ + TypePackVar empty{TypePack{}}; + FunctionTypeVar ftv{&empty, &empty, {}, false}; + CHECK_EQ("f(): ()", toStringNamedFunction("f", ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics") +{ + CheckResult result = check(R"( + local function f(x: a, ...): (a, a, b...) + return x, x, ... + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(x: a, ...: any): (a, a, b...)", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics2") +{ + CheckResult result = check(R"( + local function f(): ...number + return 1, 2, 3 + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(): ...number", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics3") +{ + CheckResult result = check(R"( + local function f(): (string, ...number) + return 'a', 1, 2, 3 + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(): (string, ...number)", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_type_annotation_has_partial_argnames") +{ + CheckResult result = check(R"( + local f: (number, y: number) -> number + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + CHECK_EQ("f(_: number, y: number): number", toStringNamedFunction("f", *ftv)); +} + +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_type_params") +{ + CheckResult result = check(R"( + local function f(x: T, g: (T) -> U)): () + end + )"); + + TypeId ty = requireType("f"); + auto ftv = get(follow(ty)); + + ToStringOptions opts; + opts.hideNamedFunctionTypeParameters = true; + CHECK_EQ("f(x: T, g: (T) -> U): ()", toStringNamedFunction("f", *ftv, opts)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index c27f8083b..74ce155c2 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -109,7 +109,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_stop_typechecking_after_reporting_duplicate_typ CheckResult result = check(R"( type A = number type A = string -- Redefinition of type 'A', previously defined at line 1 - local foo: string = 1 -- No "Type 'number' could not be converted into 'string'" + local foo: string = 1 -- "Type 'number' could not be converted into 'string'" )"); LUAU_REQUIRE_ERROR_COUNT(2, result); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 2e4001641..091c2f012 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -381,6 +381,8 @@ TEST_CASE_FIXTURE(Fixture, "typeof_expr") TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + CheckResult result = check(R"( type A = B type B = A @@ -390,7 +392,7 @@ TEST_CASE_FIXTURE(Fixture, "corecursive_types_error_on_tight_loop") )"); TypeId fType = requireType("aa"); - const ErrorTypeVar* ftv = get(follow(fType)); + const AnyTypeVar* ftv = get(follow(fType)); REQUIRE(ftv != nullptr); REQUIRE(!result.errors.empty()); } diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index de2f01544..88c2dc85d 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -289,7 +289,7 @@ TEST_CASE_FIXTURE(Fixture, "calling_self_generic_methods") end )"); // TODO: Should typecheck but currently errors CLI-39916 - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "infer_generic_property") @@ -352,7 +352,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_generic_types") -- so this assignment should fail local b: boolean = f(true) )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") @@ -368,7 +368,7 @@ TEST_CASE_FIXTURE(Fixture, "dont_leak_inferred_generic_types") local y: number = id(37) end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "dont_substitute_bound_types") diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 733fc39b3..fe8e7ff90 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -704,9 +704,10 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); else CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" - if (FFlag::LuauQuantifyInPlace2) + if (FFlag::LuauQuantifyInPlace2) CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" else CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp new file mode 100644 index 000000000..5f95efd52 --- /dev/null +++ b/tests/TypeInfer.singletons.test.cpp @@ -0,0 +1,377 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Fixture.h" + +#include "doctest.h" +#include "Luau/BuiltinDefinitions.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeSingletons"); + +TEST_CASE_FIXTURE(Fixture, "bool_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: true = true + local b: false = false + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "string_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: "foo" = "foo" + local b: "bar" = "bar" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "bool_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: true = false + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'false' could not be converted into 'true'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "string_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: "foo" = "bar" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '\"bar\"' could not be converted into '\"foo\"'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "string_singletons_escape_chars") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: "\n" = "\000\r" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type '"\000\r"' could not be converted into '"\n"')", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "bool_singleton_subtype") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: true = true + local b: boolean = a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "string_singleton_subtype") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: "foo" = "foo" + local b: string = a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + function f(a: true, b: "foo") end + f(true, "foo") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + function f(a: true, b: "foo") end + f(true, "bar") + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '\"bar\"' could not be converted into '\"foo\"'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + function f(a, b) end + local g : ((true, string) -> ()) & ((false, number) -> ()) = (f::any) + g(true, "foo") + g(false, 37) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + function f(a, b) end + local g : ((true, string) -> ()) & ((false, number) -> ()) = (f::any) + g(true, 37) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); + CHECK_EQ("Other overloads are also not viable: (false, number) -> ()", toString(result.errors[1])); +} + +TEST_CASE_FIXTURE(Fixture, "enums_using_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type MyEnum = "foo" | "bar" | "baz" + local a : MyEnum = "foo" + local b : MyEnum = "bar" + local c : MyEnum = "baz" + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauExtendedTypeMismatchError", true}, + }; + + CheckResult result = check(R"( + type MyEnum = "foo" | "bar" | "baz" + local a : MyEnum = "bang" + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '\"bang\"' could not be converted into '\"bar\" | \"baz\" | \"foo\"'; none of the union options are compatible", + toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_subtyping") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type MyEnum1 = "foo" | "bar" + type MyEnum2 = MyEnum1 | "baz" + local a : MyEnum1 = "foo" + local b : MyEnum2 = a + local c : string = b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauExpectedTypesOfProperties", true}, + }; + + CheckResult result = check(R"( + type Dog = { tag: "Dog", howls: boolean } + type Cat = { tag: "Cat", meows: boolean } + type Animal = Dog | Cat + local a : Dog = { tag = "Dog", howls = true } + local b : Animal = { tag = "Cat", meows = true } + local c : Animal = a + c = b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type Dog = { tag: "Dog", howls: boolean } + type Cat = { tag: "Cat", meows: boolean } + type Animal = Dog | Cat + local a : Animal = { tag = "Cat", howls = true } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tagged_unions_immutable_tag") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type Dog = { tag: "Dog", howls: boolean } + type Cat = { tag: "Cat", meows: boolean } + type Animal = Dog | Cat + local a : Animal = { tag = "Cat", meows = true } + a.tag = "Dog" + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + type T = { + ["foo"] : number, + ["$$bar"] : string, + baz : boolean + } + local t: T = { + ["foo"] = 37, + ["$$bar"] = "hi", + baz = true + } + local a: number = t.foo + local b: string = t["$$bar"] + local c: boolean = t.baz + t.foo = 5 + t["$$bar"] = "lo" + t.baz = false + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} +TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings_mismatch") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + type T = { + ["$$bar"] : string, + } + local t: T = { + ["$$bar"] = "hi", + } + t["$$bar"] = 5 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + type S = "bar" + type T = { + [("foo")] : number, + [S] : string, + } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Syntax error: Cannot have more than one table indexer", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + }; + + CheckResult result = check(R"( + --!strict + local x: { ["<>"] : number } + x = { ["\n"] = 5 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Table type '{| ["\n"]: number |}' not compatible with type '{| ["<>"]: number |}' because the former is missing field '<>')", + toString(result.errors[0])); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 30d9130a5..99fd8339c 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -362,7 +362,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error") CHECK_EQ(2, result.errors.size()); TypeId p = requireType("p"); - CHECK_EQ(*p, *typeChecker.errorType); + CHECK_EQ("*unknown*", toString(p)); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") @@ -480,7 +480,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(typeChecker.anyType, requireType("a")); + CHECK_EQ("any", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") @@ -496,7 +496,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(typeChecker.anyType, requireType("a")); + CHECK_EQ("any", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") @@ -512,7 +512,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(typeChecker.anyType, requireType("a")); + CHECK_EQ("any", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") @@ -526,7 +526,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(typeChecker.errorType, requireType("a")); + CHECK_EQ("*unknown*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") @@ -542,7 +542,7 @@ TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(typeChecker.errorType, requireType("a")); + CHECK_EQ("*unknown*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_custom_iterator") @@ -673,7 +673,7 @@ TEST_CASE_FIXTURE(Fixture, "string_index") REQUIRE(nat); CHECK_EQ("string", toString(nat->ty)); - CHECK(get(requireType("t"))); + CHECK_EQ("*unknown*", toString(requireType("t"))); } TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error") @@ -1456,7 +1456,7 @@ TEST_CASE_FIXTURE(Fixture, "require_module_that_does_not_export") auto hootyType = requireType(bModule, "Hooty"); - CHECK_MESSAGE(get(follow(hootyType)) != nullptr, "Should be an error: " << toString(hootyType)); + CHECK_EQ("*unknown*", toString(hootyType)); } TEST_CASE_FIXTURE(Fixture, "warn_on_lowercase_parent_property") @@ -2032,7 +2032,7 @@ TEST_CASE_FIXTURE(Fixture, "higher_order_function_4") CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[1]); } -TEST_CASE_FIXTURE(Fixture, "error_types_propagate") +TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types") { CheckResult result = check(R"( local err = (true).x @@ -2049,10 +2049,10 @@ TEST_CASE_FIXTURE(Fixture, "error_types_propagate") CHECK_EQ("boolean", toString(err->table)); CHECK_EQ("x", err->key); - CHECK(nullptr != get(requireType("c"))); - CHECK(nullptr != get(requireType("d"))); - CHECK(nullptr != get(requireType("e"))); - CHECK(nullptr != get(requireType("f"))); + CHECK_EQ("*unknown*", toString(requireType("c"))); + CHECK_EQ("*unknown*", toString(requireType("d"))); + CHECK_EQ("*unknown*", toString(requireType("e"))); + CHECK_EQ("*unknown*", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") @@ -2068,7 +2068,7 @@ TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") CHECK_EQ("unknown", err->name); - CHECK(nullptr != get(requireType("a"))); + CHECK_EQ("*unknown*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") @@ -2077,9 +2077,7 @@ TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") local a = Utility.Create "Foo" {} )"); - TypeId aType = requireType("a"); - - REQUIRE_MESSAGE(nullptr != get(aType), "Not an error: " << toString(aType)); + CHECK_EQ("*unknown*", toString(requireType("a"))); } TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") @@ -2146,6 +2144,8 @@ TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops") TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + CheckResult result = check(R"( --!strict local Vec3 = {} @@ -2175,11 +2175,13 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersectio CHECK_EQ("Vec3", toString(requireType("b"))); CHECK_EQ("Vec3", toString(requireType("c"))); CHECK_EQ("Vec3", toString(requireType("d"))); - CHECK(get(requireType("e"))); + CHECK_EQ("Vec3", toString(requireType("e"))); } TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + CheckResult result = check(R"( --!strict local Vec3 = {} @@ -2209,7 +2211,7 @@ TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersectio CHECK_EQ("Vec3", toString(requireType("b"))); CHECK_EQ("Vec3", toString(requireType("c"))); CHECK_EQ("Vec3", toString(requireType("d"))); - CHECK(get(requireType("e"))); + CHECK_EQ("Vec3", toString(requireType("e"))); } TEST_CASE_FIXTURE(Fixture, "compare_numbers") @@ -2901,6 +2903,8 @@ end TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") { + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + CheckResult result = check(R"( local x: number = 9999 function x:y(z: number) @@ -2908,7 +2912,7 @@ function x:y(z: number) end )"); - LUAU_REQUIRE_ERROR_COUNT(3, result); + LUAU_REQUIRE_ERROR_COUNT(2, result); } TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfError") @@ -2920,7 +2924,7 @@ function x:y(z: number) end )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "CallOrOfFunctions") @@ -3799,7 +3803,7 @@ TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign") print(a) )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERRORS(result); CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); } @@ -4215,7 +4219,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); - CHECK(get(t0->type)); + CHECK_EQ("*unknown*", toString(t0->type)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); @@ -4238,7 +4242,7 @@ TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional") std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); REQUIRE(t0); - CHECK(get(t0->type)); + CHECK_EQ("*unknown*", toString(t0->type)); auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { return get(err); @@ -4394,6 +4398,25 @@ TEST_CASE_FIXTURE(Fixture, "record_matching_overload") CHECK_EQ(toString(*it), "(number) -> number"); } +TEST_CASE_FIXTURE(Fixture, "return_type_by_overload") +{ + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( + type Overload = ((string) -> string) & ((number, number) -> number) + local abc: Overload + local x = abc(true) + local y = abc(true,true) + local z = abc(true,true,true) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("string", toString(requireType("x"))); + CHECK_EQ("number", toString(requireType("y"))); + // Should this be string|number? + CHECK_EQ("string", toString(requireType("z"))); +} + TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") { // Simple direct arg to arg propagation @@ -4740,4 +4763,20 @@ TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions3") } } +TEST_CASE_FIXTURE(Fixture, "type_error_addition") +{ + CheckResult result = check(R"( +--!strict +local foo = makesandwich() +local bar = foo.nutrition + 100 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + // We should definitely get this error + CHECK_EQ("Unknown global 'makesandwich'", toString(result.errors[0])); + // We get this error if makesandwich() returns a free type + // CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'foo'", toString(result.errors[1])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 2d697fc97..9f9a007f1 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -121,9 +121,26 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "members_of_failed_typepack_unification_are_u LUAU_REQUIRE_ERROR_COUNT(1, result); - TypeId bType = requireType("b"); + CHECK_EQ("a", toString(requireType("a"))); + CHECK_EQ("*unknown*", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(TryUnifyFixture, "result_of_failed_typepack_unification_is_constrained") +{ + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( + function f(arg: number) return arg end + local a + local b + local c = f(a, b) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_MESSAGE(get(bType), "Should be an error: " << toString(bType)); + CHECK_EQ("a", toString(requireType("a"))); + CHECK_EQ("*unknown*", toString(requireType("b"))); + CHECK_EQ("number", toString(requireType("c"))); } TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails") @@ -167,15 +184,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_tails_respect_progress") CHECK(state.errors.empty()); } -TEST_CASE_FIXTURE(TryUnifyFixture, "unifying_variadic_pack_with_error_should_work") -{ - TypePackId variadicPack = arena.addTypePack(TypePackVar{VariadicTypePack{typeChecker.numberType}}); - TypePackId errorPack = arena.addTypePack(TypePack{{typeChecker.numberType}, arena.addTypePack(TypePackVar{Unifiable::Error{}})}); - - state.tryUnify(variadicPack, errorPack); - REQUIRE_EQ(0, state.errors.size()); -} - TEST_CASE_FIXTURE(TryUnifyFixture, "variadics_should_use_reversed_properly") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 9f29b6428..48496b895 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -200,8 +200,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_missing_property") CHECK_EQ(mup->missing[0], *bTy); CHECK_EQ(mup->key, "x"); - TypeId r = requireType("r"); - CHECK_MESSAGE(get(r), "Expected error, got " << toString(r)); + CHECK_EQ("*unknown*", toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "index_on_a_union_type_with_one_property_of_type_any") @@ -283,7 +282,7 @@ local c = b:foo(1, 2) CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "optional_union_follow") +TEST_CASE_FIXTURE(UnfrozenFixture, "optional_union_follow") { CheckResult result = check(R"( local y: number? = 2 diff --git a/tests/conformance/coroutine.lua b/tests/conformance/coroutine.lua index 753296422..4d9b12953 100644 --- a/tests/conformance/coroutine.lua +++ b/tests/conformance/coroutine.lua @@ -319,4 +319,58 @@ for i=0,30 do assert(#T2 == 1 or T2[#T2] == 42) end +-- test coroutine.close +do + -- ok to close a dead coroutine + local co = coroutine.create(type) + assert(coroutine.resume(co, "testing 'coroutine.close'")) + assert(coroutine.status(co) == "dead") + local st, msg = coroutine.close(co) + assert(st and msg == nil) + -- also ok to close it again + st, msg = coroutine.close(co) + assert(st and msg == nil) + + + -- cannot close the running coroutine + coroutine.wrap(function() + local st, msg = pcall(coroutine.close, coroutine.running()) + assert(not st and string.find(msg, "running")) + end)() + + -- cannot close a "normal" coroutine + coroutine.wrap(function() + local co = coroutine.running() + coroutine.wrap(function () + local st, msg = pcall(coroutine.close, co) + assert(not st and string.find(msg, "normal")) + end)() + end)() + + -- closing a coroutine after an error + local co = coroutine.create(error) + local obj = {42} + local st, msg = coroutine.resume(co, obj) + assert(not st and msg == obj) + st, msg = coroutine.close(co) + assert(not st and msg == obj) + -- after closing, no more errors + st, msg = coroutine.close(co) + assert(st and msg == nil) + + -- closing a coroutine that has outstanding upvalues + local f + local co = coroutine.create(function() + local a = 42 + f = function() return a end + coroutine.yield() + a = 20 + end) + coroutine.resume(co) + assert(f() == 42) + st, msg = coroutine.close(co) + assert(st and msg == nil) + assert(f() == 42) +end + return'OK' diff --git a/tests/conformance/debugger.lua b/tests/conformance/debugger.lua index 5e69fc6bc..6ba99fb9b 100644 --- a/tests/conformance/debugger.lua +++ b/tests/conformance/debugger.lua @@ -45,4 +45,13 @@ breakpoint(38) -- break inside corobad() local co = coroutine.create(corobad) assert(coroutine.resume(co) == false) -- this breaks, resumes and dies! +function bar() + print("in bar") +end + +breakpoint(49) +breakpoint(49, false) -- validate that disabling breakpoints works + +bar() + return 'OK' From eed18acec8677b380fdbb7f424d79e1c7dae4273 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 2 Dec 2021 15:20:08 -0800 Subject: [PATCH 08/32] Sync to upstream/release/506 --- Analysis/include/Luau/FileResolver.h | 23 - Analysis/include/Luau/Module.h | 12 +- Analysis/include/Luau/ToDot.h | 31 ++ Analysis/include/Luau/TxnLog.h | 9 - Analysis/include/Luau/TypeInfer.h | 1 - Analysis/include/Luau/TypeVar.h | 15 - Analysis/include/Luau/Unifier.h | 25 +- Analysis/include/Luau/UnifierSharedState.h | 8 + Analysis/include/Luau/VisitTypeVar.h | 4 +- Analysis/src/Autocomplete.cpp | 64 ++- Analysis/src/BuiltinDefinitions.cpp | 6 +- Analysis/src/Error.cpp | 118 ++--- Analysis/src/Frontend.cpp | 27 +- Analysis/src/IostreamHelpers.cpp | 21 +- Analysis/src/JsonEncoder.cpp | 9 +- Analysis/src/Module.cpp | 132 ++--- Analysis/src/Quantify.cpp | 13 +- Analysis/src/RequireTracer.cpp | 195 +------ Analysis/src/Substitution.cpp | 29 +- Analysis/src/ToDot.cpp | 378 ++++++++++++++ Analysis/src/ToString.cpp | 162 +++--- Analysis/src/Transpiler.cpp | 15 +- Analysis/src/TxnLog.cpp | 37 +- Analysis/src/TypeAttach.cpp | 52 +- Analysis/src/TypeInfer.cpp | 325 +++++------- Analysis/src/TypeVar.cpp | 364 -------------- Analysis/src/Unifier.cpp | 306 ++--------- Ast/src/Parser.cpp | 78 +-- CLI/Analyze.cpp | 23 +- CLI/Repl.cpp | 39 -- CLI/Web.cpp | 106 ++++ CMakeLists.txt | 57 ++- Compiler/src/Compiler.cpp | 14 +- Makefile | 6 +- Sources.cmake | 10 + VM/include/lua.h | 11 +- VM/include/luaconf.h | 38 ++ VM/include/lualib.h | 6 + VM/src/lapi.cpp | 65 ++- VM/src/laux.cpp | 95 ++-- VM/src/lbaselib.cpp | 4 +- VM/src/lbitlib.cpp | 2 +- VM/src/lbuiltins.cpp | 12 +- VM/src/lcorolib.cpp | 2 +- VM/src/ldblib.cpp | 2 +- VM/src/ldo.cpp | 71 +-- VM/src/lgc.cpp | 548 +------------------- VM/src/lgcdebug.cpp | 558 +++++++++++++++++++++ VM/src/linit.cpp | 10 +- VM/src/lmathlib.cpp | 5 +- VM/src/lmem.cpp | 8 +- VM/src/lnumutils.h | 8 + VM/src/lobject.cpp | 2 +- VM/src/lobject.h | 23 +- VM/src/loslib.cpp | 2 +- VM/src/lstrlib.cpp | 2 +- VM/src/ltable.cpp | 19 +- VM/src/ltablib.cpp | 2 +- VM/src/lutf8lib.cpp | 2 +- VM/src/lvmexecute.cpp | 30 +- VM/src/lvmload.cpp | 6 +- VM/src/lvmutils.cpp | 18 +- bench/gc/test_LB_mandel.lua | 2 +- bench/tests/shootout/mandel.lua | 2 +- bench/tests/shootout/qt.lua | 10 +- fuzz/proto.cpp | 4 +- tests/AstQuery.test.cpp | 23 + tests/Autocomplete.test.cpp | 40 +- tests/Compiler.test.cpp | 168 +++++++ tests/Conformance.test.cpp | 114 +++-- tests/Fixture.cpp | 38 +- tests/Fixture.h | 5 +- tests/Frontend.test.cpp | 17 - tests/Linter.test.cpp | 16 + tests/Module.test.cpp | 66 ++- tests/Parser.test.cpp | 2 - tests/ToDot.test.cpp | 366 ++++++++++++++ tests/Transpiler.test.cpp | 3 - tests/TypeInfer.aliases.test.cpp | 2 - tests/TypeInfer.generics.test.cpp | 21 +- tests/TypeInfer.provisional.test.cpp | 2 - tests/TypeInfer.tables.test.cpp | 70 +++ tests/TypeInfer.test.cpp | 20 + tests/TypeInfer.typePacks.cpp | 33 -- tests/TypeVar.test.cpp | 45 ++ tests/conformance/apicalls.lua | 8 +- tests/conformance/basic.lua | 5 +- tests/conformance/closure.lua | 8 +- tests/conformance/constructs.lua | 2 +- tests/conformance/coroutine.lua | 2 +- tests/conformance/datetime.lua | 2 +- tests/conformance/debug.lua | 2 +- tests/conformance/errors.lua | 32 +- tests/conformance/gc.lua | 8 +- tests/conformance/nextvar.lua | 6 +- tests/conformance/pcall.lua | 2 +- tests/conformance/utf8.lua | 2 +- tests/conformance/vector.lua | 31 +- tools/svg.py | 9 +- 99 files changed, 2895 insertions(+), 2558 deletions(-) create mode 100644 Analysis/include/Luau/ToDot.h create mode 100644 Analysis/src/ToDot.cpp create mode 100644 CLI/Web.cpp create mode 100644 VM/src/lgcdebug.cpp create mode 100644 tests/ToDot.test.cpp diff --git a/Analysis/include/Luau/FileResolver.h b/Analysis/include/Luau/FileResolver.h index 9b74fc12d..0fdcce161 100644 --- a/Analysis/include/Luau/FileResolver.h +++ b/Analysis/include/Luau/FileResolver.h @@ -51,13 +51,6 @@ struct FileResolver { return std::nullopt; } - - // DEPRECATED APIS - // These are going to be removed with LuauNewRequireTrace2 - virtual bool moduleExists(const ModuleName& name) const = 0; - virtual std::optional fromAstFragment(AstExpr* expr) const = 0; - virtual ModuleName concat(const ModuleName& lhs, std::string_view rhs) const = 0; - virtual std::optional getParentModuleName(const ModuleName& name) const = 0; }; struct NullFileResolver : FileResolver @@ -66,22 +59,6 @@ struct NullFileResolver : FileResolver { return std::nullopt; } - bool moduleExists(const ModuleName& name) const override - { - return false; - } - std::optional fromAstFragment(AstExpr* expr) const override - { - return std::nullopt; - } - ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override - { - return lhs; - } - std::optional getParentModuleName(const ModuleName& name) const override - { - return std::nullopt; - } }; } // namespace Luau diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index d08448351..2e41674bf 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -78,9 +78,15 @@ void unfreeze(TypeArena& arena); using SeenTypes = std::unordered_map; using SeenTypePacks = std::unordered_map; -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType = nullptr); -TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType = nullptr); -TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType = nullptr); +struct CloneState +{ + int recursionCount = 0; + bool encounteredFreeType = false; +}; + +TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); +TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState); struct Module { diff --git a/Analysis/include/Luau/ToDot.h b/Analysis/include/Luau/ToDot.h new file mode 100644 index 000000000..ce518d3ae --- /dev/null +++ b/Analysis/include/Luau/ToDot.h @@ -0,0 +1,31 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" + +#include + +namespace Luau +{ +struct TypeVar; +using TypeId = const TypeVar*; + +struct TypePackVar; +using TypePackId = const TypePackVar*; + +struct ToDotOptions +{ + bool showPointers = true; // Show pointer value in the node label + bool duplicatePrimitives = true; // Display primitive types and 'any' as separate nodes +}; + +std::string toDot(TypeId ty, const ToDotOptions& opts); +std::string toDot(TypePackId tp, const ToDotOptions& opts); + +std::string toDot(TypeId ty); +std::string toDot(TypePackId tp); + +void dumpDot(TypeId ty); +void dumpDot(TypePackId tp); + +} // namespace Luau diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 322abd198..29988a3b9 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -25,15 +25,6 @@ struct TxnLog { } - explicit TxnLog(const std::vector>& ownedSeen) - : originalSeenSize(ownedSeen.size()) - , ownedSeen(ownedSeen) - , sharedSeen(nullptr) - { - // This is deprecated! - LUAU_ASSERT(!FFlag::LuauShareTxnSeen); - } - TxnLog(const TxnLog&) = delete; TxnLog& operator=(const TxnLog&) = delete; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 78d642c58..9f553bc14 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -297,7 +297,6 @@ struct TypeChecker private: Unifier mkUnifier(const Location& location); - Unifier mkUnifier(const std::vector>& seen, const Location& location); // These functions are only safe to call when we are in the process of typechecking a module. diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 093ea4319..8c4c2f34f 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -517,21 +517,6 @@ extern SingletonTypes singletonTypes; void persist(TypeId ty); void persist(TypePackId tp); -struct ToDotOptions -{ - bool showPointers = true; // Show pointer value in the node label - bool duplicatePrimitives = true; // Display primitive types and 'any' as separate nodes -}; - -std::string toDot(TypeId ty, const ToDotOptions& opts); -std::string toDot(TypePackId tp, const ToDotOptions& opts); - -std::string toDot(TypeId ty); -std::string toDot(TypePackId tp); - -void dumpDot(TypeId ty); -void dumpDot(TypePackId tp); - const TypeLevel* getLevel(TypeId ty); TypeLevel* getMutableLevel(TypeId ty); diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 503034a1b..4588cdd8c 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -19,12 +19,6 @@ enum Variance Invariant }; -struct UnifierCounters -{ - int recursionCount = 0; - int iterationCount = 0; -}; - struct Unifier { TypeArena* const types; @@ -37,20 +31,11 @@ struct Unifier Variance variance = Covariant; CountMismatch::Context ctx = CountMismatch::Arg; - UnifierCounters* counters; - UnifierCounters countersData; - - std::shared_ptr counters_DEPRECATED; - UnifierSharedState& sharedState; Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState); - Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& ownedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED = nullptr, - UnifierCounters* counters = nullptr); Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED = nullptr, - UnifierCounters* counters = nullptr); + Variance variance, UnifierSharedState& sharedState); // Test whether the two type vars unify. Never commits the result. ErrorVec canUnify(TypeId superTy, TypeId subTy); @@ -92,9 +77,9 @@ struct Unifier public: // Report an "infinite type error" if the type "needle" already occurs within "haystack" void occursCheck(TypeId needle, TypeId haystack); - void occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypeId needle, TypeId haystack); + void occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack); void occursCheck(TypePackId needle, TypePackId haystack); - void occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypePackId needle, TypePackId haystack); + void occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack); Unifier makeChildUnifier(); @@ -106,10 +91,6 @@ struct Unifier [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); - - // Remove with FFlagLuauCacheUnifyTableResults - DenseHashSet tempSeenTy_DEPRECATED{nullptr}; - DenseHashSet tempSeenTp_DEPRECATED{nullptr}; }; } // namespace Luau diff --git a/Analysis/include/Luau/UnifierSharedState.h b/Analysis/include/Luau/UnifierSharedState.h index f252a004b..88997c41a 100644 --- a/Analysis/include/Luau/UnifierSharedState.h +++ b/Analysis/include/Luau/UnifierSharedState.h @@ -24,6 +24,12 @@ struct TypeIdPairHash } }; +struct UnifierCounters +{ + int recursionCount = 0; + int iterationCount = 0; +}; + struct UnifierSharedState { UnifierSharedState(InternalErrorReporter* iceHandler) @@ -39,6 +45,8 @@ struct UnifierSharedState DenseHashSet tempSeenTy{nullptr}; DenseHashSet tempSeenTp{nullptr}; + + UnifierCounters counters; }; } // namespace Luau diff --git a/Analysis/include/Luau/VisitTypeVar.h b/Analysis/include/Luau/VisitTypeVar.h index a866655c9..740854b33 100644 --- a/Analysis/include/Luau/VisitTypeVar.h +++ b/Analysis/include/Luau/VisitTypeVar.h @@ -5,8 +5,6 @@ #include "Luau/TypeVar.h" #include "Luau/TypePack.h" -LUAU_FASTFLAG(LuauCacheUnifyTableResults) - namespace Luau { @@ -101,7 +99,7 @@ void visit(TypeId ty, F& f, Set& seen) // Some visitors want to see bound tables, that's why we visit the original type if (apply(ty, *ttv, seen, f)) { - if (FFlag::LuauCacheUnifyTableResults && ttv->boundTo) + if (ttv->boundTo) { visit(*ttv->boundTo, f, seen); } diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 6fc0b3f88..db2d1d0e5 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -12,9 +12,9 @@ #include #include -LUAU_FASTFLAGVARIABLE(ElseElseIfCompletionImprovements, false); LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); +LUAU_FASTFLAGVARIABLE(LuauAutocompletePreferToCallFunctions, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -203,8 +203,9 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ { SeenTypes seenTypes; SeenTypePacks seenTypePacks; - expectedType = clone(expectedType, *typeArena, seenTypes, seenTypePacks, nullptr); - actualType = clone(actualType, *typeArena, seenTypes, seenTypePacks, nullptr); + CloneState cloneState; + expectedType = clone(expectedType, *typeArena, seenTypes, seenTypePacks, cloneState); + actualType = clone(actualType, *typeArena, seenTypes, seenTypePacks, cloneState); auto errors = unifier.canUnify(expectedType, actualType); return errors.empty(); @@ -229,28 +230,51 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ TypeId expectedType = follow(*it); - if (canUnify(expectedType, ty)) - return TypeCorrectKind::Correct; + if (FFlag::LuauAutocompletePreferToCallFunctions) + { + // We also want to suggest functions that return compatible result + if (const FunctionTypeVar* ftv = get(ty)) + { + auto [retHead, retTail] = flatten(ftv->retType); - // We also want to suggest functions that return compatible result - const FunctionTypeVar* ftv = get(ty); + if (!retHead.empty() && canUnify(expectedType, retHead.front())) + return TypeCorrectKind::CorrectFunctionResult; - if (!ftv) - return TypeCorrectKind::None; + // We might only have a variadic tail pack, check if the element is compatible + if (retTail) + { + if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(expectedType, vtp->ty)) + return TypeCorrectKind::CorrectFunctionResult; + } + } + + return canUnify(expectedType, ty) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + } + else + { + if (canUnify(expectedType, ty)) + return TypeCorrectKind::Correct; - auto [retHead, retTail] = flatten(ftv->retType); + // We also want to suggest functions that return compatible result + const FunctionTypeVar* ftv = get(ty); - if (!retHead.empty()) - return canUnify(expectedType, retHead.front()) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + if (!ftv) + return TypeCorrectKind::None; - // We might only have a variadic tail pack, check if the element is compatible - if (retTail) - { - if (const VariadicTypePack* vtp = get(follow(*retTail))) - return canUnify(expectedType, vtp->ty) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; - } + auto [retHead, retTail] = flatten(ftv->retType); + + if (!retHead.empty()) + return canUnify(expectedType, retHead.front()) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; - return TypeCorrectKind::None; + // We might only have a variadic tail pack, check if the element is compatible + if (retTail) + { + if (const VariadicTypePack* vtp = get(follow(*retTail))) + return canUnify(expectedType, vtp->ty) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + } + + return TypeCorrectKind::None; + } } enum class PropIndexType @@ -1413,7 +1437,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M else if (AstStatWhile* statWhile = extractStat(finder.ancestry); statWhile && !statWhile->hasDo) return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; - else if (AstStatIf* statIf = node->as(); FFlag::ElseElseIfCompletionImprovements && statIf && !statIf->hasElse) + else if (AstStatIf* statIf = node->as(); statIf && !statIf->hasElse) { return {{{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index 62a06a3cd..bac94a2bd 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -8,8 +8,6 @@ #include -LUAU_FASTFLAG(LuauNewRequireTrace2) - /** FIXME: Many of these type definitions are not quite completely accurate. * * Some of them require richer generics than we have. For instance, we do not yet have a way to talk @@ -473,9 +471,7 @@ static std::optional> magicFunctionRequire( if (!checkRequirePath(typechecker, expr.args.data[0])) return std::nullopt; - const AstExpr* require = FFlag::LuauNewRequireTrace2 ? &expr : expr.args.data[0]; - - if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, *require)) + if (auto moduleInfo = typechecker.resolver->resolveModuleInfo(typechecker.currentModuleName, expr)) return ExprResult{arena.addTypePack({typechecker.checkRequire(scope, *moduleInfo, expr.location)})}; return std::nullopt; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index f80d50a7a..8334bd626 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -7,57 +7,14 @@ #include -LUAU_FASTFLAG(LuauTypeAliasPacks) - -static std::string wrongNumberOfArgsString_DEPRECATED(size_t expectedCount, size_t actualCount, bool isTypeArgs = false) -{ - std::string s = "expects " + std::to_string(expectedCount) + " "; - - if (isTypeArgs) - s += "type "; - - s += "argument"; - if (expectedCount != 1) - s += "s"; - - s += ", but "; - - if (actualCount == 0) - { - s += "none"; - } - else - { - if (actualCount < expectedCount) - s += "only "; - - s += std::to_string(actualCount); - } - - s += (actualCount == 1) ? " is" : " are"; - - s += " specified"; - - return s; -} - static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) { - std::string s; + std::string s = "expects "; - if (FFlag::LuauTypeAliasPacks) - { - s = "expects "; + if (isVariadic) + s += "at least "; - if (isVariadic) - s += "at least "; - - s += std::to_string(expectedCount) + " "; - } - else - { - s = "expects " + std::to_string(expectedCount) + " "; - } + s += std::to_string(expectedCount) + " "; if (argPrefix) s += std::string(argPrefix) + " "; @@ -188,10 +145,7 @@ struct ErrorConverter return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + std::to_string(e.actual) + " are required here"; case CountMismatch::Arg: - if (FFlag::LuauTypeAliasPacks) - return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); - else - return "Argument count mismatch. Function " + wrongNumberOfArgsString_DEPRECATED(e.expected, e.actual); + return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); } LUAU_ASSERT(!"Unknown context"); @@ -232,7 +186,7 @@ struct ErrorConverter std::string operator()(const Luau::IncorrectGenericParameterCount& e) const { std::string name = e.name; - if (!e.typeFun.typeParams.empty() || (FFlag::LuauTypeAliasPacks && !e.typeFun.typePackParams.empty())) + if (!e.typeFun.typeParams.empty() || !e.typeFun.typePackParams.empty()) { name += "<"; bool first = true; @@ -246,36 +200,25 @@ struct ErrorConverter name += toString(t); } - if (FFlag::LuauTypeAliasPacks) + for (TypePackId t : e.typeFun.typePackParams) { - for (TypePackId t : e.typeFun.typePackParams) - { - if (first) - first = false; - else - name += ", "; - - name += toString(t); - } + if (first) + first = false; + else + name += ", "; + + name += toString(t); } name += ">"; } - if (FFlag::LuauTypeAliasPacks) - { - if (e.typeFun.typeParams.size() != e.actualParameters) - return "Generic type '" + name + "' " + - wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, "type", !e.typeFun.typePackParams.empty()); - + if (e.typeFun.typeParams.size() != e.actualParameters) return "Generic type '" + name + "' " + - wrongNumberOfArgsString(e.typeFun.typePackParams.size(), e.actualPackParameters, "type pack", /*isVariadic*/ false); - } - else - { - return "Generic type '" + name + "' " + - wrongNumberOfArgsString_DEPRECATED(e.typeFun.typeParams.size(), e.actualParameters, /*isTypeArgs*/ true); - } + wrongNumberOfArgsString(e.typeFun.typeParams.size(), e.actualParameters, "type", !e.typeFun.typePackParams.empty()); + + return "Generic type '" + name + "' " + + wrongNumberOfArgsString(e.typeFun.typePackParams.size(), e.actualPackParameters, "type pack", /*isVariadic*/ false); } std::string operator()(const Luau::SyntaxError& e) const @@ -591,11 +534,8 @@ bool IncorrectGenericParameterCount::operator==(const IncorrectGenericParameterC if (typeFun.typeParams.size() != rhs.typeFun.typeParams.size()) return false; - if (FFlag::LuauTypeAliasPacks) - { - if (typeFun.typePackParams.size() != rhs.typeFun.typePackParams.size()) - return false; - } + if (typeFun.typePackParams.size() != rhs.typeFun.typePackParams.size()) + return false; for (size_t i = 0; i < typeFun.typeParams.size(); ++i) { @@ -603,13 +543,10 @@ bool IncorrectGenericParameterCount::operator==(const IncorrectGenericParameterC return false; } - if (FFlag::LuauTypeAliasPacks) + for (size_t i = 0; i < typeFun.typePackParams.size(); ++i) { - for (size_t i = 0; i < typeFun.typePackParams.size(); ++i) - { - if (typeFun.typePackParams[i] != rhs.typeFun.typePackParams[i]) - return false; - } + if (typeFun.typePackParams[i] != rhs.typeFun.typePackParams[i]) + return false; } return true; @@ -733,14 +670,14 @@ bool containsParseErrorName(const TypeError& error) } template -void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks) +void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState cloneState) { auto clone = [&](auto&& ty) { - return ::Luau::clone(ty, destArena, seenTypes, seenTypePacks); + return ::Luau::clone(ty, destArena, seenTypes, seenTypePacks, cloneState); }; auto visitErrorData = [&](auto&& e) { - copyError(e, destArena, seenTypes, seenTypePacks); + copyError(e, destArena, seenTypes, seenTypePacks, cloneState); }; if constexpr (false) @@ -864,9 +801,10 @@ void copyErrors(ErrorVec& errors, TypeArena& destArena) { SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; auto visitErrorData = [&](auto&& e) { - copyError(e, destArena, seenTypes, seenTypePacks); + copyError(e, destArena, seenTypes, seenTypePacks, cloneState); }; LUAU_ASSERT(!destArena.typeVars.isFrozen()); diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 1e97705dc..e332f07d4 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -18,10 +18,7 @@ LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauTypeCheckTwice, false) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTFLAGVARIABLE(LuauResolveModuleNameWithoutACurrentModule, false) -LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) -LUAU_FASTFLAG(LuauNewRequireTrace2) namespace Luau { @@ -96,10 +93,11 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; for (const auto& [name, ty] : checkedModule->declaredGlobals) { - TypeId globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks); + TypeId globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks, cloneState); std::string documentationSymbol = packageName + "/global/" + name; generateDocumentationSymbols(globalTy, documentationSymbol); targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; @@ -110,7 +108,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) { - TypeFun globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks); + TypeFun globalTy = clone(ty, typeChecker.globalTypes, seenTypes, seenTypePacks, cloneState); std::string documentationSymbol = packageName + "/globaltype/" + name; generateDocumentationSymbols(globalTy.type, documentationSymbol); targetScope->exportedTypeBindings[name] = globalTy; @@ -427,15 +425,16 @@ CheckResult Frontend::check(const ModuleName& name) SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; for (const auto& [expr, strictTy] : strictModule->astTypes) - module->astTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks); + module->astTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks, cloneState); for (const auto& [expr, strictTy] : strictModule->astOriginalCallTypes) - module->astOriginalCallTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks); + module->astOriginalCallTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks, cloneState); for (const auto& [expr, strictTy] : strictModule->astExpectedTypes) - module->astExpectedTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks); + module->astExpectedTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks, cloneState); } stats.timeCheck += getTimestamp() - timestamp; @@ -885,16 +884,13 @@ std::optional FrontendModuleResolver::resolveModuleInfo(const Module // If we can't find the current module name, that's because we bypassed the frontend's initializer // and called typeChecker.check directly. (This is done by autocompleteSource, for example). // In that case, requires will always fail. - if (FFlag::LuauResolveModuleNameWithoutACurrentModule) - return std::nullopt; - else - throw std::runtime_error("Frontend::resolveModuleName: Unknown currentModuleName '" + currentModuleName + "'"); + return std::nullopt; } const auto& exprs = it->second.exprs; const ModuleInfo* info = exprs.find(&pathExpr); - if (!info || (!FFlag::LuauNewRequireTrace2 && info->name.empty())) + if (!info) return std::nullopt; return *info; @@ -911,10 +907,7 @@ const ModulePtr FrontendModuleResolver::getModule(const ModuleName& moduleName) bool FrontendModuleResolver::moduleExists(const ModuleName& moduleName) const { - if (FFlag::LuauNewRequireTrace2) - return frontend->sourceNodes.count(moduleName) != 0; - else - return frontend->fileResolver->moduleExists(moduleName); + return frontend->sourceNodes.count(moduleName) != 0; } std::string FrontendModuleResolver::getHumanReadableModuleName(const ModuleName& moduleName) const diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 3b2671213..ac46b5a49 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -2,8 +2,6 @@ #include "Luau/IostreamHelpers.h" #include "Luau/ToString.h" -LUAU_FASTFLAG(LuauTypeAliasPacks) - namespace Luau { @@ -94,7 +92,7 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo { stream << "IncorrectGenericParameterCount { name = " << error.name; - if (!error.typeFun.typeParams.empty() || (FFlag::LuauTypeAliasPacks && !error.typeFun.typePackParams.empty())) + if (!error.typeFun.typeParams.empty() || !error.typeFun.typePackParams.empty()) { stream << "<"; bool first = true; @@ -108,17 +106,14 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo stream << toString(t); } - if (FFlag::LuauTypeAliasPacks) + for (TypePackId t : error.typeFun.typePackParams) { - for (TypePackId t : error.typeFun.typePackParams) - { - if (first) - first = false; - else - stream << ", "; - - stream << toString(t); - } + if (first) + first = false; + else + stream << ", "; + + stream << toString(t); } stream << ">"; diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index 064accba5..c7f623eea 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -5,8 +5,6 @@ #include "Luau/StringUtils.h" #include "Luau/Common.h" -LUAU_FASTFLAG(LuauTypeAliasPacks) - namespace Luau { @@ -615,12 +613,7 @@ struct AstJsonEncoder : public AstVisitor writeNode(node, "AstStatTypeAlias", [&]() { PROP(name); PROP(generics); - - if (FFlag::LuauTypeAliasPacks) - { - PROP(genericPacks); - } - + PROP(genericPacks); PROP(type); PROP(exported); }); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 32a0646ae..b4b6eb425 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -1,20 +1,20 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Module.h" +#include "Luau/Common.h" +#include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" #include "Luau/TypeVar.h" #include "Luau/VisitTypeVar.h" -#include "Luau/Common.h" #include LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans) -LUAU_FASTFLAG(LuauTypeAliasPacks) -LUAU_FASTFLAGVARIABLE(LuauCloneBoundTables, false) +LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 0) namespace Luau { @@ -120,12 +120,6 @@ TypePackId TypeArena::addTypePack(TypePackVar tp) return allocated; } -using SeenTypes = std::unordered_map; -using SeenTypePacks = std::unordered_map; - -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType); -TypeId clone(TypeId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType); - namespace { @@ -138,11 +132,12 @@ struct TypePackCloner; struct TypeCloner { - TypeCloner(TypeArena& dest, TypeId typeId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks) + TypeCloner(TypeArena& dest, TypeId typeId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) : dest(dest) , typeId(typeId) , seenTypes(seenTypes) , seenTypePacks(seenTypePacks) + , cloneState(cloneState) { } @@ -150,8 +145,7 @@ struct TypeCloner TypeId typeId; SeenTypes& seenTypes; SeenTypePacks& seenTypePacks; - - bool* encounteredFreeType = nullptr; + CloneState& cloneState; template void defaultClone(const T& t); @@ -178,13 +172,14 @@ struct TypePackCloner TypePackId typePackId; SeenTypes& seenTypes; SeenTypePacks& seenTypePacks; - bool* encounteredFreeType = nullptr; + CloneState& cloneState; - TypePackCloner(TypeArena& dest, TypePackId typePackId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks) + TypePackCloner(TypeArena& dest, TypePackId typePackId, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) : dest(dest) , typePackId(typePackId) , seenTypes(seenTypes) , seenTypePacks(seenTypePacks) + , cloneState(cloneState) { } @@ -197,8 +192,7 @@ struct TypePackCloner void operator()(const Unifiable::Free& t) { - if (encounteredFreeType) - *encounteredFreeType = true; + cloneState.encounteredFreeType = true; TypePackId err = singletonTypes.errorRecoveryTypePack(singletonTypes.anyTypePack); TypePackId cloned = dest.addTypePack(*err); @@ -218,13 +212,13 @@ struct TypePackCloner // We just need to be sure that we rewrite pointers both to the binder and the bindee to the same pointer. void operator()(const Unifiable::Bound& t) { - TypePackId cloned = clone(t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + TypePackId cloned = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); seenTypePacks[typePackId] = cloned; } void operator()(const VariadicTypePack& t) { - TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, encounteredFreeType)}}); + TypePackId cloned = dest.addTypePack(TypePackVar{VariadicTypePack{clone(t.ty, dest, seenTypes, seenTypePacks, cloneState)}}); seenTypePacks[typePackId] = cloned; } @@ -236,10 +230,10 @@ struct TypePackCloner seenTypePacks[typePackId] = cloned; for (TypeId ty : t.head) - destTp->head.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); + destTp->head.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); if (t.tail) - destTp->tail = clone(*t.tail, dest, seenTypes, seenTypePacks, encounteredFreeType); + destTp->tail = clone(*t.tail, dest, seenTypes, seenTypePacks, cloneState); } }; @@ -252,8 +246,7 @@ void TypeCloner::defaultClone(const T& t) void TypeCloner::operator()(const Unifiable::Free& t) { - if (encounteredFreeType) - *encounteredFreeType = true; + cloneState.encounteredFreeType = true; TypeId err = singletonTypes.errorRecoveryType(singletonTypes.anyType); TypeId cloned = dest.addType(*err); seenTypes[typeId] = cloned; @@ -266,7 +259,7 @@ void TypeCloner::operator()(const Unifiable::Generic& t) void TypeCloner::operator()(const Unifiable::Bound& t) { - TypeId boundTo = clone(t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + TypeId boundTo = clone(t.boundTo, dest, seenTypes, seenTypePacks, cloneState); seenTypes[typeId] = boundTo; } @@ -294,23 +287,23 @@ void TypeCloner::operator()(const FunctionTypeVar& t) seenTypes[typeId] = result; for (TypeId generic : t.generics) - ftv->generics.push_back(clone(generic, dest, seenTypes, seenTypePacks, encounteredFreeType)); + ftv->generics.push_back(clone(generic, dest, seenTypes, seenTypePacks, cloneState)); for (TypePackId genericPack : t.genericPacks) - ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, encounteredFreeType)); + ftv->genericPacks.push_back(clone(genericPack, dest, seenTypes, seenTypePacks, cloneState)); ftv->tags = t.tags; - ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, encounteredFreeType); + ftv->argTypes = clone(t.argTypes, dest, seenTypes, seenTypePacks, cloneState); ftv->argNames = t.argNames; - ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, encounteredFreeType); + ftv->retType = clone(t.retType, dest, seenTypes, seenTypePacks, cloneState); } void TypeCloner::operator()(const TableTypeVar& t) { // If table is now bound to another one, we ignore the content of the original - if (FFlag::LuauCloneBoundTables && t.boundTo) + if (t.boundTo) { - TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); + TypeId boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, cloneState); seenTypes[typeId] = boundTo; return; } @@ -326,34 +319,21 @@ void TypeCloner::operator()(const TableTypeVar& t) ttv->level = TypeLevel{0, 0}; for (const auto& [name, prop] : t.props) - ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; + ttv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; if (t.indexer) - ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, encounteredFreeType), - clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, encounteredFreeType)}; - - if (!FFlag::LuauCloneBoundTables) - { - if (t.boundTo) - ttv->boundTo = clone(*t.boundTo, dest, seenTypes, seenTypePacks, encounteredFreeType); - } + ttv->indexer = TableIndexer{clone(t.indexer->indexType, dest, seenTypes, seenTypePacks, cloneState), + clone(t.indexer->indexResultType, dest, seenTypes, seenTypePacks, cloneState)}; for (TypeId& arg : ttv->instantiatedTypeParams) - arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType); + arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); - if (FFlag::LuauTypeAliasPacks) - { - for (TypePackId& arg : ttv->instantiatedTypePackParams) - arg = clone(arg, dest, seenTypes, seenTypePacks, encounteredFreeType); - } + for (TypePackId& arg : ttv->instantiatedTypePackParams) + arg = clone(arg, dest, seenTypes, seenTypePacks, cloneState); if (ttv->state == TableState::Free) { - if (FFlag::LuauCloneBoundTables || !t.boundTo) - { - if (encounteredFreeType) - *encounteredFreeType = true; - } + cloneState.encounteredFreeType = true; ttv->state = TableState::Sealed; } @@ -369,8 +349,8 @@ void TypeCloner::operator()(const MetatableTypeVar& t) MetatableTypeVar* mtv = getMutable(result); seenTypes[typeId] = result; - mtv->table = clone(t.table, dest, seenTypes, seenTypePacks, encounteredFreeType); - mtv->metatable = clone(t.metatable, dest, seenTypes, seenTypePacks, encounteredFreeType); + mtv->table = clone(t.table, dest, seenTypes, seenTypePacks, cloneState); + mtv->metatable = clone(t.metatable, dest, seenTypes, seenTypePacks, cloneState); } void TypeCloner::operator()(const ClassTypeVar& t) @@ -381,13 +361,13 @@ void TypeCloner::operator()(const ClassTypeVar& t) seenTypes[typeId] = result; for (const auto& [name, prop] : t.props) - ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, encounteredFreeType), prop.deprecated, {}, prop.location, prop.tags}; + ctv->props[name] = {clone(prop.type, dest, seenTypes, seenTypePacks, cloneState), prop.deprecated, {}, prop.location, prop.tags}; if (t.parent) - ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, encounteredFreeType); + ctv->parent = clone(*t.parent, dest, seenTypes, seenTypePacks, cloneState); if (t.metatable) - ctv->metatable = clone(*t.metatable, dest, seenTypes, seenTypePacks, encounteredFreeType); + ctv->metatable = clone(*t.metatable, dest, seenTypes, seenTypePacks, cloneState); } void TypeCloner::operator()(const AnyTypeVar& t) @@ -404,7 +384,7 @@ void TypeCloner::operator()(const UnionTypeVar& t) LUAU_ASSERT(option != nullptr); for (TypeId ty : t.options) - option->options.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); + option->options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); } void TypeCloner::operator()(const IntersectionTypeVar& t) @@ -416,7 +396,7 @@ void TypeCloner::operator()(const IntersectionTypeVar& t) LUAU_ASSERT(option != nullptr); for (TypeId ty : t.parts) - option->parts.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); + option->parts.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); } void TypeCloner::operator()(const LazyTypeVar& t) @@ -426,17 +406,18 @@ void TypeCloner::operator()(const LazyTypeVar& t) } // anonymous namespace -TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType) +TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) { if (tp->persistent) return tp; + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); + TypePackId& res = seenTypePacks[tp]; if (res == nullptr) { - TypePackCloner cloner{dest, tp, seenTypes, seenTypePacks}; - cloner.encounteredFreeType = encounteredFreeType; + TypePackCloner cloner{dest, tp, seenTypes, seenTypePacks, cloneState}; Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. } @@ -446,17 +427,18 @@ TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypeP return res; } -TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType) +TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) { if (typeId->persistent) return typeId; + RecursionLimiter _ra(&cloneState.recursionCount, FInt::LuauTypeCloneRecursionLimit); + TypeId& res = seenTypes[typeId]; if (res == nullptr) { - TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks}; - cloner.encounteredFreeType = encounteredFreeType; + TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks, cloneState}; Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. asMutable(res)->documentationSymbol = typeId->documentationSymbol; } @@ -467,19 +449,16 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks return res; } -TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, bool* encounteredFreeType) +TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) { TypeFun result; for (TypeId ty : typeFun.typeParams) - result.typeParams.push_back(clone(ty, dest, seenTypes, seenTypePacks, encounteredFreeType)); + result.typeParams.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); - if (FFlag::LuauTypeAliasPacks) - { - for (TypePackId tp : typeFun.typePackParams) - result.typePackParams.push_back(clone(tp, dest, seenTypes, seenTypePacks, encounteredFreeType)); - } + for (TypePackId tp : typeFun.typePackParams) + result.typePackParams.push_back(clone(tp, dest, seenTypes, seenTypePacks, cloneState)); - result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, encounteredFreeType); + result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, cloneState); return result; } @@ -519,19 +498,18 @@ bool Module::clonePublicInterface() LUAU_ASSERT(interfaceTypes.typeVars.empty()); LUAU_ASSERT(interfaceTypes.typePacks.empty()); - bool encounteredFreeType = false; - - SeenTypePacks seenTypePacks; SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + CloneState cloneState; ScopePtr moduleScope = getModuleScope(); - moduleScope->returnType = clone(moduleScope->returnType, interfaceTypes, seenTypes, seenTypePacks, &encounteredFreeType); + moduleScope->returnType = clone(moduleScope->returnType, interfaceTypes, seenTypes, seenTypePacks, cloneState); if (moduleScope->varargPack) - moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, seenTypes, seenTypePacks, &encounteredFreeType); + moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, seenTypes, seenTypePacks, cloneState); for (auto& pair : moduleScope->exportedTypeBindings) - pair.second = clone(pair.second, interfaceTypes, seenTypes, seenTypePacks, &encounteredFreeType); + pair.second = clone(pair.second, interfaceTypes, seenTypes, seenTypePacks, cloneState); for (TypeId ty : moduleScope->returnType) if (get(follow(ty))) @@ -540,7 +518,7 @@ bool Module::clonePublicInterface() freeze(internalTypes); freeze(interfaceTypes); - return encounteredFreeType; + return cloneState.encounteredFreeType; } } // namespace Luau diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index bf6d81aa6..c773e208b 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -4,6 +4,8 @@ #include "Luau/VisitTypeVar.h" +LUAU_FASTFLAGVARIABLE(LuauQuantifyVisitOnce, false) + namespace Luau { @@ -79,7 +81,16 @@ struct Quantifier void quantify(ModulePtr module, TypeId ty, TypeLevel level) { Quantifier q{std::move(module), level}; - visitTypeVar(ty, q); + + if (FFlag::LuauQuantifyVisitOnce) + { + DenseHashSet seen{nullptr}; + visitTypeVarOnce(ty, q, seen); + } + else + { + visitTypeVar(ty, q); + } FunctionTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); diff --git a/Analysis/src/RequireTracer.cpp b/Analysis/src/RequireTracer.cpp index b72f53f99..8ed245fbb 100644 --- a/Analysis/src/RequireTracer.cpp +++ b/Analysis/src/RequireTracer.cpp @@ -4,182 +4,9 @@ #include "Luau/Ast.h" #include "Luau/Module.h" -LUAU_FASTFLAGVARIABLE(LuauTraceRequireLookupChild, false) -LUAU_FASTFLAGVARIABLE(LuauNewRequireTrace2, false) - namespace Luau { -namespace -{ - -struct RequireTracerOld : AstVisitor -{ - explicit RequireTracerOld(FileResolver* fileResolver, const ModuleName& currentModuleName) - : fileResolver(fileResolver) - , currentModuleName(currentModuleName) - { - LUAU_ASSERT(!FFlag::LuauNewRequireTrace2); - } - - FileResolver* const fileResolver; - ModuleName currentModuleName; - DenseHashMap locals{nullptr}; - RequireTraceResult result; - - std::optional fromAstFragment(AstExpr* expr) - { - if (auto g = expr->as(); g && g->name == "script") - return currentModuleName; - - return fileResolver->fromAstFragment(expr); - } - - bool visit(AstStatLocal* stat) override - { - for (size_t i = 0; i < stat->vars.size; ++i) - { - AstLocal* local = stat->vars.data[i]; - - if (local->annotation) - { - if (AstTypeTypeof* ann = local->annotation->as()) - ann->expr->visit(this); - } - - if (i < stat->values.size) - { - AstExpr* expr = stat->values.data[i]; - expr->visit(this); - - const ModuleInfo* info = result.exprs.find(expr); - if (info) - locals[local] = info->name; - } - } - - return false; - } - - bool visit(AstExprGlobal* global) override - { - std::optional name = fromAstFragment(global); - if (name) - result.exprs[global] = {*name}; - - return false; - } - - bool visit(AstExprLocal* local) override - { - const ModuleName* name = locals.find(local->local); - if (name) - result.exprs[local] = {*name}; - - return false; - } - - bool visit(AstExprIndexName* indexName) override - { - indexName->expr->visit(this); - - const ModuleInfo* info = result.exprs.find(indexName->expr); - if (info) - { - if (indexName->index == "parent" || indexName->index == "Parent") - { - if (auto parent = fileResolver->getParentModuleName(info->name)) - result.exprs[indexName] = {*parent}; - } - else - result.exprs[indexName] = {fileResolver->concat(info->name, indexName->index.value)}; - } - - return false; - } - - bool visit(AstExprIndexExpr* indexExpr) override - { - indexExpr->expr->visit(this); - - const ModuleInfo* info = result.exprs.find(indexExpr->expr); - const AstExprConstantString* str = indexExpr->index->as(); - if (info && str) - { - result.exprs[indexExpr] = {fileResolver->concat(info->name, std::string_view(str->value.data, str->value.size))}; - } - - indexExpr->index->visit(this); - - return false; - } - - bool visit(AstExprTypeAssertion* expr) override - { - return false; - } - - // If we see game:GetService("StringLiteral") or Game:GetService("StringLiteral"), then rewrite to game.StringLiteral. - // Else traverse arguments and trace requires to them. - bool visit(AstExprCall* call) override - { - for (AstExpr* arg : call->args) - arg->visit(this); - - call->func->visit(this); - - AstExprGlobal* globalName = call->func->as(); - if (globalName && globalName->name == "require" && call->args.size >= 1) - { - if (const ModuleInfo* moduleInfo = result.exprs.find(call->args.data[0])) - result.requires.push_back({moduleInfo->name, call->location}); - - return false; - } - - AstExprIndexName* indexName = call->func->as(); - if (!indexName) - return false; - - std::optional rootName = fromAstFragment(indexName->expr); - - if (FFlag::LuauTraceRequireLookupChild && !rootName) - { - if (const ModuleInfo* moduleInfo = result.exprs.find(indexName->expr)) - rootName = moduleInfo->name; - } - - if (!rootName) - return false; - - bool supportedLookup = indexName->index == "GetService" || - (FFlag::LuauTraceRequireLookupChild && (indexName->index == "FindFirstChild" || indexName->index == "WaitForChild")); - - if (!supportedLookup) - return false; - - if (call->args.size != 1) - return false; - - AstExprConstantString* name = call->args.data[0]->as(); - if (!name) - return false; - - std::string_view v{name->value.data, name->value.size}; - if (v.end() != std::find(v.begin(), v.end(), '/')) - return false; - - result.exprs[call] = {fileResolver->concat(*rootName, v)}; - - // 'WaitForChild' can be used on modules that are not available at the typecheck time, but will be available at runtime - // If we fail to find such module, we will not report an UnknownRequire error - if (FFlag::LuauTraceRequireLookupChild && indexName->index == "WaitForChild") - result.exprs[call].optional = true; - - return false; - } -}; - struct RequireTracer : AstVisitor { RequireTracer(RequireTraceResult& result, FileResolver* fileResolver, const ModuleName& currentModuleName) @@ -188,7 +15,6 @@ struct RequireTracer : AstVisitor , currentModuleName(currentModuleName) , locals(nullptr) { - LUAU_ASSERT(FFlag::LuauNewRequireTrace2); } bool visit(AstExprTypeAssertion* expr) override @@ -328,24 +154,13 @@ struct RequireTracer : AstVisitor std::vector requires; }; -} // anonymous namespace - RequireTraceResult traceRequires(FileResolver* fileResolver, AstStatBlock* root, const ModuleName& currentModuleName) { - if (FFlag::LuauNewRequireTrace2) - { - RequireTraceResult result; - RequireTracer tracer{result, fileResolver, currentModuleName}; - root->visit(&tracer); - tracer.process(); - return result; - } - else - { - RequireTracerOld tracer{fileResolver, currentModuleName}; - root->visit(&tracer); - return tracer.result; - } + RequireTraceResult result; + RequireTracer tracer{result, fileResolver, currentModuleName}; + root->visit(&tracer); + tracer.process(); + return result; } } // namespace Luau diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index ca2b30f52..3d004bee3 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -7,8 +7,6 @@ #include LUAU_FASTINTVARIABLE(LuauTarjanChildLimit, 1000) -LUAU_FASTFLAGVARIABLE(LuauSubstitutionDontReplaceIgnoredTypes, false) -LUAU_FASTFLAG(LuauTypeAliasPacks) namespace Luau { @@ -39,11 +37,8 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypeId itp : ttv->instantiatedTypeParams) visitChild(itp); - if (FFlag::LuauTypeAliasPacks) - { - for (TypePackId itp : ttv->instantiatedTypePackParams) - visitChild(itp); - } + for (TypePackId itp : ttv->instantiatedTypePackParams) + visitChild(itp); } else if (const MetatableTypeVar* mtv = get(ty)) { @@ -339,10 +334,10 @@ std::optional Substitution::substitute(TypeId ty) return std::nullopt; for (auto [oldTy, newTy] : newTypes) - if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTy)) + if (!ignoreChildren(oldTy)) replaceChildren(newTy); for (auto [oldTp, newTp] : newPacks) - if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTp)) + if (!ignoreChildren(oldTp)) replaceChildren(newTp); TypeId newTy = replace(ty); return newTy; @@ -359,10 +354,10 @@ std::optional Substitution::substitute(TypePackId tp) return std::nullopt; for (auto [oldTy, newTy] : newTypes) - if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTy)) + if (!ignoreChildren(oldTy)) replaceChildren(newTy); for (auto [oldTp, newTp] : newPacks) - if (!FFlag::LuauSubstitutionDontReplaceIgnoredTypes || !ignoreChildren(oldTp)) + if (!ignoreChildren(oldTp)) replaceChildren(newTp); TypePackId newTp = replace(tp); return newTp; @@ -393,10 +388,7 @@ TypeId Substitution::clone(TypeId ty) clone.name = ttv->name; clone.syntheticName = ttv->syntheticName; clone.instantiatedTypeParams = ttv->instantiatedTypeParams; - - if (FFlag::LuauTypeAliasPacks) - clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; - + clone.instantiatedTypePackParams = ttv->instantiatedTypePackParams; clone.tags = ttv->tags; result = addType(std::move(clone)); } @@ -505,11 +497,8 @@ void Substitution::replaceChildren(TypeId ty) for (TypeId& itp : ttv->instantiatedTypeParams) itp = replace(itp); - if (FFlag::LuauTypeAliasPacks) - { - for (TypePackId& itp : ttv->instantiatedTypePackParams) - itp = replace(itp); - } + for (TypePackId& itp : ttv->instantiatedTypePackParams) + itp = replace(itp); } else if (MetatableTypeVar* mtv = getMutable(ty)) { diff --git a/Analysis/src/ToDot.cpp b/Analysis/src/ToDot.cpp new file mode 100644 index 000000000..df9d41881 --- /dev/null +++ b/Analysis/src/ToDot.cpp @@ -0,0 +1,378 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/ToDot.h" + +#include "Luau/ToString.h" +#include "Luau/TypePack.h" +#include "Luau/TypeVar.h" +#include "Luau/StringUtils.h" + +#include +#include + +namespace Luau +{ + +namespace +{ + +struct StateDot +{ + StateDot(ToDotOptions opts) + : opts(opts) + { + } + + ToDotOptions opts; + + std::unordered_set seenTy; + std::unordered_set seenTp; + std::unordered_map tyToIndex; + std::unordered_map tpToIndex; + int nextIndex = 1; + std::string result; + + bool canDuplicatePrimitive(TypeId ty); + + void visitChildren(TypeId ty, int index); + void visitChildren(TypePackId ty, int index); + + void visitChild(TypeId ty, int parentIndex, const char* linkName = nullptr); + void visitChild(TypePackId tp, int parentIndex, const char* linkName = nullptr); + + void startNode(int index); + void finishNode(); + + void startNodeLabel(); + void finishNodeLabel(TypeId ty); + void finishNodeLabel(TypePackId tp); +}; + +bool StateDot::canDuplicatePrimitive(TypeId ty) +{ + if (get(ty)) + return false; + + return get(ty) || get(ty); +} + +void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) +{ + if (!tyToIndex.count(ty) || (opts.duplicatePrimitives && canDuplicatePrimitive(ty))) + tyToIndex[ty] = nextIndex++; + + int index = tyToIndex[ty]; + + if (parentIndex != 0) + { + if (linkName) + formatAppend(result, "n%d -> n%d [label=\"%s\"];\n", parentIndex, index, linkName); + else + formatAppend(result, "n%d -> n%d;\n", parentIndex, index); + } + + if (opts.duplicatePrimitives && canDuplicatePrimitive(ty)) + { + if (get(ty)) + formatAppend(result, "n%d [label=\"%s\"];\n", index, toStringDetailed(ty, {}).name.c_str()); + else if (get(ty)) + formatAppend(result, "n%d [label=\"any\"];\n", index); + } + else + { + visitChildren(ty, index); + } +} + +void StateDot::visitChild(TypePackId tp, int parentIndex, const char* linkName) +{ + if (!tpToIndex.count(tp)) + tpToIndex[tp] = nextIndex++; + + if (parentIndex != 0) + { + if (linkName) + formatAppend(result, "n%d -> n%d [label=\"%s\"];\n", parentIndex, tpToIndex[tp], linkName); + else + formatAppend(result, "n%d -> n%d;\n", parentIndex, tpToIndex[tp]); + } + + visitChildren(tp, tpToIndex[tp]); +} + +void StateDot::startNode(int index) +{ + formatAppend(result, "n%d [", index); +} + +void StateDot::finishNode() +{ + formatAppend(result, "];\n"); +} + +void StateDot::startNodeLabel() +{ + formatAppend(result, "label=\""); +} + +void StateDot::finishNodeLabel(TypeId ty) +{ + if (opts.showPointers) + formatAppend(result, "\n0x%p", ty); + // additional common attributes can be added here as well + result += "\""; +} + +void StateDot::finishNodeLabel(TypePackId tp) +{ + if (opts.showPointers) + formatAppend(result, "\n0x%p", tp); + // additional common attributes can be added here as well + result += "\""; +} + +void StateDot::visitChildren(TypeId ty, int index) +{ + if (seenTy.count(ty)) + return; + seenTy.insert(ty); + + startNode(index); + startNodeLabel(); + + if (const BoundTypeVar* btv = get(ty)) + { + formatAppend(result, "BoundTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(btv->boundTo, index); + } + else if (const FunctionTypeVar* ftv = get(ty)) + { + formatAppend(result, "FunctionTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(ftv->argTypes, index, "arg"); + visitChild(ftv->retType, index, "ret"); + } + else if (const TableTypeVar* ttv = get(ty)) + { + if (ttv->name) + formatAppend(result, "TableTypeVar %s", ttv->name->c_str()); + else if (ttv->syntheticName) + formatAppend(result, "TableTypeVar %s", ttv->syntheticName->c_str()); + else + formatAppend(result, "TableTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + if (ttv->boundTo) + return visitChild(*ttv->boundTo, index, "boundTo"); + + for (const auto& [name, prop] : ttv->props) + visitChild(prop.type, index, name.c_str()); + if (ttv->indexer) + { + visitChild(ttv->indexer->indexType, index, "[index]"); + visitChild(ttv->indexer->indexResultType, index, "[value]"); + } + for (TypeId itp : ttv->instantiatedTypeParams) + visitChild(itp, index, "typeParam"); + + for (TypePackId itp : ttv->instantiatedTypePackParams) + visitChild(itp, index, "typePackParam"); + } + else if (const MetatableTypeVar* mtv = get(ty)) + { + formatAppend(result, "MetatableTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + visitChild(mtv->table, index, "table"); + visitChild(mtv->metatable, index, "metatable"); + } + else if (const UnionTypeVar* utv = get(ty)) + { + formatAppend(result, "UnionTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + for (TypeId opt : utv->options) + visitChild(opt, index); + } + else if (const IntersectionTypeVar* itv = get(ty)) + { + formatAppend(result, "IntersectionTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + + for (TypeId part : itv->parts) + visitChild(part, index); + } + else if (const GenericTypeVar* gtv = get(ty)) + { + if (gtv->explicitName) + formatAppend(result, "GenericTypeVar %s", gtv->name.c_str()); + else + formatAppend(result, "GenericTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if (const FreeTypeVar* ftv = get(ty)) + { + formatAppend(result, "FreeTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if (get(ty)) + { + formatAppend(result, "AnyTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if (get(ty)) + { + formatAppend(result, "PrimitiveTypeVar %s", toStringDetailed(ty, {}).name.c_str()); + finishNodeLabel(ty); + finishNode(); + } + else if (get(ty)) + { + formatAppend(result, "ErrorTypeVar %d", index); + finishNodeLabel(ty); + finishNode(); + } + else if (const ClassTypeVar* ctv = get(ty)) + { + formatAppend(result, "ClassTypeVar %s", ctv->name.c_str()); + finishNodeLabel(ty); + finishNode(); + + for (const auto& [name, prop] : ctv->props) + visitChild(prop.type, index, name.c_str()); + + if (ctv->parent) + visitChild(*ctv->parent, index, "[parent]"); + + if (ctv->metatable) + visitChild(*ctv->metatable, index, "[metatable]"); + } + else + { + LUAU_ASSERT(!"unknown type kind"); + finishNodeLabel(ty); + finishNode(); + } +} + +void StateDot::visitChildren(TypePackId tp, int index) +{ + if (seenTp.count(tp)) + return; + seenTp.insert(tp); + + startNode(index); + startNodeLabel(); + + if (const BoundTypePack* btp = get(tp)) + { + formatAppend(result, "BoundTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + + visitChild(btp->boundTo, index); + } + else if (const TypePack* tpp = get(tp)) + { + formatAppend(result, "TypePack %d", index); + finishNodeLabel(tp); + finishNode(); + + for (TypeId tv : tpp->head) + visitChild(tv, index); + if (tpp->tail) + visitChild(*tpp->tail, index, "tail"); + } + else if (const VariadicTypePack* vtp = get(tp)) + { + formatAppend(result, "VariadicTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + + visitChild(vtp->ty, index); + } + else if (const FreeTypePack* ftp = get(tp)) + { + formatAppend(result, "FreeTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + } + else if (const GenericTypePack* gtp = get(tp)) + { + if (gtp->explicitName) + formatAppend(result, "GenericTypePack %s", gtp->name.c_str()); + else + formatAppend(result, "GenericTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + } + else if (get(tp)) + { + formatAppend(result, "ErrorTypePack %d", index); + finishNodeLabel(tp); + finishNode(); + } + else + { + LUAU_ASSERT(!"unknown type pack kind"); + finishNodeLabel(tp); + finishNode(); + } +} + +} // namespace + +std::string toDot(TypeId ty, const ToDotOptions& opts) +{ + StateDot state{opts}; + + state.result = "digraph graphname {\n"; + state.visitChild(ty, 0); + state.result += "}"; + + return state.result; +} + +std::string toDot(TypePackId tp, const ToDotOptions& opts) +{ + StateDot state{opts}; + + state.result = "digraph graphname {\n"; + state.visitChild(tp, 0); + state.result += "}"; + + return state.result; +} + +std::string toDot(TypeId ty) +{ + return toDot(ty, {}); +} + +std::string toDot(TypePackId tp) +{ + return toDot(tp, {}); +} + +void dumpDot(TypeId ty) +{ + printf("%s\n", toDot(ty).c_str()); +} + +void dumpDot(TypePackId tp) +{ + printf("%s\n", toDot(tp).c_str()); +} + +} // namespace Luau diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 735bfa503..6322096c4 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -11,7 +11,7 @@ #include LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) -LUAU_FASTFLAG(LuauTypeAliasPacks) +LUAU_FASTFLAGVARIABLE(LuauFunctionArgumentNameSize, false) namespace Luau { @@ -59,11 +59,8 @@ struct FindCyclicTypes for (TypeId itp : ttv.instantiatedTypeParams) visitTypeVar(itp, *this, seen); - if (FFlag::LuauTypeAliasPacks) - { - for (TypePackId itp : ttv.instantiatedTypePackParams) - visitTypeVar(itp, *this, seen); - } + for (TypePackId itp : ttv.instantiatedTypePackParams) + visitTypeVar(itp, *this, seen); return exhaustive; } @@ -248,58 +245,45 @@ struct TypeVarStringifier void stringify(const std::vector& types, const std::vector& typePacks) { - if (types.size() == 0 && (!FFlag::LuauTypeAliasPacks || typePacks.size() == 0)) + if (types.size() == 0 && typePacks.size() == 0) return; - if (types.size() || (FFlag::LuauTypeAliasPacks && typePacks.size())) + if (types.size() || typePacks.size()) state.emit("<"); - if (FFlag::LuauTypeAliasPacks) - { - bool first = true; - - for (TypeId ty : types) - { - if (!first) - state.emit(", "); - first = false; + bool first = true; - stringify(ty); - } + for (TypeId ty : types) + { + if (!first) + state.emit(", "); + first = false; - bool singleTp = typePacks.size() == 1; + stringify(ty); + } - for (TypePackId tp : typePacks) - { - if (isEmpty(tp) && singleTp) - continue; + bool singleTp = typePacks.size() == 1; - if (!first) - state.emit(", "); - else - first = false; + for (TypePackId tp : typePacks) + { + if (isEmpty(tp) && singleTp) + continue; - if (!singleTp) - state.emit("("); + if (!first) + state.emit(", "); + else + first = false; - stringify(tp); + if (!singleTp) + state.emit("("); - if (!singleTp) - state.emit(")"); - } - } - else - { - for (size_t i = 0; i < types.size(); ++i) - { - if (i > 0) - state.emit(", "); + stringify(tp); - stringify(types[i]); - } + if (!singleTp) + state.emit(")"); } - if (types.size() || (FFlag::LuauTypeAliasPacks && typePacks.size())) + if (types.size() || typePacks.size()) state.emit(">"); } @@ -767,12 +751,23 @@ struct TypePackStringifier else state.emit(", "); - LUAU_ASSERT(elemNames.empty() || elemIndex < elemNames.size()); - - if (!elemNames.empty() && elemNames[elemIndex]) + if (FFlag::LuauFunctionArgumentNameSize) { - state.emit(elemNames[elemIndex]->name); - state.emit(": "); + if (elemIndex < elemNames.size() && elemNames[elemIndex]) + { + state.emit(elemNames[elemIndex]->name); + state.emit(": "); + } + } + else + { + LUAU_ASSERT(elemNames.empty() || elemIndex < elemNames.size()); + + if (!elemNames.empty() && elemNames[elemIndex]) + { + state.emit(elemNames[elemIndex]->name); + state.emit(": "); + } } elemIndex++; @@ -929,38 +924,7 @@ ToStringResult toStringDetailed(TypeId ty, const ToStringOptions& opts) result.name += ttv->name ? *ttv->name : *ttv->syntheticName; - if (FFlag::LuauTypeAliasPacks) - { - tvs.stringify(ttv->instantiatedTypeParams, ttv->instantiatedTypePackParams); - } - else - { - if (ttv->instantiatedTypeParams.empty() && (!FFlag::LuauTypeAliasPacks || ttv->instantiatedTypePackParams.empty())) - return result; - - result.name += "<"; - - bool first = true; - for (TypeId ty : ttv->instantiatedTypeParams) - { - if (!first) - result.name += ", "; - else - first = false; - - tvs.stringify(ty); - } - - if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) - { - result.truncated = true; - result.name += "... "; - } - else - { - result.name += ">"; - } - } + tvs.stringify(ttv->instantiatedTypeParams, ttv->instantiatedTypePackParams); return result; } @@ -1161,17 +1125,37 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV s += ", "; first = false; - // argNames is guaranteed to be equal to argTypes iff argNames is not empty. - // We don't currently respect opts.functionTypeArguments. I don't think this function should. - if (!ftv.argNames.empty()) - s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; - s += toString_(*argPackIter); + if (FFlag::LuauFunctionArgumentNameSize) + { + // We don't currently respect opts.functionTypeArguments. I don't think this function should. + if (argNameIter != ftv.argNames.end()) + { + s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; + ++argNameIter; + } + else + { + s += "_: "; + } + } + else + { + // argNames is guaranteed to be equal to argTypes iff argNames is not empty. + // We don't currently respect opts.functionTypeArguments. I don't think this function should. + if (!ftv.argNames.empty()) + s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; + } + s += toString_(*argPackIter); ++argPackIter; - if (!ftv.argNames.empty()) + + if (!FFlag::LuauFunctionArgumentNameSize) { - LUAU_ASSERT(argNameIter != ftv.argNames.end()); - ++argNameIter; + if (!ftv.argNames.empty()) + { + LUAU_ASSERT(argNameIter != ftv.argNames.end()); + ++argNameIter; + } } } diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 6627fbe36..8e13ea5be 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -10,8 +10,6 @@ #include #include -LUAU_FASTFLAG(LuauTypeAliasPacks) - namespace { bool isIdentifierStartChar(char c) @@ -787,7 +785,7 @@ struct Printer writer.keyword("type"); writer.identifier(a->name.value); - if (a->generics.size > 0 || (FFlag::LuauTypeAliasPacks && a->genericPacks.size > 0)) + if (a->generics.size > 0 || a->genericPacks.size > 0) { writer.symbol("<"); CommaSeparatorInserter comma(writer); @@ -798,14 +796,11 @@ struct Printer writer.identifier(o.value); } - if (FFlag::LuauTypeAliasPacks) + for (auto o : a->genericPacks) { - for (auto o : a->genericPacks) - { - comma(); - writer.identifier(o.value); - writer.symbol("..."); - } + comma(); + writer.identifier(o.value); + writer.symbol("..."); } writer.symbol(">"); diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 383bb050d..f6a61581e 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -5,8 +5,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauShareTxnSeen, false) - namespace Luau { @@ -36,11 +34,8 @@ void TxnLog::rollback() for (auto it = tableChanges.rbegin(); it != tableChanges.rend(); ++it) std::swap(it->first->boundTo, it->second); - if (FFlag::LuauShareTxnSeen) - { - LUAU_ASSERT(originalSeenSize <= sharedSeen->size()); - sharedSeen->resize(originalSeenSize); - } + LUAU_ASSERT(originalSeenSize <= sharedSeen->size()); + sharedSeen->resize(originalSeenSize); } void TxnLog::concat(TxnLog rhs) @@ -53,45 +48,25 @@ void TxnLog::concat(TxnLog rhs) tableChanges.insert(tableChanges.end(), rhs.tableChanges.begin(), rhs.tableChanges.end()); rhs.tableChanges.clear(); - - if (!FFlag::LuauShareTxnSeen) - { - ownedSeen.swap(rhs.ownedSeen); - rhs.ownedSeen.clear(); - } } bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) { const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - if (FFlag::LuauShareTxnSeen) - return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)); - else - return (ownedSeen.end() != std::find(ownedSeen.begin(), ownedSeen.end(), sortedPair)); + return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)); } void TxnLog::pushSeen(TypeId lhs, TypeId rhs) { const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - if (FFlag::LuauShareTxnSeen) - sharedSeen->push_back(sortedPair); - else - ownedSeen.push_back(sortedPair); + sharedSeen->push_back(sortedPair); } void TxnLog::popSeen(TypeId lhs, TypeId rhs) { const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - if (FFlag::LuauShareTxnSeen) - { - LUAU_ASSERT(sortedPair == sharedSeen->back()); - sharedSeen->pop_back(); - } - else - { - LUAU_ASSERT(sortedPair == ownedSeen.back()); - ownedSeen.pop_back(); - } + LUAU_ASSERT(sortedPair == sharedSeen->back()); + sharedSeen->pop_back(); } } // namespace Luau diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index af6d2543d..9e61c7924 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -13,8 +13,6 @@ #include -LUAU_FASTFLAG(LuauTypeAliasPacks) - static char* allocateString(Luau::Allocator& allocator, std::string_view contents) { char* result = (char*)allocator.allocate(contents.size() + 1); @@ -131,12 +129,9 @@ class TypeRehydrationVisitor parameters.data[i] = {Luau::visit(*this, ttv.instantiatedTypeParams[i]->ty), {}}; } - if (FFlag::LuauTypeAliasPacks) + for (size_t i = 0; i < ttv.instantiatedTypePackParams.size(); ++i) { - for (size_t i = 0; i < ttv.instantiatedTypePackParams.size(); ++i) - { - parameters.data[i] = {{}, rehydrate(ttv.instantiatedTypePackParams[i])}; - } + parameters.data[i] = {{}, rehydrate(ttv.instantiatedTypePackParams[i])}; } return allocator->alloc(Location(), std::nullopt, AstName(ttv.name->c_str()), parameters.size != 0, parameters); @@ -250,20 +245,7 @@ class TypeRehydrationVisitor AstTypePack* argTailAnnotation = nullptr; if (argTail) - { - if (FFlag::LuauTypeAliasPacks) - { - argTailAnnotation = rehydrate(*argTail); - } - else - { - TypePackId tail = *argTail; - if (const VariadicTypePack* vtp = get(tail)) - { - argTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); - } - } - } + argTailAnnotation = rehydrate(*argTail); AstArray> argNames; argNames.size = ftv.argNames.size(); @@ -292,20 +274,7 @@ class TypeRehydrationVisitor AstTypePack* retTailAnnotation = nullptr; if (retTail) - { - if (FFlag::LuauTypeAliasPacks) - { - retTailAnnotation = rehydrate(*retTail); - } - else - { - TypePackId tail = *retTail; - if (const VariadicTypePack* vtp = get(tail)) - { - retTailAnnotation = allocator->alloc(Location(), Luau::visit(*this, vtp->ty->ty)); - } - } - } + retTailAnnotation = rehydrate(*retTail); return allocator->alloc( Location(), generics, genericPacks, AstTypeList{argTypes, argTailAnnotation}, argNames, AstTypeList{returnTypes, retTailAnnotation}); @@ -518,18 +487,7 @@ class TypeAttacher : public AstVisitor const auto& [v, tail] = flatten(ret); if (tail) - { - if (FFlag::LuauTypeAliasPacks) - { - variadicAnnotation = TypeRehydrationVisitor(allocator, &syntheticNames).rehydrate(*tail); - } - else - { - TypePackId tailPack = *tail; - if (const VariadicTypePack* vtp = get(tailPack)) - variadicAnnotation = allocator->alloc(Location(), typeAst(vtp->ty)); - } - } + variadicAnnotation = TypeRehydrationVisitor(allocator, &syntheticNames).rehydrate(*tail); fn->returnAnnotation = AstTypeList{typeAstPack(ret), variadicAnnotation}; } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index b2ae94c72..617bf482c 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -23,22 +23,20 @@ LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) -LUAU_FASTFLAGVARIABLE(LuauClassPropertyAccessAsString, false) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. -LUAU_FASTFLAG(LuauTraceRequireLookupChild) LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false) -LUAU_FASTFLAG(LuauSubstitutionDontReplaceIgnoredTypes) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) -LUAU_FASTFLAG(LuauNewRequireTrace2) -LUAU_FASTFLAG(LuauTypeAliasPacks) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) +LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) +LUAU_FASTFLAGVARIABLE(LuauTailArgumentTypeInfo, false) +LUAU_FASTFLAGVARIABLE(LuauModuleRequireErrorPack, false) namespace Luau { @@ -562,12 +560,6 @@ ErrorVec TypeChecker::canUnify(TypePackId left, TypePackId right, const Location return canUnify_(left, right, location); } -ErrorVec TypeChecker::canUnify(const std::vector>& seen, TypeId superTy, TypeId subTy, const Location& location) -{ - Unifier state = mkUnifier(seen, location); - return state.canUnify(superTy, subTy); -} - template ErrorVec TypeChecker::canUnify_(Id superTy, Id subTy, const Location& location) { @@ -1152,61 +1144,20 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias Location location = scope->typeAliasLocations[name]; reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); - if (FFlag::LuauTypeAliasPacks) - bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)}; - else - bindingsMap[name] = TypeFun{binding->typeParams, errorRecoveryType(anyType)}; + bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)}; } else { ScopePtr aliasScope = FFlag::LuauQuantifyInPlace2 ? childScope(scope, typealias.location, subLevel) : childScope(scope, typealias.location); - if (FFlag::LuauTypeAliasPacks) - { - auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks); + auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks); - TypeId ty = freshType(aliasScope); - FreeTypeVar* ftv = getMutable(ty); - LUAU_ASSERT(ftv); - ftv->forwardedTypeAlias = true; - bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; - } - else - { - std::vector generics; - for (AstName generic : typealias.generics) - { - Name n = generic.value; - - // These generics are the only thing that will ever be added to aliasScope, so we can be certain that - // a collision can only occur when two generic typevars have the same name. - if (aliasScope->privateTypeBindings.end() != aliasScope->privateTypeBindings.find(n)) - { - // TODO(jhuelsman): report the exact span of the generic type parameter whose name is a duplicate. - reportError(TypeError{typealias.location, DuplicateGenericParameter{n}}); - } - - TypeId g; - if (FFlag::LuauRecursiveTypeParameterRestriction) - { - TypeId& cached = scope->typeAliasTypeParameters[n]; - if (!cached) - cached = addType(GenericTypeVar{aliasScope->level, n}); - g = cached; - } - else - g = addType(GenericTypeVar{aliasScope->level, n}); - generics.push_back(g); - aliasScope->privateTypeBindings[n] = TypeFun{{}, g}; - } - - TypeId ty = freshType(aliasScope); - FreeTypeVar* ftv = getMutable(ty); - LUAU_ASSERT(ftv); - ftv->forwardedTypeAlias = true; - bindingsMap[name] = {std::move(generics), ty}; - } + TypeId ty = freshType(aliasScope); + FreeTypeVar* ftv = getMutable(ty); + LUAU_ASSERT(ftv); + ftv->forwardedTypeAlias = true; + bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; } } else @@ -1223,14 +1174,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, ty}; } - if (FFlag::LuauTypeAliasPacks) + for (TypePackId tp : binding->typePackParams) { - for (TypePackId tp : binding->typePackParams) - { - auto generic = get(tp); - LUAU_ASSERT(generic); - aliasScope->privateTypePackBindings[generic->name] = tp; - } + auto generic = get(tp); + LUAU_ASSERT(generic); + aliasScope->privateTypePackBindings[generic->name] = tp; } TypeId ty = resolveType(aliasScope, *typealias.type); @@ -1241,19 +1189,16 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { // Copy can be skipped if this is an identical alias if (ttv->name != name || ttv->instantiatedTypeParams != binding->typeParams || - (FFlag::LuauTypeAliasPacks && ttv->instantiatedTypePackParams != binding->typePackParams)) + ttv->instantiatedTypePackParams != binding->typePackParams) { // This is a shallow clone, original recursive links to self are not updated TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; - clone.name = name; clone.instantiatedTypeParams = binding->typeParams; - - if (FFlag::LuauTypeAliasPacks) - clone.instantiatedTypePackParams = binding->typePackParams; + clone.instantiatedTypePackParams = binding->typePackParams; ty = addType(std::move(clone)); } @@ -1262,9 +1207,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { ttv->name = name; ttv->instantiatedTypeParams = binding->typeParams; - - if (FFlag::LuauTypeAliasPacks) - ttv->instantiatedTypePackParams = binding->typePackParams; + ttv->instantiatedTypePackParams = binding->typePackParams; } } else if (auto mtv = getMutable(follow(ty))) @@ -1289,7 +1232,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar } // We don't have generic classes, so this assertion _should_ never be hit. - LUAU_ASSERT(lookupType->typeParams.size() == 0 && (!FFlag::LuauTypeAliasPacks || lookupType->typePackParams.size() == 0)); + LUAU_ASSERT(lookupType->typeParams.size() == 0 && lookupType->typePackParams.size() == 0); superTy = lookupType->type; if (!get(follow(*superTy))) @@ -1851,6 +1794,24 @@ TypeId TypeChecker::checkExprTable( if (isNonstrictMode() && !getTableType(exprType) && !get(exprType)) exprType = anyType; + if (FFlag::LuauPropertiesGetExpectedType && expectedTable) + { + auto it = expectedTable->props.find(key->value.data); + if (it != expectedTable->props.end()) + { + Property expectedProp = it->second; + ErrorVec errors = tryUnify(expectedProp.type, exprType, k->location); + if (errors.empty()) + exprType = expectedProp.type; + } + else if (expectedTable->indexer && isString(expectedTable->indexer->indexType)) + { + ErrorVec errors = tryUnify(expectedTable->indexer->indexResultType, exprType, k->location); + if (errors.empty()) + exprType = expectedTable->indexer->indexResultType; + } + } + props[key->value.data] = {exprType, /* deprecated */ false, {}, k->location}; } else @@ -3744,17 +3705,29 @@ ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const L for (size_t i = 0; i < exprs.size; ++i) { AstExpr* expr = exprs.data[i]; + std::optional expectedType = i < expectedTypes.size() ? expectedTypes[i] : std::nullopt; if (i == lastIndex && (expr->is() || expr->is())) { auto [typePack, exprPredicates] = checkExprPack(scope, *expr); insert(exprPredicates); + if (FFlag::LuauTailArgumentTypeInfo) + { + if (std::optional firstTy = first(typePack)) + { + if (!currentModule->astTypes.find(expr)) + currentModule->astTypes[expr] = follow(*firstTy); + } + + if (expectedType) + currentModule->astExpectedTypes[expr] = *expectedType; + } + tp->tail = typePack; } else { - std::optional expectedType = i < expectedTypes.size() ? expectedTypes[i] : std::nullopt; auto [type, exprPredicates] = checkExpr(scope, *expr, expectedType); insert(exprPredicates); @@ -3797,7 +3770,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module LUAU_TIMETRACE_SCOPE("TypeChecker::checkRequire", "TypeChecker"); LUAU_TIMETRACE_ARGUMENT("moduleInfo", moduleInfo.name.c_str()); - if (FFlag::LuauNewRequireTrace2 && moduleInfo.name.empty()) + if (moduleInfo.name.empty()) { if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict) { @@ -3814,7 +3787,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module // There are two reasons why we might fail to find the module: // either the file does not exist or there's a cycle. If there's a cycle // we will already have reported the error. - if (!resolver->moduleExists(moduleInfo.name) && (FFlag::LuauTraceRequireLookupChild ? !moduleInfo.optional : true)) + if (!resolver->moduleExists(moduleInfo.name) && !moduleInfo.optional) { std::string reportedModulePath = resolver->getHumanReadableModuleName(moduleInfo.name); reportError(TypeError{location, UnknownRequire{reportedModulePath}}); @@ -3830,7 +3803,12 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module return errorRecoveryType(scope); } - std::optional moduleType = first(module->getModuleScope()->returnType); + TypePackId modulePack = module->getModuleScope()->returnType; + + if (FFlag::LuauModuleRequireErrorPack && get(modulePack)) + return errorRecoveryType(scope); + + std::optional moduleType = first(modulePack); if (!moduleType) { std::string humanReadableName = resolver->getHumanReadableModuleName(moduleInfo.name); @@ -3840,7 +3818,8 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module SeenTypes seenTypes; SeenTypePacks seenTypePacks; - return clone(*moduleType, currentModule->internalTypes, seenTypes, seenTypePacks); + CloneState cloneState; + return clone(*moduleType, currentModule->internalTypes, seenTypes, seenTypePacks, cloneState); } void TypeChecker::tablify(TypeId type) @@ -4326,11 +4305,6 @@ Unifier TypeChecker::mkUnifier(const Location& location) return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, location, Variance::Covariant, unifierState}; } -Unifier TypeChecker::mkUnifier(const std::vector>& seen, const Location& location) -{ - return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, seen, location, Variance::Covariant, unifierState}; -} - TypeId TypeChecker::freshType(const ScopePtr& scope) { return freshType(scope->level); @@ -4477,117 +4451,82 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation return errorRecoveryType(scope); } - if (lit->parameters.size == 0 && tf->typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf->typePackParams.empty())) - { + if (lit->parameters.size == 0 && tf->typeParams.empty() && tf->typePackParams.empty()) return tf->type; - } - else if (!FFlag::LuauTypeAliasPacks && lit->parameters.size != tf->typeParams.size()) + + if (!lit->hasParameterList && !tf->typePackParams.empty()) { - reportError(TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, lit->parameters.size, 0}}); + reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); if (!FFlag::LuauErrorRecoveryType) return errorRecoveryType(scope); } - if (FFlag::LuauTypeAliasPacks) - { - if (!lit->hasParameterList && !tf->typePackParams.empty()) - { - reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); - if (!FFlag::LuauErrorRecoveryType) - return errorRecoveryType(scope); - } - - std::vector typeParams; - std::vector extraTypes; - std::vector typePackParams; + std::vector typeParams; + std::vector extraTypes; + std::vector typePackParams; - for (size_t i = 0; i < lit->parameters.size; ++i) + for (size_t i = 0; i < lit->parameters.size; ++i) + { + if (AstType* type = lit->parameters.data[i].type) { - if (AstType* type = lit->parameters.data[i].type) - { - TypeId ty = resolveType(scope, *type); + TypeId ty = resolveType(scope, *type); - if (typeParams.size() < tf->typeParams.size() || tf->typePackParams.empty()) - typeParams.push_back(ty); - else if (typePackParams.empty()) - extraTypes.push_back(ty); - else - reportError(TypeError{annotation.location, GenericError{"Type parameters must come before type pack parameters"}}); - } - else if (AstTypePack* typePack = lit->parameters.data[i].typePack) - { - TypePackId tp = resolveTypePack(scope, *typePack); - - // If we have collected an implicit type pack, materialize it - if (typePackParams.empty() && !extraTypes.empty()) - typePackParams.push_back(addTypePack(extraTypes)); - - // If we need more regular types, we can use single element type packs to fill those in - if (typeParams.size() < tf->typeParams.size() && size(tp) == 1 && finite(tp) && first(tp)) - typeParams.push_back(*first(tp)); - else - typePackParams.push_back(tp); - } + if (typeParams.size() < tf->typeParams.size() || tf->typePackParams.empty()) + typeParams.push_back(ty); + else if (typePackParams.empty()) + extraTypes.push_back(ty); + else + reportError(TypeError{annotation.location, GenericError{"Type parameters must come before type pack parameters"}}); } - - // If we still haven't meterialized an implicit type pack, do it now - if (typePackParams.empty() && !extraTypes.empty()) - typePackParams.push_back(addTypePack(extraTypes)); - - // If we didn't combine regular types into a type pack and we're still one type pack short, provide an empty type pack - if (extraTypes.empty() && typePackParams.size() + 1 == tf->typePackParams.size()) - typePackParams.push_back(addTypePack({})); - - if (typeParams.size() != tf->typeParams.size() || typePackParams.size() != tf->typePackParams.size()) + else if (AstTypePack* typePack = lit->parameters.data[i].typePack) { - reportError( - TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); + TypePackId tp = resolveTypePack(scope, *typePack); - if (FFlag::LuauErrorRecoveryType) - { - // Pad the types out with error recovery types - while (typeParams.size() < tf->typeParams.size()) - typeParams.push_back(errorRecoveryType(scope)); - while (typePackParams.size() < tf->typePackParams.size()) - typePackParams.push_back(errorRecoveryTypePack(scope)); - } + // If we have collected an implicit type pack, materialize it + if (typePackParams.empty() && !extraTypes.empty()) + typePackParams.push_back(addTypePack(extraTypes)); + + // If we need more regular types, we can use single element type packs to fill those in + if (typeParams.size() < tf->typeParams.size() && size(tp) == 1 && finite(tp) && first(tp)) + typeParams.push_back(*first(tp)); else - return errorRecoveryType(scope); + typePackParams.push_back(tp); } + } - if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams) - { - // If the generic parameters and the type arguments are the same, we are about to - // perform an identity substitution, which we can just short-circuit. - return tf->type; - } + // If we still haven't meterialized an implicit type pack, do it now + if (typePackParams.empty() && !extraTypes.empty()) + typePackParams.push_back(addTypePack(extraTypes)); - return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location); - } - else - { - std::vector typeParams; + // If we didn't combine regular types into a type pack and we're still one type pack short, provide an empty type pack + if (extraTypes.empty() && typePackParams.size() + 1 == tf->typePackParams.size()) + typePackParams.push_back(addTypePack({})); - for (const auto& param : lit->parameters) - typeParams.push_back(resolveType(scope, *param.type)); + if (typeParams.size() != tf->typeParams.size() || typePackParams.size() != tf->typePackParams.size()) + { + reportError( + TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); if (FFlag::LuauErrorRecoveryType) { - // If there aren't enough type parameters, pad them out with error recovery types - // (we've already reported the error) - while (typeParams.size() < lit->parameters.size) + // Pad the types out with error recovery types + while (typeParams.size() < tf->typeParams.size()) typeParams.push_back(errorRecoveryType(scope)); + while (typePackParams.size() < tf->typePackParams.size()) + typePackParams.push_back(errorRecoveryTypePack(scope)); } + else + return errorRecoveryType(scope); + } - if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams) - { - // If the generic parameters and the type arguments are the same, we are about to - // perform an identity substitution, which we can just short-circuit. - return tf->type; - } - - return instantiateTypeFun(scope, *tf, typeParams, {}, annotation.location); + if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams) + { + // If the generic parameters and the type arguments are the same, we are about to + // perform an identity substitution, which we can just short-circuit. + return tf->type; } + + return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location); } else if (const auto& table = annotation.as()) { @@ -4757,7 +4696,7 @@ bool ApplyTypeFunction::isDirty(TypePackId tp) bool ApplyTypeFunction::ignoreChildren(TypeId ty) { - if (FFlag::LuauSubstitutionDontReplaceIgnoredTypes && get(ty)) + if (get(ty)) return true; else return false; @@ -4765,7 +4704,7 @@ bool ApplyTypeFunction::ignoreChildren(TypeId ty) bool ApplyTypeFunction::ignoreChildren(TypePackId tp) { - if (FFlag::LuauSubstitutionDontReplaceIgnoredTypes && get(tp)) + if (get(tp)) return true; else return false; @@ -4788,36 +4727,26 @@ TypePackId ApplyTypeFunction::clean(TypePackId tp) // Really this should just replace the arguments, // but for bug-compatibility with existing code, we replace // all generics by free type variables. - if (FFlag::LuauTypeAliasPacks) - { - TypePackId& arg = typePackArguments[tp]; - if (arg) - return arg; - else - return addTypePack(FreeTypePack{level}); - } + TypePackId& arg = typePackArguments[tp]; + if (arg) + return arg; else - { return addTypePack(FreeTypePack{level}); - } } TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, const std::vector& typeParams, const std::vector& typePackParams, const Location& location) { - if (tf.typeParams.empty() && (!FFlag::LuauTypeAliasPacks || tf.typePackParams.empty())) + if (tf.typeParams.empty() && tf.typePackParams.empty()) return tf.type; applyTypeFunction.typeArguments.clear(); for (size_t i = 0; i < tf.typeParams.size(); ++i) applyTypeFunction.typeArguments[tf.typeParams[i]] = typeParams[i]; - if (FFlag::LuauTypeAliasPacks) - { - applyTypeFunction.typePackArguments.clear(); - for (size_t i = 0; i < tf.typePackParams.size(); ++i) - applyTypeFunction.typePackArguments[tf.typePackParams[i]] = typePackParams[i]; - } + applyTypeFunction.typePackArguments.clear(); + for (size_t i = 0; i < tf.typePackParams.size(); ++i) + applyTypeFunction.typePackArguments[tf.typePackParams[i]] = typePackParams[i]; applyTypeFunction.currentModule = currentModule; applyTypeFunction.level = scope->level; @@ -4866,9 +4795,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, if (ttv) { ttv->instantiatedTypeParams = typeParams; - - if (FFlag::LuauTypeAliasPacks) - ttv->instantiatedTypePackParams = typePackParams; + ttv->instantiatedTypePackParams = typePackParams; } } else @@ -4884,9 +4811,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, } ttv->instantiatedTypeParams = typeParams; - - if (FFlag::LuauTypeAliasPacks) - ttv->instantiatedTypePackParams = typePackParams; + ttv->instantiatedTypePackParams = typePackParams; } } @@ -4914,7 +4839,7 @@ std::pair, std::vector> TypeChecker::createGener } TypeId g; - if (FFlag::LuauRecursiveTypeParameterRestriction && FFlag::LuauTypeAliasPacks) + if (FFlag::LuauRecursiveTypeParameterRestriction) { TypeId& cached = scope->parent->typeAliasTypeParameters[n]; if (!cached) @@ -4944,7 +4869,7 @@ std::pair, std::vector> TypeChecker::createGener } TypePackId g; - if (FFlag::LuauRecursiveTypeParameterRestriction && FFlag::LuauTypeAliasPacks) + if (FFlag::LuauRecursiveTypeParameterRestriction) { TypePackId& cached = scope->parent->typeAliasTypePackParameters[n]; if (!cached) @@ -5245,7 +5170,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); auto typeFun = globalScope->lookupType(typeguardP.kind); - if (!typeFun || !typeFun->typeParams.empty() || (FFlag::LuauTypeAliasPacks && !typeFun->typePackParams.empty())) + if (!typeFun || !typeFun->typeParams.empty() || !typeFun->typePackParams.empty()) return fail(UnknownSymbol{typeguardP.kind, UnknownSymbol::Type}); TypeId type = follow(typeFun->type); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 924bf082a..62715af53 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -19,7 +19,6 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) -LUAU_FASTFLAG(LuauTypeAliasPacks) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) LUAU_FASTFLAG(LuauErrorRecoveryType) @@ -739,369 +738,6 @@ void persist(TypePackId tp) } } -namespace -{ - -struct StateDot -{ - StateDot(ToDotOptions opts) - : opts(opts) - { - } - - ToDotOptions opts; - - std::unordered_set seenTy; - std::unordered_set seenTp; - std::unordered_map tyToIndex; - std::unordered_map tpToIndex; - int nextIndex = 1; - std::string result; - - bool canDuplicatePrimitive(TypeId ty); - - void visitChildren(TypeId ty, int index); - void visitChildren(TypePackId ty, int index); - - void visitChild(TypeId ty, int parentIndex, const char* linkName = nullptr); - void visitChild(TypePackId tp, int parentIndex, const char* linkName = nullptr); - - void startNode(int index); - void finishNode(); - - void startNodeLabel(); - void finishNodeLabel(TypeId ty); - void finishNodeLabel(TypePackId tp); -}; - -bool StateDot::canDuplicatePrimitive(TypeId ty) -{ - if (get(ty)) - return false; - - return get(ty) || get(ty); -} - -void StateDot::visitChild(TypeId ty, int parentIndex, const char* linkName) -{ - if (!tyToIndex.count(ty) || (opts.duplicatePrimitives && canDuplicatePrimitive(ty))) - tyToIndex[ty] = nextIndex++; - - int index = tyToIndex[ty]; - - if (parentIndex != 0) - { - if (linkName) - formatAppend(result, "n%d -> n%d [label=\"%s\"];\n", parentIndex, index, linkName); - else - formatAppend(result, "n%d -> n%d;\n", parentIndex, index); - } - - if (opts.duplicatePrimitives && canDuplicatePrimitive(ty)) - { - if (get(ty)) - formatAppend(result, "n%d [label=\"%s\"];\n", index, toStringDetailed(ty, {}).name.c_str()); - else if (get(ty)) - formatAppend(result, "n%d [label=\"any\"];\n", index); - } - else - { - visitChildren(ty, index); - } -} - -void StateDot::visitChild(TypePackId tp, int parentIndex, const char* linkName) -{ - if (!tpToIndex.count(tp)) - tpToIndex[tp] = nextIndex++; - - if (linkName) - formatAppend(result, "n%d -> n%d [label=\"%s\"];\n", parentIndex, tpToIndex[tp], linkName); - else - formatAppend(result, "n%d -> n%d;\n", parentIndex, tpToIndex[tp]); - - visitChildren(tp, tpToIndex[tp]); -} - -void StateDot::startNode(int index) -{ - formatAppend(result, "n%d [", index); -} - -void StateDot::finishNode() -{ - formatAppend(result, "];\n"); -} - -void StateDot::startNodeLabel() -{ - formatAppend(result, "label=\""); -} - -void StateDot::finishNodeLabel(TypeId ty) -{ - if (opts.showPointers) - formatAppend(result, "\n0x%p", ty); - // additional common attributes can be added here as well - result += "\""; -} - -void StateDot::finishNodeLabel(TypePackId tp) -{ - if (opts.showPointers) - formatAppend(result, "\n0x%p", tp); - // additional common attributes can be added here as well - result += "\""; -} - -void StateDot::visitChildren(TypeId ty, int index) -{ - if (seenTy.count(ty)) - return; - seenTy.insert(ty); - - startNode(index); - startNodeLabel(); - - if (const BoundTypeVar* btv = get(ty)) - { - formatAppend(result, "BoundTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - visitChild(btv->boundTo, index); - } - else if (const FunctionTypeVar* ftv = get(ty)) - { - formatAppend(result, "FunctionTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - visitChild(ftv->argTypes, index, "arg"); - visitChild(ftv->retType, index, "ret"); - } - else if (const TableTypeVar* ttv = get(ty)) - { - if (ttv->name) - formatAppend(result, "TableTypeVar %s", ttv->name->c_str()); - else if (ttv->syntheticName) - formatAppend(result, "TableTypeVar %s", ttv->syntheticName->c_str()); - else - formatAppend(result, "TableTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - if (ttv->boundTo) - return visitChild(*ttv->boundTo, index, "boundTo"); - - for (const auto& [name, prop] : ttv->props) - visitChild(prop.type, index, name.c_str()); - if (ttv->indexer) - { - visitChild(ttv->indexer->indexType, index, "[index]"); - visitChild(ttv->indexer->indexResultType, index, "[value]"); - } - for (TypeId itp : ttv->instantiatedTypeParams) - visitChild(itp, index, "typeParam"); - - if (FFlag::LuauTypeAliasPacks) - { - for (TypePackId itp : ttv->instantiatedTypePackParams) - visitChild(itp, index, "typePackParam"); - } - } - else if (const MetatableTypeVar* mtv = get(ty)) - { - formatAppend(result, "MetatableTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - visitChild(mtv->table, index, "table"); - visitChild(mtv->metatable, index, "metatable"); - } - else if (const UnionTypeVar* utv = get(ty)) - { - formatAppend(result, "UnionTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - for (TypeId opt : utv->options) - visitChild(opt, index); - } - else if (const IntersectionTypeVar* itv = get(ty)) - { - formatAppend(result, "IntersectionTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - - for (TypeId part : itv->parts) - visitChild(part, index); - } - else if (const GenericTypeVar* gtv = get(ty)) - { - if (gtv->explicitName) - formatAppend(result, "GenericTypeVar %s", gtv->name.c_str()); - else - formatAppend(result, "GenericTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (const FreeTypeVar* ftv = get(ty)) - { - formatAppend(result, "FreeTypeVar %d", ftv->index); - finishNodeLabel(ty); - finishNode(); - } - else if (get(ty)) - { - formatAppend(result, "AnyTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (get(ty)) - { - formatAppend(result, "PrimitiveTypeVar %s", toStringDetailed(ty, {}).name.c_str()); - finishNodeLabel(ty); - finishNode(); - } - else if (get(ty)) - { - formatAppend(result, "ErrorTypeVar %d", index); - finishNodeLabel(ty); - finishNode(); - } - else if (const ClassTypeVar* ctv = get(ty)) - { - formatAppend(result, "ClassTypeVar %s", ctv->name.c_str()); - finishNodeLabel(ty); - finishNode(); - - for (const auto& [name, prop] : ctv->props) - visitChild(prop.type, index, name.c_str()); - - if (ctv->parent) - visitChild(*ctv->parent, index, "[parent]"); - - if (ctv->metatable) - visitChild(*ctv->metatable, index, "[metatable]"); - } - else - { - LUAU_ASSERT(!"unknown type kind"); - finishNodeLabel(ty); - finishNode(); - } -} - -void StateDot::visitChildren(TypePackId tp, int index) -{ - if (seenTp.count(tp)) - return; - seenTp.insert(tp); - - startNode(index); - startNodeLabel(); - - if (const BoundTypePack* btp = get(tp)) - { - formatAppend(result, "BoundTypePack %d", index); - finishNodeLabel(tp); - finishNode(); - - visitChild(btp->boundTo, index); - } - else if (const TypePack* tpp = get(tp)) - { - formatAppend(result, "TypePack %d", index); - finishNodeLabel(tp); - finishNode(); - - for (TypeId tv : tpp->head) - visitChild(tv, index); - if (tpp->tail) - visitChild(*tpp->tail, index, "tail"); - } - else if (const VariadicTypePack* vtp = get(tp)) - { - formatAppend(result, "VariadicTypePack %d", index); - finishNodeLabel(tp); - finishNode(); - - visitChild(vtp->ty, index); - } - else if (const FreeTypePack* ftp = get(tp)) - { - formatAppend(result, "FreeTypePack %d", ftp->index); - finishNodeLabel(tp); - finishNode(); - } - else if (const GenericTypePack* gtp = get(tp)) - { - if (gtp->explicitName) - formatAppend(result, "GenericTypePack %s", gtp->name.c_str()); - else - formatAppend(result, "GenericTypePack %d", gtp->index); - finishNodeLabel(tp); - finishNode(); - } - else if (get(tp)) - { - formatAppend(result, "ErrorTypePack %d", index); - finishNodeLabel(tp); - finishNode(); - } - else - { - LUAU_ASSERT(!"unknown type pack kind"); - finishNodeLabel(tp); - finishNode(); - } -} - -} // namespace - -std::string toDot(TypeId ty, const ToDotOptions& opts) -{ - StateDot state{opts}; - - state.result = "digraph graphname {\n"; - state.visitChild(ty, 0); - state.result += "}"; - - return state.result; -} - -std::string toDot(TypePackId tp, const ToDotOptions& opts) -{ - StateDot state{opts}; - - state.result = "digraph graphname {\n"; - state.visitChild(tp, 0); - state.result += "}"; - - return state.result; -} - -std::string toDot(TypeId ty) -{ - return toDot(ty, {}); -} - -std::string toDot(TypePackId tp) -{ - return toDot(tp, {}); -} - -void dumpDot(TypeId ty) -{ - printf("%s\n", toDot(ty).c_str()); -} - -void dumpDot(TypePackId tp) -{ - printf("%s\n", toDot(tp).c_str()); -} - const TypeLevel* getLevel(TypeId ty) { ty = follow(ty); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index e1a52be4e..d0b188374 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -18,9 +18,6 @@ LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance, false); LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) -LUAU_FASTFLAGVARIABLE(LuauTypecheckOpts, false) -LUAU_FASTFLAG(LuauShareTxnSeen); -LUAU_FASTFLAGVARIABLE(LuauCacheUnifyTableResults, false) LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false) LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false) @@ -136,38 +133,19 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Locati , globalScope(std::move(globalScope)) , location(location) , variance(variance) - , counters(&countersData) - , counters_DEPRECATED(std::make_shared()) - , sharedState(sharedState) -{ - LUAU_ASSERT(sharedState.iceHandler); -} - -Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const std::vector>& ownedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) - : types(types) - , mode(mode) - , globalScope(std::move(globalScope)) - , log(ownedSeen) - , location(location) - , variance(variance) - , counters(counters ? counters : &countersData) - , counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared()) , sharedState(sharedState) { LUAU_ASSERT(sharedState.iceHandler); } Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState, const std::shared_ptr& counters_DEPRECATED, UnifierCounters* counters) + Variance variance, UnifierSharedState& sharedState) : types(types) , mode(mode) , globalScope(std::move(globalScope)) , log(sharedSeen) , location(location) , variance(variance) - , counters(counters ? counters : &countersData) - , counters_DEPRECATED(counters_DEPRECATED ? counters_DEPRECATED : std::make_shared()) , sharedState(sharedState) { LUAU_ASSERT(sharedState.iceHandler); @@ -175,26 +153,18 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector< void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) { - if (FFlag::LuauTypecheckOpts) - counters->iterationCount = 0; - else - counters_DEPRECATED->iterationCount = 0; + sharedState.counters.iterationCount = 0; tryUnify_(superTy, subTy, isFunctionCall, isIntersection); } void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) { - RecursionLimiter _ra( - FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); - if (FFlag::LuauTypecheckOpts) - ++counters->iterationCount; - else - ++counters_DEPRECATED->iterationCount; + ++sharedState.counters.iterationCount; - if (FInt::LuauTypeInferIterationLimit > 0 && - FInt::LuauTypeInferIterationLimit < (FFlag::LuauTypecheckOpts ? counters->iterationCount : counters_DEPRECATED->iterationCount)) + if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) { errors.push_back(TypeError{location, UnificationTooComplex{}}); return; @@ -302,7 +272,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (get(subTy) || get(subTy)) return tryUnifyWithAny(subTy, superTy); - bool cacheEnabled = FFlag::LuauCacheUnifyTableResults && !isFunctionCall && !isIntersection; + bool cacheEnabled = !isFunctionCall && !isIntersection; auto& cache = sharedState.cachedUnify; // What if the types are immutable and we proved their relation before @@ -563,8 +533,6 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool void Unifier::cacheResult(TypeId superTy, TypeId subTy) { - LUAU_ASSERT(FFlag::LuauCacheUnifyTableResults); - bool* superTyInfo = sharedState.skipCacheForType.find(superTy); if (superTyInfo && *superTyInfo) @@ -686,10 +654,7 @@ ErrorVec Unifier::canUnify(TypePackId superTy, TypePackId subTy, bool isFunction void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall) { - if (FFlag::LuauTypecheckOpts) - counters->iterationCount = 0; - else - counters_DEPRECATED->iterationCount = 0; + sharedState.counters.iterationCount = 0; tryUnify_(superTp, subTp, isFunctionCall); } @@ -700,16 +665,11 @@ void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall */ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCall) { - RecursionLimiter _ra( - FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); - if (FFlag::LuauTypecheckOpts) - ++counters->iterationCount; - else - ++counters_DEPRECATED->iterationCount; + ++sharedState.counters.iterationCount; - if (FInt::LuauTypeInferIterationLimit > 0 && - FInt::LuauTypeInferIterationLimit < (FFlag::LuauTypecheckOpts ? counters->iterationCount : counters_DEPRECATED->iterationCount)) + if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) { errors.push_back(TypeError{location, UnificationTooComplex{}}); return; @@ -1727,39 +1687,8 @@ void Unifier::tryUnify(const TableIndexer& superIndexer, const TableIndexer& sub tryUnify_(superIndexer.indexResultType, subIndexer.indexResultType); } -static void queueTypePack_DEPRECATED( - std::vector& queue, std::unordered_set& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) -{ - LUAU_ASSERT(!FFlag::LuauTypecheckOpts); - - while (true) - { - a = follow(a); - - if (seenTypePacks.count(a)) - break; - seenTypePacks.insert(a); - - if (get(a)) - { - state.log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; - } - else if (auto tp = get(a)) - { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; - } - } -} - static void queueTypePack(std::vector& queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) { - LUAU_ASSERT(FFlag::LuauTypecheckOpts); - while (true) { a = follow(a); @@ -1837,66 +1766,9 @@ void Unifier::tryUnifyVariadics(TypePackId superTp, TypePackId subTp, bool rever } } -static void tryUnifyWithAny_DEPRECATED( - std::vector& queue, Unifier& state, std::unordered_set& seenTypePacks, TypeId anyType, TypePackId anyTypePack) -{ - LUAU_ASSERT(!FFlag::LuauTypecheckOpts); - - std::unordered_set seen; - - while (!queue.empty()) - { - TypeId ty = follow(queue.back()); - queue.pop_back(); - if (seen.count(ty)) - continue; - seen.insert(ty); - - if (get(ty)) - { - state.log(ty); - *asMutable(ty) = BoundTypeVar{anyType}; - } - else if (auto fun = get(ty)) - { - queueTypePack_DEPRECATED(queue, seenTypePacks, state, fun->argTypes, anyTypePack); - queueTypePack_DEPRECATED(queue, seenTypePacks, state, fun->retType, anyTypePack); - } - else if (auto table = get(ty)) - { - for (const auto& [_name, prop] : table->props) - queue.push_back(prop.type); - - if (table->indexer) - { - queue.push_back(table->indexer->indexType); - queue.push_back(table->indexer->indexResultType); - } - } - else if (auto mt = get(ty)) - { - queue.push_back(mt->table); - queue.push_back(mt->metatable); - } - else if (get(ty)) - { - // ClassTypeVars never contain free typevars. - } - else if (auto union_ = get(ty)) - queue.insert(queue.end(), union_->options.begin(), union_->options.end()); - else if (auto intersection = get(ty)) - queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); - else - { - } // Primitives, any, errors, and generics are left untouched. - } -} - static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHashSet& seen, DenseHashSet& seenTypePacks, TypeId anyType, TypePackId anyTypePack) { - LUAU_ASSERT(FFlag::LuauTypecheckOpts); - while (!queue.empty()) { TypeId ty = follow(queue.back()); @@ -1949,43 +1821,20 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) { LUAU_ASSERT(get(any) || get(any)); - if (FFlag::LuauTypecheckOpts) - { - // These types are not visited in general loop below - if (get(ty) || get(ty) || get(ty)) - return; - } + // These types are not visited in general loop below + if (get(ty) || get(ty) || get(ty)) + return; const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{singletonTypes.anyType}}); const TypePackId anyTP = get(any) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); - if (FFlag::LuauTypecheckOpts) - { - std::vector queue = {ty}; - - if (FFlag::LuauCacheUnifyTableResults) - { - sharedState.tempSeenTy.clear(); - sharedState.tempSeenTp.clear(); - - Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, singletonTypes.anyType, anyTP); - } - else - { - tempSeenTy_DEPRECATED.clear(); - tempSeenTp_DEPRECATED.clear(); + std::vector queue = {ty}; - Luau::tryUnifyWithAny(queue, *this, tempSeenTy_DEPRECATED, tempSeenTp_DEPRECATED, singletonTypes.anyType, anyTP); - } - } - else - { - std::unordered_set seenTypePacks; - std::vector queue = {ty}; + sharedState.tempSeenTy.clear(); + sharedState.tempSeenTp.clear(); - Luau::tryUnifyWithAny_DEPRECATED(queue, *this, seenTypePacks, singletonTypes.anyType, anyTP); - } + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, singletonTypes.anyType, anyTP); } void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) @@ -1994,38 +1843,14 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) const TypeId anyTy = singletonTypes.errorRecoveryType(); - if (FFlag::LuauTypecheckOpts) - { - std::vector queue; - - if (FFlag::LuauCacheUnifyTableResults) - { - sharedState.tempSeenTy.clear(); - sharedState.tempSeenTp.clear(); - - queueTypePack(queue, sharedState.tempSeenTp, *this, ty, any); - - Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, anyTy, any); - } - else - { - tempSeenTy_DEPRECATED.clear(); - tempSeenTp_DEPRECATED.clear(); - - queueTypePack(queue, tempSeenTp_DEPRECATED, *this, ty, any); + std::vector queue; - Luau::tryUnifyWithAny(queue, *this, tempSeenTy_DEPRECATED, tempSeenTp_DEPRECATED, anyTy, any); - } - } - else - { - std::unordered_set seenTypePacks; - std::vector queue; + sharedState.tempSeenTy.clear(); + sharedState.tempSeenTp.clear(); - queueTypePack_DEPRECATED(queue, seenTypePacks, *this, ty, any); + queueTypePack(queue, sharedState.tempSeenTp, *this, ty, any); - Luau::tryUnifyWithAny_DEPRECATED(queue, *this, seenTypePacks, anyTy, any); - } + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, anyTy, any); } std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, Name name) @@ -2035,46 +1860,22 @@ std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, N void Unifier::occursCheck(TypeId needle, TypeId haystack) { - std::unordered_set seen_DEPRECATED; + sharedState.tempSeenTy.clear(); - if (FFlag::LuauCacheUnifyTableResults) - { - if (FFlag::LuauTypecheckOpts) - sharedState.tempSeenTy.clear(); - - return occursCheck(seen_DEPRECATED, sharedState.tempSeenTy, needle, haystack); - } - else - { - if (FFlag::LuauTypecheckOpts) - tempSeenTy_DEPRECATED.clear(); - - return occursCheck(seen_DEPRECATED, tempSeenTy_DEPRECATED, needle, haystack); - } + return occursCheck(sharedState.tempSeenTy, needle, haystack); } -void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypeId needle, TypeId haystack) +void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId haystack) { - RecursionLimiter _ra( - FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); needle = follow(needle); haystack = follow(haystack); - if (FFlag::LuauTypecheckOpts) - { - if (seen.find(haystack)) - return; - - seen.insert(haystack); - } - else - { - if (seen_DEPRECATED.end() != seen_DEPRECATED.find(haystack)) - return; + if (seen.find(haystack)) + return; - seen_DEPRECATED.insert(haystack); - } + seen.insert(haystack); if (get(needle)) return; @@ -2091,7 +1892,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHash } auto check = [&](TypeId tv) { - occursCheck(seen_DEPRECATED, seen, needle, tv); + occursCheck(seen, needle, tv); }; if (get(haystack)) @@ -2121,43 +1922,20 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHash void Unifier::occursCheck(TypePackId needle, TypePackId haystack) { - std::unordered_set seen_DEPRECATED; - - if (FFlag::LuauCacheUnifyTableResults) - { - if (FFlag::LuauTypecheckOpts) - sharedState.tempSeenTp.clear(); - - return occursCheck(seen_DEPRECATED, sharedState.tempSeenTp, needle, haystack); - } - else - { - if (FFlag::LuauTypecheckOpts) - tempSeenTp_DEPRECATED.clear(); + sharedState.tempSeenTp.clear(); - return occursCheck(seen_DEPRECATED, tempSeenTp_DEPRECATED, needle, haystack); - } + return occursCheck(sharedState.tempSeenTp, needle, haystack); } -void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, DenseHashSet& seen, TypePackId needle, TypePackId haystack) +void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack) { needle = follow(needle); haystack = follow(haystack); - if (FFlag::LuauTypecheckOpts) - { - if (seen.find(haystack)) - return; - - seen.insert(haystack); - } - else - { - if (seen_DEPRECATED.end() != seen_DEPRECATED.find(haystack)) - return; + if (seen.find(haystack)) + return; - seen_DEPRECATED.insert(haystack); - } + seen.insert(haystack); if (get(needle)) return; @@ -2165,8 +1943,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, Dense if (!get(needle)) ice("Expected needle pack to be free"); - RecursionLimiter _ra( - FFlag::LuauTypecheckOpts ? &counters->recursionCount : &counters_DEPRECATED->recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); while (!get(haystack)) { @@ -2186,8 +1963,8 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, Dense { if (auto f = get(follow(ty))) { - occursCheck(seen_DEPRECATED, seen, needle, f->argTypes); - occursCheck(seen_DEPRECATED, seen, needle, f->retType); + occursCheck(seen, needle, f->argTypes); + occursCheck(seen, needle, f->retType); } } } @@ -2204,10 +1981,7 @@ void Unifier::occursCheck(std::unordered_set& seen_DEPRECATED, Dense Unifier Unifier::makeChildUnifier() { - if (FFlag::LuauShareTxnSeen) - return Unifier{types, mode, globalScope, log.sharedSeen, location, variance, sharedState, counters_DEPRECATED, counters}; - else - return Unifier{types, mode, globalScope, log.ownedSeen, location, variance, sharedState, counters_DEPRECATED, counters}; + return Unifier{types, mode, globalScope, log.sharedSeen, location, variance, sharedState}; } bool Unifier::isNonstrictMode() const diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index bc63e37dc..3d0d5b7e6 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -13,8 +13,6 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauCaptureBrokenCommentSpans, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) -LUAU_FASTFLAGVARIABLE(LuauTypeAliasPacks, false) -LUAU_FASTFLAGVARIABLE(LuauParseTypePackTypeParameters, false) LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctionTypeBegin, false) @@ -782,8 +780,7 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) AstType* type = parseTypeAnnotation(); - return allocator.alloc( - Location(start, type->location), name->name, generics, FFlag::LuauTypeAliasPacks ? genericPacks : AstArray{}, type, exported); + return allocator.alloc(Location(start, type->location), name->name, generics, genericPacks, type, exported); } AstDeclaredClassProp Parser::parseDeclaredClassMethod() @@ -1602,30 +1599,18 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) return {allocator.alloc(Location(begin, end), expr), {}}; } - if (FFlag::LuauParseTypePackTypeParameters) - { - bool hasParameters = false; - AstArray parameters{}; - - if (lexer.current().type == '<') - { - hasParameters = true; - parameters = parseTypeParams(); - } + bool hasParameters = false; + AstArray parameters{}; - Location end = lexer.previousLocation(); - - return {allocator.alloc(Location(begin, end), prefix, name.name, hasParameters, parameters), {}}; - } - else + if (lexer.current().type == '<') { - AstArray generics = parseTypeParams(); + hasParameters = true; + parameters = parseTypeParams(); + } - Location end = lexer.previousLocation(); + Location end = lexer.previousLocation(); - // false in 'hasParameterList' as it is not used without FFlagLuauTypeAliasPacks - return {allocator.alloc(Location(begin, end), prefix, name.name, false, generics), {}}; - } + return {allocator.alloc(Location(begin, end), prefix, name.name, hasParameters, parameters), {}}; } else if (lexer.current().type == '{') { @@ -2414,37 +2399,24 @@ AstArray Parser::parseTypeParams() while (true) { - if (FFlag::LuauParseTypePackTypeParameters) + if (shouldParseTypePackAnnotation(lexer)) + { + auto typePack = parseTypePackAnnotation(); + + parameters.push_back({{}, typePack}); + } + else if (lexer.current().type == '(') { - if (shouldParseTypePackAnnotation(lexer)) - { - auto typePack = parseTypePackAnnotation(); - - if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them - parameters.push_back({{}, typePack}); - } - else if (lexer.current().type == '(') - { - auto [type, typePack] = parseTypeOrPackAnnotation(); - - if (typePack) - { - if (FFlag::LuauTypeAliasPacks) // Type packs are recorded only is we can handle them - parameters.push_back({{}, typePack}); - } - else - { - parameters.push_back({type, {}}); - } - } - else if (lexer.current().type == '>' && parameters.empty()) - { - break; - } + auto [type, typePack] = parseTypeOrPackAnnotation(); + + if (typePack) + parameters.push_back({{}, typePack}); else - { - parameters.push_back({parseTypeAnnotation(), {}}); - } + parameters.push_back({type, {}}); + } + else if (lexer.current().type == '>' && parameters.empty()) + { + break; } else { diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index ebdd78966..9230d80d0 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -121,7 +121,7 @@ struct CliFileResolver : Luau::FileResolver if (Luau::AstExprConstantString* expr = node->as()) { Luau::ModuleName name = std::string(expr->value.data, expr->value.size) + ".luau"; - if (!moduleExists(name)) + if (!readFile(name)) { // fall back to .lua if a module with .luau doesn't exist name = std::string(expr->value.data, expr->value.size) + ".lua"; @@ -132,27 +132,6 @@ struct CliFileResolver : Luau::FileResolver return std::nullopt; } - - bool moduleExists(const Luau::ModuleName& name) const override - { - return !!readFile(name); - } - - - std::optional fromAstFragment(Luau::AstExpr* expr) const override - { - return std::nullopt; - } - - Luau::ModuleName concat(const Luau::ModuleName& lhs, std::string_view rhs) const override - { - return lhs + "/" + std::string(rhs); - } - - std::optional getParentModuleName(const Luau::ModuleName& name) const override - { - return std::nullopt; - } }; struct CliConfigResolver : Luau::ConfigResolver diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index b29cd6f9c..2cdd0062f 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -198,11 +198,6 @@ static std::string runCode(lua_State* L, const std::string& source) error += "\nstack backtrace:\n"; error += lua_debugtrace(T); -#ifdef __EMSCRIPTEN__ - // nicer formatting for errors in web repl - error = "Error:" + error; -#endif - fprintf(stdout, "%s", error.c_str()); } @@ -210,39 +205,6 @@ static std::string runCode(lua_State* L, const std::string& source) return std::string(); } -#ifdef __EMSCRIPTEN__ -extern "C" -{ - const char* executeScript(const char* source) - { - // setup flags - for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) - if (strncmp(flag->name, "Luau", 4) == 0) - flag->value = true; - - // create new state - std::unique_ptr globalState(luaL_newstate(), lua_close); - lua_State* L = globalState.get(); - - // setup state - setupState(L); - - // sandbox thread - luaL_sandboxthread(L); - - // static string for caching result (prevents dangling ptr on function exit) - static std::string result; - - // run code + collect error - result = runCode(L, source); - - return result.empty() ? NULL : result.c_str(); - } -} -#endif - -// Excluded from emscripten compilation to avoid -Wunused-function errors. -#ifndef __EMSCRIPTEN__ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, std::vector& completions) { std::string_view lookup = editBuffer + start; @@ -564,6 +526,5 @@ int main(int argc, char** argv) return failed; } } -#endif diff --git a/CLI/Web.cpp b/CLI/Web.cpp new file mode 100644 index 000000000..cf5c831e9 --- /dev/null +++ b/CLI/Web.cpp @@ -0,0 +1,106 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "lua.h" +#include "lualib.h" +#include "luacode.h" + +#include "Luau/Common.h" + +#include + +#include + +static void setupState(lua_State* L) +{ + luaL_openlibs(L); + + luaL_sandbox(L); +} + +static std::string runCode(lua_State* L, const std::string& source) +{ + size_t bytecodeSize = 0; + char* bytecode = luau_compile(source.data(), source.length(), nullptr, &bytecodeSize); + int result = luau_load(L, "=stdin", bytecode, bytecodeSize, 0); + free(bytecode); + + if (result != 0) + { + size_t len; + const char* msg = lua_tolstring(L, -1, &len); + + std::string error(msg, len); + lua_pop(L, 1); + + return error; + } + + lua_State* T = lua_newthread(L); + + lua_pushvalue(L, -2); + lua_remove(L, -3); + lua_xmove(L, T, 1); + + int status = lua_resume(T, NULL, 0); + + if (status == 0) + { + int n = lua_gettop(T); + + if (n) + { + luaL_checkstack(T, LUA_MINSTACK, "too many results to print"); + lua_getglobal(T, "print"); + lua_insert(T, 1); + lua_pcall(T, n, 0, 0); + } + } + else + { + std::string error; + + if (status == LUA_YIELD) + { + error = "thread yielded unexpectedly"; + } + else if (const char* str = lua_tostring(T, -1)) + { + error = str; + } + + error += "\nstack backtrace:\n"; + error += lua_debugtrace(T); + + error = "Error:" + error; + + fprintf(stdout, "%s", error.c_str()); + } + + lua_pop(L, 1); + return std::string(); +} + +extern "C" const char* executeScript(const char* source) +{ + // setup flags + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + if (strncmp(flag->name, "Luau", 4) == 0) + flag->value = true; + + // create new state + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + // setup state + setupState(L); + + // sandbox thread + luaL_sandboxthread(L); + + // static string for caching result (prevents dangling ptr on function exit) + static std::string result; + + // run code + collect error + result = runCode(L, source); + + return result.empty() ? NULL : result.c_str(); +} diff --git a/CMakeLists.txt b/CMakeLists.txt index 9c69521ed..bafc59e59 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -9,6 +9,7 @@ project(Luau LANGUAGES CXX) option(LUAU_BUILD_CLI "Build CLI" ON) option(LUAU_BUILD_TESTS "Build tests" ON) +option(LUAU_BUILD_WEB "Build Web module" OFF) option(LUAU_WERROR "Warnings as errors" OFF) add_library(Luau.Ast STATIC) @@ -18,26 +19,22 @@ add_library(Luau.VM STATIC) if(LUAU_BUILD_CLI) add_executable(Luau.Repl.CLI) - if(NOT EMSCRIPTEN) - add_executable(Luau.Analyze.CLI) - else() - # add -fexceptions for emscripten to allow exceptions to be caught in C++ - set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -fexceptions") - endif() + add_executable(Luau.Analyze.CLI) # This also adds target `name` on Linux/macOS and `name.exe` on Windows set_target_properties(Luau.Repl.CLI PROPERTIES OUTPUT_NAME luau) - - if(NOT EMSCRIPTEN) - set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze) - endif() + set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze) endif() -if(LUAU_BUILD_TESTS AND NOT EMSCRIPTEN) +if(LUAU_BUILD_TESTS) add_executable(Luau.UnitTest) add_executable(Luau.Conformance) endif() +if(LUAU_BUILD_WEB) + add_executable(Luau.Web) +endif() + include(Sources.cmake) target_compile_features(Luau.Ast PUBLIC cxx_std_17) @@ -72,16 +69,18 @@ if(LUAU_WERROR) endif() endif() +if(LUAU_BUILD_WEB) + # add -fexceptions for emscripten to allow exceptions to be caught in C++ + list(APPEND LUAU_OPTIONS -fexceptions) +endif() + target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analysis PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS}) if(LUAU_BUILD_CLI) target_compile_options(Luau.Repl.CLI PRIVATE ${LUAU_OPTIONS}) - - if(NOT EMSCRIPTEN) - target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) - endif() + target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.Repl.CLI PRIVATE extern) target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.VM) @@ -93,20 +92,10 @@ if(LUAU_BUILD_CLI) endif() endif() - if(NOT EMSCRIPTEN) - target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis) - endif() - - if(EMSCRIPTEN) - # declare exported functions to emscripten - target_link_options(Luau.Repl.CLI PRIVATE -sEXPORTED_FUNCTIONS=['_executeScript'] -sEXPORTED_RUNTIME_METHODS=['ccall','cwrap'] -fexceptions) - - # custom output directory for wasm + js file - set_target_properties(Luau.Repl.CLI PROPERTIES RUNTIME_OUTPUT_DIRECTORY ${CMAKE_SOURCE_DIR}/docs/assets/luau) - endif() + target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis) endif() -if(LUAU_BUILD_TESTS AND NOT EMSCRIPTEN) +if(LUAU_BUILD_TESTS) target_compile_options(Luau.UnitTest PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.UnitTest PRIVATE extern) target_link_libraries(Luau.UnitTest PRIVATE Luau.Analysis Luau.Compiler) @@ -115,3 +104,17 @@ if(LUAU_BUILD_TESTS AND NOT EMSCRIPTEN) target_include_directories(Luau.Conformance PRIVATE extern) target_link_libraries(Luau.Conformance PRIVATE Luau.Analysis Luau.Compiler Luau.VM) endif() + +if(LUAU_BUILD_WEB) + target_compile_options(Luau.Web PRIVATE ${LUAU_OPTIONS}) + target_link_libraries(Luau.Web PRIVATE Luau.Compiler Luau.VM) + + # declare exported functions to emscripten + target_link_options(Luau.Web PRIVATE -sEXPORTED_FUNCTIONS=['_executeScript'] -sEXPORTED_RUNTIME_METHODS=['ccall','cwrap']) + + # add -fexceptions for emscripten to allow exceptions to be caught in C++ + target_link_options(Luau.Web PRIVATE -fexceptions) + + # the output is a single .js file with an embedded wasm blob + target_link_options(Luau.Web PRIVATE -sSINGLE_FILE=1) +endif() diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 5b93c1dc0..2c1e85ff0 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -321,13 +321,15 @@ struct Compiler compileExprTempTop(expr->args.data[i], uint8_t(regs + 1 + expr->self + i)); } - setDebugLine(expr->func); + setDebugLineEnd(expr->func); if (expr->self) { AstExprIndexName* fi = expr->func->as(); LUAU_ASSERT(fi); + setDebugLine(fi->indexLocation); + BytecodeBuilder::StringRef iname = sref(fi->index); int32_t cid = bytecode.addConstantString(iname); if (cid < 0) @@ -1313,6 +1315,8 @@ struct Compiler RegScope rs(this); uint8_t reg = compileExprAuto(expr->expr, rs); + setDebugLine(expr->indexLocation); + BytecodeBuilder::StringRef iname = sref(expr->index); int32_t cid = bytecode.addConstantString(iname); if (cid < 0) @@ -2710,6 +2714,12 @@ struct Compiler bytecode.setDebugLine(node->location.begin.line + 1); } + void setDebugLine(const Location& location) + { + if (options.debugLevel >= 1) + bytecode.setDebugLine(location.begin.line + 1); + } + void setDebugLineEnd(AstNode* node) { if (options.debugLevel >= 1) @@ -3650,7 +3660,7 @@ struct Compiler { if (options.vectorLib) { - if (builtin.object == options.vectorLib && builtin.method == options.vectorCtor) + if (builtin.isMethod(options.vectorLib, options.vectorCtor)) return LBF_VECTOR; } else diff --git a/Makefile b/Makefile index cab3d43f1..15c7ff7a4 100644 --- a/Makefile +++ b/Makefile @@ -35,7 +35,7 @@ ANALYZE_CLI_SOURCES=CLI/FileUtils.cpp CLI/Analyze.cpp ANALYZE_CLI_OBJECTS=$(ANALYZE_CLI_SOURCES:%=$(BUILD)/%.o) ANALYZE_CLI_TARGET=$(BUILD)/luau-analyze -FUZZ_SOURCES=$(wildcard fuzz/*.cpp) +FUZZ_SOURCES=$(wildcard fuzz/*.cpp) fuzz/luau.pb.cpp FUZZ_OBJECTS=$(FUZZ_SOURCES:%=$(BUILD)/%.o) TESTS_ARGS= @@ -167,8 +167,8 @@ fuzz/luau.pb.cpp: fuzz/luau.proto build/libprotobuf-mutator cd fuzz && ../build/libprotobuf-mutator/external.protobuf/bin/protoc luau.proto --cpp_out=. mv fuzz/luau.pb.cc fuzz/luau.pb.cpp -$(BUILD)/fuzz/proto.cpp.o: build/libprotobuf-mutator -$(BUILD)/fuzz/protoprint.cpp.o: build/libprotobuf-mutator +$(BUILD)/fuzz/proto.cpp.o: fuzz/luau.pb.cpp +$(BUILD)/fuzz/protoprint.cpp.o: fuzz/luau.pb.cpp build/libprotobuf-mutator: git clone https://github.com/google/libprotobuf-mutator build/libprotobuf-mutator diff --git a/Sources.cmake b/Sources.cmake index 23b931c6b..57df9b91e 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -54,6 +54,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/Scope.h Analysis/include/Luau/Substitution.h Analysis/include/Luau/Symbol.h + Analysis/include/Luau/ToDot.h Analysis/include/Luau/TopoSortStatements.h Analysis/include/Luau/ToString.h Analysis/include/Luau/Transpiler.h @@ -86,6 +87,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/Scope.cpp Analysis/src/Substitution.cpp Analysis/src/Symbol.cpp + Analysis/src/ToDot.cpp Analysis/src/TopoSortStatements.cpp Analysis/src/ToString.cpp Analysis/src/Transpiler.cpp @@ -118,6 +120,7 @@ target_sources(Luau.VM PRIVATE VM/src/ldo.cpp VM/src/lfunc.cpp VM/src/lgc.cpp + VM/src/lgcdebug.cpp VM/src/linit.cpp VM/src/lmathlib.cpp VM/src/lmem.cpp @@ -194,6 +197,7 @@ if(TARGET Luau.UnitTest) tests/RequireTracer.test.cpp tests/StringUtils.test.cpp tests/Symbol.test.cpp + tests/ToDot.test.cpp tests/TopoSort.test.cpp tests/ToString.test.cpp tests/Transpiler.test.cpp @@ -224,3 +228,9 @@ if(TARGET Luau.Conformance) tests/Conformance.test.cpp tests/main.cpp) endif() + +if(TARGET Luau.Web) + # Luau.Web Sources + target_sources(Luau.Web PRIVATE + CLI/Web.cpp) +endif() diff --git a/VM/include/lua.h b/VM/include/lua.h index 1568d191f..7078acd0f 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -21,6 +21,7 @@ #define LUA_ENVIRONINDEX (-10001) #define LUA_GLOBALSINDEX (-10002) #define lua_upvalueindex(i) (LUA_GLOBALSINDEX - (i)) +#define lua_ispseudo(i) ((i) <= LUA_REGISTRYINDEX) /* thread status; 0 is OK */ enum lua_Status @@ -108,6 +109,7 @@ LUA_API int lua_isthreadreset(lua_State* L); /* ** basic stack manipulation */ +LUA_API int lua_absindex(lua_State* L, int idx); LUA_API int lua_gettop(lua_State* L); LUA_API void lua_settop(lua_State* L, int idx); LUA_API void lua_pushvalue(lua_State* L, int idx); @@ -159,7 +161,11 @@ LUA_API void lua_pushnil(lua_State* L); LUA_API void lua_pushnumber(lua_State* L, double n); LUA_API void lua_pushinteger(lua_State* L, int n); LUA_API void lua_pushunsigned(lua_State* L, unsigned n); +#if LUA_VECTOR_SIZE == 4 +LUA_API void lua_pushvector(lua_State* L, float x, float y, float z, float w); +#else LUA_API void lua_pushvector(lua_State* L, float x, float y, float z); +#endif LUA_API void lua_pushlstring(lua_State* L, const char* s, size_t l); LUA_API void lua_pushstring(lua_State* L, const char* s); LUA_API const char* lua_pushvfstring(lua_State* L, const char* fmt, va_list argp); @@ -183,7 +189,7 @@ LUA_API void lua_setreadonly(lua_State* L, int idx, int enabled); LUA_API int lua_getreadonly(lua_State* L, int idx); LUA_API void lua_setsafeenv(lua_State* L, int idx, int enabled); -LUA_API void* lua_newuserdata(lua_State* L, size_t sz, int tag); +LUA_API void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag); LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)); LUA_API int lua_getmetatable(lua_State* L, int objindex); LUA_API void lua_getfenv(lua_State* L, int idx); @@ -227,6 +233,7 @@ enum lua_GCOp LUA_GCRESTART, LUA_GCCOLLECT, LUA_GCCOUNT, + LUA_GCCOUNTB, LUA_GCISRUNNING, // garbage collection is handled by 'assists' that perform some amount of GC work matching pace of allocation @@ -281,6 +288,7 @@ LUA_API void lua_unref(lua_State* L, int ref); #define lua_pop(L, n) lua_settop(L, -(n)-1) #define lua_newtable(L) lua_createtable(L, 0, 0) +#define lua_newuserdata(L, s) lua_newuserdatatagged(L, s, 0) #define lua_strlen(L, i) lua_objlen(L, (i)) @@ -289,6 +297,7 @@ LUA_API void lua_unref(lua_State* L, int ref); #define lua_islightuserdata(L, n) (lua_type(L, (n)) == LUA_TLIGHTUSERDATA) #define lua_isnil(L, n) (lua_type(L, (n)) == LUA_TNIL) #define lua_isboolean(L, n) (lua_type(L, (n)) == LUA_TBOOLEAN) +#define lua_isvector(L, n) (lua_type(L, (n)) == LUA_TVECTOR) #define lua_isthread(L, n) (lua_type(L, (n)) == LUA_TTHREAD) #define lua_isnone(L, n) (lua_type(L, (n)) == LUA_TNONE) #define lua_isnoneornil(L, n) (lua_type(L, (n)) <= LUA_TNIL) diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h index aa008a240..a01a14819 100644 --- a/VM/include/luaconf.h +++ b/VM/include/luaconf.h @@ -34,7 +34,10 @@ #endif /* Can be used to reconfigure visibility/exports for public APIs */ +#ifndef LUA_API #define LUA_API extern +#endif + #define LUALIB_API LUA_API /* Can be used to reconfigure visibility for internal APIs */ @@ -47,10 +50,14 @@ #endif /* Can be used to reconfigure internal error handling to use longjmp instead of C++ EH */ +#ifndef LUA_USE_LONGJMP #define LUA_USE_LONGJMP 0 +#endif /* LUA_IDSIZE gives the maximum size for the description of the source */ +#ifndef LUA_IDSIZE #define LUA_IDSIZE 256 +#endif /* @@ LUAI_GCGOAL defines the desired top heap size in relation to the live heap @@ -59,7 +66,9 @@ ** mean larger GC pauses which mean slower collection.) You can also change ** this value dynamically. */ +#ifndef LUAI_GCGOAL #define LUAI_GCGOAL 200 /* 200% (allow heap to double compared to live heap size) */ +#endif /* @@ LUAI_GCSTEPMUL / LUAI_GCSTEPSIZE define the default speed of garbage collection @@ -69,38 +78,63 @@ ** CHANGE it if you want to change the granularity of the garbage ** collection. */ +#ifndef LUAI_GCSTEPMUL #define LUAI_GCSTEPMUL 200 /* GC runs 'twice the speed' of memory allocation */ +#endif + +#ifndef LUAI_GCSTEPSIZE #define LUAI_GCSTEPSIZE 1 /* GC runs every KB of memory allocation */ +#endif /* LUA_MINSTACK is the guaranteed number of Lua stack slots available to a C function */ +#ifndef LUA_MINSTACK #define LUA_MINSTACK 20 +#endif /* LUAI_MAXCSTACK limits the number of Lua stack slots that a C function can use */ +#ifndef LUAI_MAXCSTACK #define LUAI_MAXCSTACK 8000 +#endif /* LUAI_MAXCALLS limits the number of nested calls */ +#ifndef LUAI_MAXCALLS #define LUAI_MAXCALLS 20000 +#endif /* LUAI_MAXCCALLS is the maximum depth for nested C calls; this limit depends on native stack size */ +#ifndef LUAI_MAXCCALLS #define LUAI_MAXCCALLS 200 +#endif /* buffer size used for on-stack string operations; this limit depends on native stack size */ +#ifndef LUA_BUFFERSIZE #define LUA_BUFFERSIZE 512 +#endif /* number of valid Lua userdata tags */ +#ifndef LUA_UTAG_LIMIT #define LUA_UTAG_LIMIT 128 +#endif /* upper bound for number of size classes used by page allocator */ +#ifndef LUA_SIZECLASSES #define LUA_SIZECLASSES 32 +#endif /* available number of separate memory categories */ +#ifndef LUA_MEMORY_CATEGORIES #define LUA_MEMORY_CATEGORIES 256 +#endif /* minimum size for the string table (must be power of 2) */ +#ifndef LUA_MINSTRTABSIZE #define LUA_MINSTRTABSIZE 32 +#endif /* maximum number of captures supported by pattern matching */ +#ifndef LUA_MAXCAPTURES #define LUA_MAXCAPTURES 32 +#endif /* }================================================================== */ @@ -122,3 +156,7 @@ void* s; \ long l; \ } + +#define LUA_VECTOR_SIZE 3 /* must be 3 or 4 */ + +#define LUA_EXTRA_SIZE LUA_VECTOR_SIZE - 2 diff --git a/VM/include/lualib.h b/VM/include/lualib.h index fa836955c..baf27b47e 100644 --- a/VM/include/lualib.h +++ b/VM/include/lualib.h @@ -25,11 +25,17 @@ LUALIB_API const char* luaL_optlstring(lua_State* L, int numArg, const char* def LUALIB_API double luaL_checknumber(lua_State* L, int numArg); LUALIB_API double luaL_optnumber(lua_State* L, int nArg, double def); +LUALIB_API int luaL_checkboolean(lua_State* L, int narg); +LUALIB_API int luaL_optboolean(lua_State* L, int narg, int def); + LUALIB_API int luaL_checkinteger(lua_State* L, int numArg); LUALIB_API int luaL_optinteger(lua_State* L, int nArg, int def); LUALIB_API unsigned luaL_checkunsigned(lua_State* L, int numArg); LUALIB_API unsigned luaL_optunsigned(lua_State* L, int numArg, unsigned def); +LUALIB_API const float* luaL_checkvector(lua_State* L, int narg); +LUALIB_API const float* luaL_optvector(lua_State* L, int narg, const float* def); + LUALIB_API void luaL_checkstack(lua_State* L, int sz, const char* msg); LUALIB_API void luaL_checktype(lua_State* L, int narg, int t); LUALIB_API void luaL_checkany(lua_State* L, int narg); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index a79ba0d40..76043b9cd 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -13,6 +13,8 @@ #include +LUAU_FASTFLAG(LuauActivateBeforeExec) + const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -170,6 +172,12 @@ lua_State* lua_mainthread(lua_State* L) ** basic stack manipulation */ +int lua_absindex(lua_State* L, int idx) +{ + api_check(L, (idx > 0 && idx <= L->top - L->base) || (idx < 0 && -idx <= L->top - L->base) || lua_ispseudo(idx)); + return idx > 0 || lua_ispseudo(idx) ? idx : cast_int(L->top - L->base) + idx + 1; +} + int lua_gettop(lua_State* L) { return cast_int(L->top - L->base); @@ -550,12 +558,21 @@ void lua_pushunsigned(lua_State* L, unsigned u) return; } +#if LUA_VECTOR_SIZE == 4 +void lua_pushvector(lua_State* L, float x, float y, float z, float w) +{ + setvvalue(L->top, x, y, z, w); + api_incr_top(L); + return; +} +#else void lua_pushvector(lua_State* L, float x, float y, float z) { - setvvalue(L->top, x, y, z); + setvvalue(L->top, x, y, z, 0.0f); api_incr_top(L); return; } +#endif void lua_pushlstring(lua_State* L, const char* s, size_t len) { @@ -922,14 +939,21 @@ void lua_call(lua_State* L, int nargs, int nresults) checkresults(L, nargs, nresults); func = L->top - (nargs + 1); - int wasActive = luaC_threadactive(L); - l_setbit(L->stackstate, THREAD_ACTIVEBIT); - luaC_checkthreadsleep(L); + if (FFlag::LuauActivateBeforeExec) + { + luaD_call(L, func, nresults); + } + else + { + int oldactive = luaC_threadactive(L); + l_setbit(L->stackstate, THREAD_ACTIVEBIT); + luaC_checkthreadsleep(L); - luaD_call(L, func, nresults); + luaD_call(L, func, nresults); - if (!wasActive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); + if (!oldactive) + resetbit(L->stackstate, THREAD_ACTIVEBIT); + } adjustresults(L, nresults); return; @@ -970,14 +994,21 @@ int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc) c.func = L->top - (nargs + 1); /* function to be called */ c.nresults = nresults; - int wasActive = luaC_threadactive(L); - l_setbit(L->stackstate, THREAD_ACTIVEBIT); - luaC_checkthreadsleep(L); + if (FFlag::LuauActivateBeforeExec) + { + status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); + } + else + { + int oldactive = luaC_threadactive(L); + l_setbit(L->stackstate, THREAD_ACTIVEBIT); + luaC_checkthreadsleep(L); - status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); + status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); - if (!wasActive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); + if (!oldactive) + resetbit(L->stackstate, THREAD_ACTIVEBIT); + } adjustresults(L, nresults); return status; @@ -1030,6 +1061,11 @@ int lua_gc(lua_State* L, int what, int data) res = cast_int(g->totalbytes >> 10); break; } + case LUA_GCCOUNTB: + { + res = cast_int(g->totalbytes & 1023); + break; + } case LUA_GCISRUNNING: { res = (g->GCthreshold != SIZE_MAX); @@ -1146,7 +1182,7 @@ void lua_concat(lua_State* L, int n) return; } -void* lua_newuserdata(lua_State* L, size_t sz, int tag) +void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag) { api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); luaC_checkGC(L); @@ -1231,6 +1267,7 @@ uintptr_t lua_encodepointer(lua_State* L, uintptr_t p) int lua_ref(lua_State* L, int idx) { + api_check(L, idx != LUA_REGISTRYINDEX); /* idx is a stack index for value */ int ref = LUA_REFNIL; global_State* g = L->global; StkId p = index2adr(L, idx); diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index 2a684ee4e..7ed2a62ee 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -30,7 +30,7 @@ static const char* currfuncname(lua_State* L) return debugname; } -LUALIB_API l_noret luaL_argerrorL(lua_State* L, int narg, const char* extramsg) +l_noret luaL_argerrorL(lua_State* L, int narg, const char* extramsg) { const char* fname = currfuncname(L); @@ -40,7 +40,7 @@ LUALIB_API l_noret luaL_argerrorL(lua_State* L, int narg, const char* extramsg) luaL_error(L, "invalid argument #%d (%s)", narg, extramsg); } -LUALIB_API l_noret luaL_typeerrorL(lua_State* L, int narg, const char* tname) +l_noret luaL_typeerrorL(lua_State* L, int narg, const char* tname) { const char* fname = currfuncname(L); const TValue* obj = luaA_toobject(L, narg); @@ -66,7 +66,7 @@ static l_noret tag_error(lua_State* L, int narg, int tag) luaL_typeerrorL(L, narg, lua_typename(L, tag)); } -LUALIB_API void luaL_where(lua_State* L, int level) +void luaL_where(lua_State* L, int level) { lua_Debug ar; if (lua_getinfo(L, level, "sl", &ar) && ar.currentline > 0) @@ -77,7 +77,7 @@ LUALIB_API void luaL_where(lua_State* L, int level) lua_pushliteral(L, ""); /* else, no information available... */ } -LUALIB_API l_noret luaL_errorL(lua_State* L, const char* fmt, ...) +l_noret luaL_errorL(lua_State* L, const char* fmt, ...) { va_list argp; va_start(argp, fmt); @@ -90,7 +90,7 @@ LUALIB_API l_noret luaL_errorL(lua_State* L, const char* fmt, ...) /* }====================================================== */ -LUALIB_API int luaL_checkoption(lua_State* L, int narg, const char* def, const char* const lst[]) +int luaL_checkoption(lua_State* L, int narg, const char* def, const char* const lst[]) { const char* name = (def) ? luaL_optstring(L, narg, def) : luaL_checkstring(L, narg); int i; @@ -101,7 +101,7 @@ LUALIB_API int luaL_checkoption(lua_State* L, int narg, const char* def, const c luaL_argerrorL(L, narg, msg); } -LUALIB_API int luaL_newmetatable(lua_State* L, const char* tname) +int luaL_newmetatable(lua_State* L, const char* tname) { lua_getfield(L, LUA_REGISTRYINDEX, tname); /* get registry.name */ if (!lua_isnil(L, -1)) /* name already in use? */ @@ -113,7 +113,7 @@ LUALIB_API int luaL_newmetatable(lua_State* L, const char* tname) return 1; } -LUALIB_API void* luaL_checkudata(lua_State* L, int ud, const char* tname) +void* luaL_checkudata(lua_State* L, int ud, const char* tname) { void* p = lua_touserdata(L, ud); if (p != NULL) @@ -131,25 +131,25 @@ LUALIB_API void* luaL_checkudata(lua_State* L, int ud, const char* tname) luaL_typeerrorL(L, ud, tname); /* else error */ } -LUALIB_API void luaL_checkstack(lua_State* L, int space, const char* mes) +void luaL_checkstack(lua_State* L, int space, const char* mes) { if (!lua_checkstack(L, space)) luaL_error(L, "stack overflow (%s)", mes); } -LUALIB_API void luaL_checktype(lua_State* L, int narg, int t) +void luaL_checktype(lua_State* L, int narg, int t) { if (lua_type(L, narg) != t) tag_error(L, narg, t); } -LUALIB_API void luaL_checkany(lua_State* L, int narg) +void luaL_checkany(lua_State* L, int narg) { if (lua_type(L, narg) == LUA_TNONE) luaL_error(L, "missing argument #%d", narg); } -LUALIB_API const char* luaL_checklstring(lua_State* L, int narg, size_t* len) +const char* luaL_checklstring(lua_State* L, int narg, size_t* len) { const char* s = lua_tolstring(L, narg, len); if (!s) @@ -157,7 +157,7 @@ LUALIB_API const char* luaL_checklstring(lua_State* L, int narg, size_t* len) return s; } -LUALIB_API const char* luaL_optlstring(lua_State* L, int narg, const char* def, size_t* len) +const char* luaL_optlstring(lua_State* L, int narg, const char* def, size_t* len) { if (lua_isnoneornil(L, narg)) { @@ -169,7 +169,7 @@ LUALIB_API const char* luaL_optlstring(lua_State* L, int narg, const char* def, return luaL_checklstring(L, narg, len); } -LUALIB_API double luaL_checknumber(lua_State* L, int narg) +double luaL_checknumber(lua_State* L, int narg) { int isnum; double d = lua_tonumberx(L, narg, &isnum); @@ -178,12 +178,28 @@ LUALIB_API double luaL_checknumber(lua_State* L, int narg) return d; } -LUALIB_API double luaL_optnumber(lua_State* L, int narg, double def) +double luaL_optnumber(lua_State* L, int narg, double def) { return luaL_opt(L, luaL_checknumber, narg, def); } -LUALIB_API int luaL_checkinteger(lua_State* L, int narg) +int luaL_checkboolean(lua_State* L, int narg) +{ + // This checks specifically for boolean values, ignoring + // all other truthy/falsy values. If the desired result + // is true if value is present then lua_toboolean should + // directly be used instead. + if (!lua_isboolean(L, narg)) + tag_error(L, narg, LUA_TBOOLEAN); + return lua_toboolean(L, narg); +} + +int luaL_optboolean(lua_State* L, int narg, int def) +{ + return luaL_opt(L, luaL_checkboolean, narg, def); +} + +int luaL_checkinteger(lua_State* L, int narg) { int isnum; int d = lua_tointegerx(L, narg, &isnum); @@ -192,12 +208,12 @@ LUALIB_API int luaL_checkinteger(lua_State* L, int narg) return d; } -LUALIB_API int luaL_optinteger(lua_State* L, int narg, int def) +int luaL_optinteger(lua_State* L, int narg, int def) { return luaL_opt(L, luaL_checkinteger, narg, def); } -LUALIB_API unsigned luaL_checkunsigned(lua_State* L, int narg) +unsigned luaL_checkunsigned(lua_State* L, int narg) { int isnum; unsigned d = lua_tounsignedx(L, narg, &isnum); @@ -206,12 +222,25 @@ LUALIB_API unsigned luaL_checkunsigned(lua_State* L, int narg) return d; } -LUALIB_API unsigned luaL_optunsigned(lua_State* L, int narg, unsigned def) +unsigned luaL_optunsigned(lua_State* L, int narg, unsigned def) { return luaL_opt(L, luaL_checkunsigned, narg, def); } -LUALIB_API int luaL_getmetafield(lua_State* L, int obj, const char* event) +const float* luaL_checkvector(lua_State* L, int narg) +{ + const float* v = lua_tovector(L, narg); + if (!v) + tag_error(L, narg, LUA_TVECTOR); + return v; +} + +const float* luaL_optvector(lua_State* L, int narg, const float* def) +{ + return luaL_opt(L, luaL_checkvector, narg, def); +} + +int luaL_getmetafield(lua_State* L, int obj, const char* event) { if (!lua_getmetatable(L, obj)) /* no metatable? */ return 0; @@ -229,7 +258,7 @@ LUALIB_API int luaL_getmetafield(lua_State* L, int obj, const char* event) } } -LUALIB_API int luaL_callmeta(lua_State* L, int obj, const char* event) +int luaL_callmeta(lua_State* L, int obj, const char* event) { obj = abs_index(L, obj); if (!luaL_getmetafield(L, obj, event)) /* no metafield? */ @@ -247,7 +276,7 @@ static int libsize(const luaL_Reg* l) return size; } -LUALIB_API void luaL_register(lua_State* L, const char* libname, const luaL_Reg* l) +void luaL_register(lua_State* L, const char* libname, const luaL_Reg* l) { if (libname) { @@ -273,7 +302,7 @@ LUALIB_API void luaL_register(lua_State* L, const char* libname, const luaL_Reg* } } -LUALIB_API const char* luaL_findtable(lua_State* L, int idx, const char* fname, int szhint) +const char* luaL_findtable(lua_State* L, int idx, const char* fname, int szhint) { const char* e; lua_pushvalue(L, idx); @@ -324,7 +353,7 @@ static size_t getnextbuffersize(lua_State* L, size_t currentsize, size_t desired return newsize; } -LUALIB_API void luaL_buffinit(lua_State* L, luaL_Buffer* B) +void luaL_buffinit(lua_State* L, luaL_Buffer* B) { // start with an internal buffer B->p = B->buffer; @@ -334,14 +363,14 @@ LUALIB_API void luaL_buffinit(lua_State* L, luaL_Buffer* B) B->storage = nullptr; } -LUALIB_API char* luaL_buffinitsize(lua_State* L, luaL_Buffer* B, size_t size) +char* luaL_buffinitsize(lua_State* L, luaL_Buffer* B, size_t size) { luaL_buffinit(L, B); luaL_reservebuffer(B, size, -1); return B->p; } -LUALIB_API char* luaL_extendbuffer(luaL_Buffer* B, size_t additionalsize, int boxloc) +char* luaL_extendbuffer(luaL_Buffer* B, size_t additionalsize, int boxloc) { lua_State* L = B->L; @@ -372,13 +401,13 @@ LUALIB_API char* luaL_extendbuffer(luaL_Buffer* B, size_t additionalsize, int bo return B->p; } -LUALIB_API void luaL_reservebuffer(luaL_Buffer* B, size_t size, int boxloc) +void luaL_reservebuffer(luaL_Buffer* B, size_t size, int boxloc) { if (size_t(B->end - B->p) < size) luaL_extendbuffer(B, size - (B->end - B->p), boxloc); } -LUALIB_API void luaL_addlstring(luaL_Buffer* B, const char* s, size_t len) +void luaL_addlstring(luaL_Buffer* B, const char* s, size_t len) { if (size_t(B->end - B->p) < len) luaL_extendbuffer(B, len - (B->end - B->p), -1); @@ -387,7 +416,7 @@ LUALIB_API void luaL_addlstring(luaL_Buffer* B, const char* s, size_t len) B->p += len; } -LUALIB_API void luaL_addvalue(luaL_Buffer* B) +void luaL_addvalue(luaL_Buffer* B) { lua_State* L = B->L; @@ -404,7 +433,7 @@ LUALIB_API void luaL_addvalue(luaL_Buffer* B) } } -LUALIB_API void luaL_pushresult(luaL_Buffer* B) +void luaL_pushresult(luaL_Buffer* B) { lua_State* L = B->L; @@ -428,7 +457,7 @@ LUALIB_API void luaL_pushresult(luaL_Buffer* B) } } -LUALIB_API void luaL_pushresultsize(luaL_Buffer* B, size_t size) +void luaL_pushresultsize(luaL_Buffer* B, size_t size) { B->p += size; luaL_pushresult(B); @@ -436,7 +465,7 @@ LUALIB_API void luaL_pushresultsize(luaL_Buffer* B, size_t size) /* }====================================================== */ -LUALIB_API const char* luaL_tolstring(lua_State* L, int idx, size_t* len) +const char* luaL_tolstring(lua_State* L, int idx, size_t* len) { if (luaL_callmeta(L, idx, "__tostring")) /* is there a metafield? */ { @@ -462,7 +491,11 @@ LUALIB_API const char* luaL_tolstring(lua_State* L, int idx, size_t* len) case LUA_TVECTOR: { const float* v = lua_tovector(L, idx); +#if LUA_VECTOR_SIZE == 4 + lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2], v[3]); +#else lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2]); +#endif break; } default: diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp index 61798e2bc..881c804db 100644 --- a/VM/src/lbaselib.cpp +++ b/VM/src/lbaselib.cpp @@ -401,7 +401,7 @@ static int luaB_newproxy(lua_State* L) bool needsmt = lua_toboolean(L, 1); - lua_newuserdata(L, 0, 0); + lua_newuserdata(L, 0); if (needsmt) { @@ -441,7 +441,7 @@ static void auxopen(lua_State* L, const char* name, lua_CFunction f, lua_CFuncti lua_setfield(L, -2, name); } -LUALIB_API int luaopen_base(lua_State* L) +int luaopen_base(lua_State* L) { /* set global _G */ lua_pushvalue(L, LUA_GLOBALSINDEX); diff --git a/VM/src/lbitlib.cpp b/VM/src/lbitlib.cpp index 907c43c42..8b511edf0 100644 --- a/VM/src/lbitlib.cpp +++ b/VM/src/lbitlib.cpp @@ -236,7 +236,7 @@ static const luaL_Reg bitlib[] = { {NULL, NULL}, }; -LUALIB_API int luaopen_bit32(lua_State* L) +int luaopen_bit32(lua_State* L) { luaL_register(L, LUA_BITLIBNAME, bitlib); diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index 9ab57ac9b..34e9ebc1f 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -1018,13 +1018,23 @@ static int luauF_tunpack(lua_State* L, StkId res, TValue* arg0, int nresults, St static int luauF_vector(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) { +#if LUA_VECTOR_SIZE == 4 + if (nparams >= 4 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1) && ttisnumber(args + 2)) +#else if (nparams >= 3 && nresults <= 1 && ttisnumber(arg0) && ttisnumber(args) && ttisnumber(args + 1)) +#endif { double x = nvalue(arg0); double y = nvalue(args); double z = nvalue(args + 1); - setvvalue(res, float(x), float(y), float(z)); +#if LUA_VECTOR_SIZE == 4 + double w = nvalue(args + 2); + setvvalue(res, float(x), float(y), float(z), float(w)); +#else + setvvalue(res, float(x), float(y), float(z), 0.0f); +#endif + return 1; } diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index 0178fae84..abcde7796 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -272,7 +272,7 @@ static const luaL_Reg co_funcs[] = { {NULL, NULL}, }; -LUALIB_API int luaopen_coroutine(lua_State* L) +int luaopen_coroutine(lua_State* L) { luaL_register(L, LUA_COLIBNAME, co_funcs); diff --git a/VM/src/ldblib.cpp b/VM/src/ldblib.cpp index 965d2b3de..93d8703a6 100644 --- a/VM/src/ldblib.cpp +++ b/VM/src/ldblib.cpp @@ -160,7 +160,7 @@ static const luaL_Reg dblib[] = { {NULL, NULL}, }; -LUALIB_API int luaopen_debug(lua_State* L) +int luaopen_debug(lua_State* L) { luaL_register(L, LUA_DBLIBNAME, dblib); return 1; diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 1259d4619..62bbdb7c9 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,9 +17,9 @@ #include -LUAU_FASTFLAGVARIABLE(LuauExceptionMessageFix, false) LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false) LUAU_FASTFLAG(LuauCoroutineClose) +LUAU_FASTFLAGVARIABLE(LuauActivateBeforeExec, false) /* ** {====================================================== @@ -74,35 +74,28 @@ class lua_exception : public std::exception const char* what() const throw() override { - if (FFlag::LuauExceptionMessageFix) + // LUA_ERRRUN/LUA_ERRSYNTAX pass an object on the stack which is intended to describe the error. + if (status == LUA_ERRRUN || status == LUA_ERRSYNTAX) { - // LUA_ERRRUN/LUA_ERRSYNTAX pass an object on the stack which is intended to describe the error. - if (status == LUA_ERRRUN || status == LUA_ERRSYNTAX) + // Conversion to a string could still fail. For example if a user passes a non-string/non-number argument to `error()`. + if (const char* str = lua_tostring(L, -1)) { - // Conversion to a string could still fail. For example if a user passes a non-string/non-number argument to `error()`. - if (const char* str = lua_tostring(L, -1)) - { - return str; - } - } - - switch (status) - { - case LUA_ERRRUN: - return "lua_exception: LUA_ERRRUN (no string/number provided as description)"; - case LUA_ERRSYNTAX: - return "lua_exception: LUA_ERRSYNTAX (no string/number provided as description)"; - case LUA_ERRMEM: - return "lua_exception: " LUA_MEMERRMSG; - case LUA_ERRERR: - return "lua_exception: " LUA_ERRERRMSG; - default: - return "lua_exception: unexpected exception status"; + return str; } } - else + + switch (status) { - return lua_tostring(L, -1); + case LUA_ERRRUN: + return "lua_exception: LUA_ERRRUN (no string/number provided as description)"; + case LUA_ERRSYNTAX: + return "lua_exception: LUA_ERRSYNTAX (no string/number provided as description)"; + case LUA_ERRMEM: + return "lua_exception: " LUA_MEMERRMSG; + case LUA_ERRERR: + return "lua_exception: " LUA_ERRERRMSG; + default: + return "lua_exception: unexpected exception status"; } } @@ -234,7 +227,22 @@ void luaD_call(lua_State* L, StkId func, int nResults) if (luau_precall(L, func, nResults) == PCRLUA) { /* is a Lua function? */ L->ci->flags |= LUA_CALLINFO_RETURN; /* luau_execute will stop after returning from the stack frame */ - luau_execute(L); /* call it */ + + if (FFlag::LuauActivateBeforeExec) + { + int oldactive = luaC_threadactive(L); + l_setbit(L->stackstate, THREAD_ACTIVEBIT); + luaC_checkthreadsleep(L); + + luau_execute(L); /* call it */ + + if (!oldactive) + resetbit(L->stackstate, THREAD_ACTIVEBIT); + } + else + { + luau_execute(L); /* call it */ + } } L->nCcalls--; luaC_checkGC(L); @@ -527,10 +535,10 @@ static void restore_stack_limit(lua_State* L) int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t ef) { - int status; unsigned short oldnCcalls = L->nCcalls; ptrdiff_t old_ci = saveci(L, L->ci); - status = luaD_rawrunprotected(L, func, u); + int oldactive = luaC_threadactive(L); + int status = luaD_rawrunprotected(L, func, u); if (status != 0) { // call user-defined error function (used in xpcall) @@ -541,6 +549,13 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e status = LUA_ERRERR; } + if (FFlag::LuauActivateBeforeExec) + { + // since the call failed with an error, we might have to reset the 'active' thread state + if (!oldactive) + resetbit(L->stackstate, THREAD_ACTIVEBIT); + } + if (FFlag::LuauCcallRestoreFix) { // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 11f79d1a3..ab416041e 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -10,9 +10,7 @@ #include "ldo.h" #include -#include -LUAU_FASTFLAGVARIABLE(LuauRescanGrayAgainForwardBarrier, false) LUAU_FASTFLAGVARIABLE(LuauSeparateAtomic, false) LUAU_FASTFLAG(LuauArrayBoundary) @@ -988,7 +986,7 @@ void luaC_barriertable(lua_State* L, Table* t, GCObject* v) GCObject* o = obj2gco(t); // in the second propagation stage, table assignment barrier works as a forward barrier - if (FFlag::LuauRescanGrayAgainForwardBarrier && g->gcstate == GCSpropagateagain) + if (g->gcstate == GCSpropagateagain) { LUAU_ASSERT(isblack(o) && iswhite(v) && !isdead(g, v) && !isdead(g, o)); reallymarkobject(g, v); @@ -1044,550 +1042,6 @@ void luaC_linkupval(lua_State* L, UpVal* uv) } } -static void validateobjref(global_State* g, GCObject* f, GCObject* t) -{ - LUAU_ASSERT(!isdead(g, t)); - - if (keepinvariant(g)) - { - /* basic incremental invariant: black can't point to white */ - LUAU_ASSERT(!(isblack(f) && iswhite(t))); - } -} - -static void validateref(global_State* g, GCObject* f, TValue* v) -{ - if (iscollectable(v)) - { - LUAU_ASSERT(ttype(v) == gcvalue(v)->gch.tt); - validateobjref(g, f, gcvalue(v)); - } -} - -static void validatetable(global_State* g, Table* h) -{ - int sizenode = 1 << h->lsizenode; - - if (FFlag::LuauArrayBoundary) - LUAU_ASSERT(h->lastfree <= sizenode); - else - LUAU_ASSERT(h->lastfree >= 0 && h->lastfree <= sizenode); - - if (h->metatable) - validateobjref(g, obj2gco(h), obj2gco(h->metatable)); - - for (int i = 0; i < h->sizearray; ++i) - validateref(g, obj2gco(h), &h->array[i]); - - for (int i = 0; i < sizenode; ++i) - { - LuaNode* n = &h->node[i]; - - LUAU_ASSERT(ttype(gkey(n)) != LUA_TDEADKEY || ttisnil(gval(n))); - LUAU_ASSERT(i + gnext(n) >= 0 && i + gnext(n) < sizenode); - - if (!ttisnil(gval(n))) - { - TValue k = {}; - k.tt = gkey(n)->tt; - k.value = gkey(n)->value; - - validateref(g, obj2gco(h), &k); - validateref(g, obj2gco(h), gval(n)); - } - } -} - -static void validateclosure(global_State* g, Closure* cl) -{ - validateobjref(g, obj2gco(cl), obj2gco(cl->env)); - - if (cl->isC) - { - for (int i = 0; i < cl->nupvalues; ++i) - validateref(g, obj2gco(cl), &cl->c.upvals[i]); - } - else - { - LUAU_ASSERT(cl->nupvalues == cl->l.p->nups); - - validateobjref(g, obj2gco(cl), obj2gco(cl->l.p)); - - for (int i = 0; i < cl->nupvalues; ++i) - validateref(g, obj2gco(cl), &cl->l.uprefs[i]); - } -} - -static void validatestack(global_State* g, lua_State* l) -{ - validateref(g, obj2gco(l), gt(l)); - - for (CallInfo* ci = l->base_ci; ci <= l->ci; ++ci) - { - LUAU_ASSERT(l->stack <= ci->base); - LUAU_ASSERT(ci->func <= ci->base && ci->base <= ci->top); - LUAU_ASSERT(ci->top <= l->stack_last); - } - - // note: stack refs can violate gc invariant so we only check for liveness - for (StkId o = l->stack; o < l->top; ++o) - checkliveness(g, o); - - if (l->namecall) - validateobjref(g, obj2gco(l), obj2gco(l->namecall)); - - for (GCObject* uv = l->openupval; uv; uv = uv->gch.next) - { - LUAU_ASSERT(uv->gch.tt == LUA_TUPVAL); - LUAU_ASSERT(gco2uv(uv)->v != &gco2uv(uv)->u.value); - } -} - -static void validateproto(global_State* g, Proto* f) -{ - if (f->source) - validateobjref(g, obj2gco(f), obj2gco(f->source)); - - if (f->debugname) - validateobjref(g, obj2gco(f), obj2gco(f->debugname)); - - for (int i = 0; i < f->sizek; ++i) - validateref(g, obj2gco(f), &f->k[i]); - - for (int i = 0; i < f->sizeupvalues; ++i) - if (f->upvalues[i]) - validateobjref(g, obj2gco(f), obj2gco(f->upvalues[i])); - - for (int i = 0; i < f->sizep; ++i) - if (f->p[i]) - validateobjref(g, obj2gco(f), obj2gco(f->p[i])); - - for (int i = 0; i < f->sizelocvars; i++) - if (f->locvars[i].varname) - validateobjref(g, obj2gco(f), obj2gco(f->locvars[i].varname)); -} - -static void validateobj(global_State* g, GCObject* o) -{ - /* dead objects can only occur during sweep */ - if (isdead(g, o)) - { - LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); - return; - } - - switch (o->gch.tt) - { - case LUA_TSTRING: - break; - - case LUA_TTABLE: - validatetable(g, gco2h(o)); - break; - - case LUA_TFUNCTION: - validateclosure(g, gco2cl(o)); - break; - - case LUA_TUSERDATA: - if (gco2u(o)->metatable) - validateobjref(g, o, obj2gco(gco2u(o)->metatable)); - break; - - case LUA_TTHREAD: - validatestack(g, gco2th(o)); - break; - - case LUA_TPROTO: - validateproto(g, gco2p(o)); - break; - - case LUA_TUPVAL: - validateref(g, o, gco2uv(o)->v); - break; - - default: - LUAU_ASSERT(!"unexpected object type"); - } -} - -static void validatelist(global_State* g, GCObject* o) -{ - while (o) - { - validateobj(g, o); - - o = o->gch.next; - } -} - -static void validategraylist(global_State* g, GCObject* o) -{ - if (!keepinvariant(g)) - return; - - while (o) - { - LUAU_ASSERT(isgray(o)); - - switch (o->gch.tt) - { - case LUA_TTABLE: - o = gco2h(o)->gclist; - break; - case LUA_TFUNCTION: - o = gco2cl(o)->gclist; - break; - case LUA_TTHREAD: - o = gco2th(o)->gclist; - break; - case LUA_TPROTO: - o = gco2p(o)->gclist; - break; - default: - LUAU_ASSERT(!"unknown object in gray list"); - return; - } - } -} - -void luaC_validate(lua_State* L) -{ - global_State* g = L->global; - - LUAU_ASSERT(!isdead(g, obj2gco(g->mainthread))); - checkliveness(g, &g->registry); - - for (int i = 0; i < LUA_T_COUNT; ++i) - if (g->mt[i]) - LUAU_ASSERT(!isdead(g, obj2gco(g->mt[i]))); - - validategraylist(g, g->weak); - validategraylist(g, g->gray); - validategraylist(g, g->grayagain); - - for (int i = 0; i < g->strt.size; ++i) - validatelist(g, g->strt.hash[i]); - - validatelist(g, g->rootgc); - validatelist(g, g->strbufgc); - - for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) - { - LUAU_ASSERT(uv->tt == LUA_TUPVAL); - LUAU_ASSERT(uv->v != &uv->u.value); - LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); - } -} - -inline bool safejson(char ch) -{ - return unsigned(ch) < 128 && ch >= 32 && ch != '\\' && ch != '\"'; -} - -static void dumpref(FILE* f, GCObject* o) -{ - fprintf(f, "\"%p\"", o); -} - -static void dumprefs(FILE* f, TValue* data, size_t size) -{ - bool first = true; - - for (size_t i = 0; i < size; ++i) - { - if (iscollectable(&data[i])) - { - if (!first) - fputc(',', f); - first = false; - - dumpref(f, gcvalue(&data[i])); - } - } -} - -static void dumpstringdata(FILE* f, const char* data, size_t len) -{ - for (size_t i = 0; i < len; ++i) - fputc(safejson(data[i]) ? data[i] : '?', f); -} - -static void dumpstring(FILE* f, TString* ts) -{ - fprintf(f, "{\"type\":\"string\",\"cat\":%d,\"size\":%d,\"data\":\"", ts->memcat, int(sizestring(ts->len))); - dumpstringdata(f, ts->data, ts->len); - fprintf(f, "\"}"); -} - -static void dumptable(FILE* f, Table* h) -{ - size_t size = sizeof(Table) + (h->node == &luaH_dummynode ? 0 : sizenode(h) * sizeof(LuaNode)) + h->sizearray * sizeof(TValue); - - fprintf(f, "{\"type\":\"table\",\"cat\":%d,\"size\":%d", h->memcat, int(size)); - - if (h->node != &luaH_dummynode) - { - fprintf(f, ",\"pairs\":["); - - bool first = true; - - for (int i = 0; i < sizenode(h); ++i) - { - const LuaNode& n = h->node[i]; - - if (!ttisnil(&n.val) && (iscollectable(&n.key) || iscollectable(&n.val))) - { - if (!first) - fputc(',', f); - first = false; - - if (iscollectable(&n.key)) - dumpref(f, gcvalue(&n.key)); - else - fprintf(f, "null"); - - fputc(',', f); - - if (iscollectable(&n.val)) - dumpref(f, gcvalue(&n.val)); - else - fprintf(f, "null"); - } - } - - fprintf(f, "]"); - } - if (h->sizearray) - { - fprintf(f, ",\"array\":["); - dumprefs(f, h->array, h->sizearray); - fprintf(f, "]"); - } - if (h->metatable) - { - fprintf(f, ",\"metatable\":"); - dumpref(f, obj2gco(h->metatable)); - } - fprintf(f, "}"); -} - -static void dumpclosure(FILE* f, Closure* cl) -{ - fprintf(f, "{\"type\":\"function\",\"cat\":%d,\"size\":%d", cl->memcat, - cl->isC ? int(sizeCclosure(cl->nupvalues)) : int(sizeLclosure(cl->nupvalues))); - - fprintf(f, ",\"env\":"); - dumpref(f, obj2gco(cl->env)); - if (cl->isC) - { - if (cl->nupvalues) - { - fprintf(f, ",\"upvalues\":["); - dumprefs(f, cl->c.upvals, cl->nupvalues); - fprintf(f, "]"); - } - } - else - { - fprintf(f, ",\"proto\":"); - dumpref(f, obj2gco(cl->l.p)); - if (cl->nupvalues) - { - fprintf(f, ",\"upvalues\":["); - dumprefs(f, cl->l.uprefs, cl->nupvalues); - fprintf(f, "]"); - } - } - fprintf(f, "}"); -} - -static void dumpudata(FILE* f, Udata* u) -{ - fprintf(f, "{\"type\":\"userdata\",\"cat\":%d,\"size\":%d,\"tag\":%d", u->memcat, int(sizeudata(u->len)), u->tag); - - if (u->metatable) - { - fprintf(f, ",\"metatable\":"); - dumpref(f, obj2gco(u->metatable)); - } - fprintf(f, "}"); -} - -static void dumpthread(FILE* f, lua_State* th) -{ - size_t size = sizeof(lua_State) + sizeof(TValue) * th->stacksize + sizeof(CallInfo) * th->size_ci; - - fprintf(f, "{\"type\":\"thread\",\"cat\":%d,\"size\":%d", th->memcat, int(size)); - - if (iscollectable(&th->l_gt)) - { - fprintf(f, ",\"env\":"); - dumpref(f, gcvalue(&th->l_gt)); - } - - Closure* tcl = 0; - for (CallInfo* ci = th->base_ci; ci <= th->ci; ++ci) - { - if (ttisfunction(ci->func)) - { - tcl = clvalue(ci->func); - break; - } - } - - if (tcl && !tcl->isC && tcl->l.p->source) - { - Proto* p = tcl->l.p; - - fprintf(f, ",\"source\":\""); - dumpstringdata(f, p->source->data, p->source->len); - fprintf(f, "\",\"line\":%d", p->abslineinfo ? p->abslineinfo[0] : 0); - } - - if (th->top > th->stack) - { - fprintf(f, ",\"stack\":["); - dumprefs(f, th->stack, th->top - th->stack); - fprintf(f, "]"); - } - fprintf(f, "}"); -} - -static void dumpproto(FILE* f, Proto* p) -{ - size_t size = sizeof(Proto) + sizeof(Instruction) * p->sizecode + sizeof(Proto*) * p->sizep + sizeof(TValue) * p->sizek + p->sizelineinfo + - sizeof(LocVar) * p->sizelocvars + sizeof(TString*) * p->sizeupvalues; - - fprintf(f, "{\"type\":\"proto\",\"cat\":%d,\"size\":%d", p->memcat, int(size)); - - if (p->source) - { - fprintf(f, ",\"source\":\""); - dumpstringdata(f, p->source->data, p->source->len); - fprintf(f, "\",\"line\":%d", p->abslineinfo ? p->abslineinfo[0] : 0); - } - - if (p->sizek) - { - fprintf(f, ",\"constants\":["); - dumprefs(f, p->k, p->sizek); - fprintf(f, "]"); - } - - if (p->sizep) - { - fprintf(f, ",\"protos\":["); - for (int i = 0; i < p->sizep; ++i) - { - if (i != 0) - fputc(',', f); - dumpref(f, obj2gco(p->p[i])); - } - fprintf(f, "]"); - } - - fprintf(f, "}"); -} - -static void dumpupval(FILE* f, UpVal* uv) -{ - fprintf(f, "{\"type\":\"upvalue\",\"cat\":%d,\"size\":%d", uv->memcat, int(sizeof(UpVal))); - - if (iscollectable(uv->v)) - { - fprintf(f, ",\"object\":"); - dumpref(f, gcvalue(uv->v)); - } - fprintf(f, "}"); -} - -static void dumpobj(FILE* f, GCObject* o) -{ - switch (o->gch.tt) - { - case LUA_TSTRING: - return dumpstring(f, gco2ts(o)); - - case LUA_TTABLE: - return dumptable(f, gco2h(o)); - - case LUA_TFUNCTION: - return dumpclosure(f, gco2cl(o)); - - case LUA_TUSERDATA: - return dumpudata(f, gco2u(o)); - - case LUA_TTHREAD: - return dumpthread(f, gco2th(o)); - - case LUA_TPROTO: - return dumpproto(f, gco2p(o)); - - case LUA_TUPVAL: - return dumpupval(f, gco2uv(o)); - - default: - LUAU_ASSERT(0); - } -} - -static void dumplist(FILE* f, GCObject* o) -{ - while (o) - { - dumpref(f, o); - fputc(':', f); - dumpobj(f, o); - fputc(',', f); - fputc('\n', f); - - // thread has additional list containing collectable objects that are not present in rootgc - if (o->gch.tt == LUA_TTHREAD) - dumplist(f, gco2th(o)->openupval); - - o = o->gch.next; - } -} - -void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)) -{ - global_State* g = L->global; - FILE* f = static_cast(file); - - fprintf(f, "{\"objects\":{\n"); - dumplist(f, g->rootgc); - dumplist(f, g->strbufgc); - for (int i = 0; i < g->strt.size; ++i) - dumplist(f, g->strt.hash[i]); - - fprintf(f, "\"0\":{\"type\":\"userdata\",\"cat\":0,\"size\":0}\n"); // to avoid issues with trailing , - fprintf(f, "},\"roots\":{\n"); - fprintf(f, "\"mainthread\":"); - dumpref(f, obj2gco(g->mainthread)); - fprintf(f, ",\"registry\":"); - dumpref(f, gcvalue(&g->registry)); - - fprintf(f, "},\"stats\":{\n"); - - fprintf(f, "\"size\":%d,\n", int(g->totalbytes)); - - fprintf(f, "\"categories\":{\n"); - for (int i = 0; i < LUA_MEMORY_CATEGORIES; i++) - { - if (size_t bytes = g->memcatbytes[i]) - { - if (categoryName) - fprintf(f, "\"%d\":{\"name\":\"%s\", \"size\":%d},\n", i, categoryName(L, i), int(bytes)); - else - fprintf(f, "\"%d\":{\"size\":%d},\n", i, int(bytes)); - } - } - fprintf(f, "\"none\":{}\n"); // to avoid issues with trailing , - fprintf(f, "}\n"); - fprintf(f, "}}\n"); -} - // measure the allocation rate in bytes/sec // returns -1 if allocation rate cannot be measured int64_t luaC_allocationrate(lua_State* L) diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp new file mode 100644 index 000000000..a79e7b953 --- /dev/null +++ b/VM/src/lgcdebug.cpp @@ -0,0 +1,558 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "lgc.h" + +#include "lobject.h" +#include "lstate.h" +#include "ltable.h" +#include "lfunc.h" +#include "lstring.h" + +#include +#include + +LUAU_FASTFLAG(LuauArrayBoundary) + +static void validateobjref(global_State* g, GCObject* f, GCObject* t) +{ + LUAU_ASSERT(!isdead(g, t)); + + if (keepinvariant(g)) + { + /* basic incremental invariant: black can't point to white */ + LUAU_ASSERT(!(isblack(f) && iswhite(t))); + } +} + +static void validateref(global_State* g, GCObject* f, TValue* v) +{ + if (iscollectable(v)) + { + LUAU_ASSERT(ttype(v) == gcvalue(v)->gch.tt); + validateobjref(g, f, gcvalue(v)); + } +} + +static void validatetable(global_State* g, Table* h) +{ + int sizenode = 1 << h->lsizenode; + + if (FFlag::LuauArrayBoundary) + LUAU_ASSERT(h->lastfree <= sizenode); + else + LUAU_ASSERT(h->lastfree >= 0 && h->lastfree <= sizenode); + + if (h->metatable) + validateobjref(g, obj2gco(h), obj2gco(h->metatable)); + + for (int i = 0; i < h->sizearray; ++i) + validateref(g, obj2gco(h), &h->array[i]); + + for (int i = 0; i < sizenode; ++i) + { + LuaNode* n = &h->node[i]; + + LUAU_ASSERT(ttype(gkey(n)) != LUA_TDEADKEY || ttisnil(gval(n))); + LUAU_ASSERT(i + gnext(n) >= 0 && i + gnext(n) < sizenode); + + if (!ttisnil(gval(n))) + { + TValue k = {}; + k.tt = gkey(n)->tt; + k.value = gkey(n)->value; + + validateref(g, obj2gco(h), &k); + validateref(g, obj2gco(h), gval(n)); + } + } +} + +static void validateclosure(global_State* g, Closure* cl) +{ + validateobjref(g, obj2gco(cl), obj2gco(cl->env)); + + if (cl->isC) + { + for (int i = 0; i < cl->nupvalues; ++i) + validateref(g, obj2gco(cl), &cl->c.upvals[i]); + } + else + { + LUAU_ASSERT(cl->nupvalues == cl->l.p->nups); + + validateobjref(g, obj2gco(cl), obj2gco(cl->l.p)); + + for (int i = 0; i < cl->nupvalues; ++i) + validateref(g, obj2gco(cl), &cl->l.uprefs[i]); + } +} + +static void validatestack(global_State* g, lua_State* l) +{ + validateref(g, obj2gco(l), gt(l)); + + for (CallInfo* ci = l->base_ci; ci <= l->ci; ++ci) + { + LUAU_ASSERT(l->stack <= ci->base); + LUAU_ASSERT(ci->func <= ci->base && ci->base <= ci->top); + LUAU_ASSERT(ci->top <= l->stack_last); + } + + // note: stack refs can violate gc invariant so we only check for liveness + for (StkId o = l->stack; o < l->top; ++o) + checkliveness(g, o); + + if (l->namecall) + validateobjref(g, obj2gco(l), obj2gco(l->namecall)); + + for (GCObject* uv = l->openupval; uv; uv = uv->gch.next) + { + LUAU_ASSERT(uv->gch.tt == LUA_TUPVAL); + LUAU_ASSERT(gco2uv(uv)->v != &gco2uv(uv)->u.value); + } +} + +static void validateproto(global_State* g, Proto* f) +{ + if (f->source) + validateobjref(g, obj2gco(f), obj2gco(f->source)); + + if (f->debugname) + validateobjref(g, obj2gco(f), obj2gco(f->debugname)); + + for (int i = 0; i < f->sizek; ++i) + validateref(g, obj2gco(f), &f->k[i]); + + for (int i = 0; i < f->sizeupvalues; ++i) + if (f->upvalues[i]) + validateobjref(g, obj2gco(f), obj2gco(f->upvalues[i])); + + for (int i = 0; i < f->sizep; ++i) + if (f->p[i]) + validateobjref(g, obj2gco(f), obj2gco(f->p[i])); + + for (int i = 0; i < f->sizelocvars; i++) + if (f->locvars[i].varname) + validateobjref(g, obj2gco(f), obj2gco(f->locvars[i].varname)); +} + +static void validateobj(global_State* g, GCObject* o) +{ + /* dead objects can only occur during sweep */ + if (isdead(g, o)) + { + LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); + return; + } + + switch (o->gch.tt) + { + case LUA_TSTRING: + break; + + case LUA_TTABLE: + validatetable(g, gco2h(o)); + break; + + case LUA_TFUNCTION: + validateclosure(g, gco2cl(o)); + break; + + case LUA_TUSERDATA: + if (gco2u(o)->metatable) + validateobjref(g, o, obj2gco(gco2u(o)->metatable)); + break; + + case LUA_TTHREAD: + validatestack(g, gco2th(o)); + break; + + case LUA_TPROTO: + validateproto(g, gco2p(o)); + break; + + case LUA_TUPVAL: + validateref(g, o, gco2uv(o)->v); + break; + + default: + LUAU_ASSERT(!"unexpected object type"); + } +} + +static void validatelist(global_State* g, GCObject* o) +{ + while (o) + { + validateobj(g, o); + + o = o->gch.next; + } +} + +static void validategraylist(global_State* g, GCObject* o) +{ + if (!keepinvariant(g)) + return; + + while (o) + { + LUAU_ASSERT(isgray(o)); + + switch (o->gch.tt) + { + case LUA_TTABLE: + o = gco2h(o)->gclist; + break; + case LUA_TFUNCTION: + o = gco2cl(o)->gclist; + break; + case LUA_TTHREAD: + o = gco2th(o)->gclist; + break; + case LUA_TPROTO: + o = gco2p(o)->gclist; + break; + default: + LUAU_ASSERT(!"unknown object in gray list"); + return; + } + } +} + +void luaC_validate(lua_State* L) +{ + global_State* g = L->global; + + LUAU_ASSERT(!isdead(g, obj2gco(g->mainthread))); + checkliveness(g, &g->registry); + + for (int i = 0; i < LUA_T_COUNT; ++i) + if (g->mt[i]) + LUAU_ASSERT(!isdead(g, obj2gco(g->mt[i]))); + + validategraylist(g, g->weak); + validategraylist(g, g->gray); + validategraylist(g, g->grayagain); + + for (int i = 0; i < g->strt.size; ++i) + validatelist(g, g->strt.hash[i]); + + validatelist(g, g->rootgc); + validatelist(g, g->strbufgc); + + for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) + { + LUAU_ASSERT(uv->tt == LUA_TUPVAL); + LUAU_ASSERT(uv->v != &uv->u.value); + LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); + } +} + +inline bool safejson(char ch) +{ + return unsigned(ch) < 128 && ch >= 32 && ch != '\\' && ch != '\"'; +} + +static void dumpref(FILE* f, GCObject* o) +{ + fprintf(f, "\"%p\"", o); +} + +static void dumprefs(FILE* f, TValue* data, size_t size) +{ + bool first = true; + + for (size_t i = 0; i < size; ++i) + { + if (iscollectable(&data[i])) + { + if (!first) + fputc(',', f); + first = false; + + dumpref(f, gcvalue(&data[i])); + } + } +} + +static void dumpstringdata(FILE* f, const char* data, size_t len) +{ + for (size_t i = 0; i < len; ++i) + fputc(safejson(data[i]) ? data[i] : '?', f); +} + +static void dumpstring(FILE* f, TString* ts) +{ + fprintf(f, "{\"type\":\"string\",\"cat\":%d,\"size\":%d,\"data\":\"", ts->memcat, int(sizestring(ts->len))); + dumpstringdata(f, ts->data, ts->len); + fprintf(f, "\"}"); +} + +static void dumptable(FILE* f, Table* h) +{ + size_t size = sizeof(Table) + (h->node == &luaH_dummynode ? 0 : sizenode(h) * sizeof(LuaNode)) + h->sizearray * sizeof(TValue); + + fprintf(f, "{\"type\":\"table\",\"cat\":%d,\"size\":%d", h->memcat, int(size)); + + if (h->node != &luaH_dummynode) + { + fprintf(f, ",\"pairs\":["); + + bool first = true; + + for (int i = 0; i < sizenode(h); ++i) + { + const LuaNode& n = h->node[i]; + + if (!ttisnil(&n.val) && (iscollectable(&n.key) || iscollectable(&n.val))) + { + if (!first) + fputc(',', f); + first = false; + + if (iscollectable(&n.key)) + dumpref(f, gcvalue(&n.key)); + else + fprintf(f, "null"); + + fputc(',', f); + + if (iscollectable(&n.val)) + dumpref(f, gcvalue(&n.val)); + else + fprintf(f, "null"); + } + } + + fprintf(f, "]"); + } + if (h->sizearray) + { + fprintf(f, ",\"array\":["); + dumprefs(f, h->array, h->sizearray); + fprintf(f, "]"); + } + if (h->metatable) + { + fprintf(f, ",\"metatable\":"); + dumpref(f, obj2gco(h->metatable)); + } + fprintf(f, "}"); +} + +static void dumpclosure(FILE* f, Closure* cl) +{ + fprintf(f, "{\"type\":\"function\",\"cat\":%d,\"size\":%d", cl->memcat, + cl->isC ? int(sizeCclosure(cl->nupvalues)) : int(sizeLclosure(cl->nupvalues))); + + fprintf(f, ",\"env\":"); + dumpref(f, obj2gco(cl->env)); + if (cl->isC) + { + if (cl->nupvalues) + { + fprintf(f, ",\"upvalues\":["); + dumprefs(f, cl->c.upvals, cl->nupvalues); + fprintf(f, "]"); + } + } + else + { + fprintf(f, ",\"proto\":"); + dumpref(f, obj2gco(cl->l.p)); + if (cl->nupvalues) + { + fprintf(f, ",\"upvalues\":["); + dumprefs(f, cl->l.uprefs, cl->nupvalues); + fprintf(f, "]"); + } + } + fprintf(f, "}"); +} + +static void dumpudata(FILE* f, Udata* u) +{ + fprintf(f, "{\"type\":\"userdata\",\"cat\":%d,\"size\":%d,\"tag\":%d", u->memcat, int(sizeudata(u->len)), u->tag); + + if (u->metatable) + { + fprintf(f, ",\"metatable\":"); + dumpref(f, obj2gco(u->metatable)); + } + fprintf(f, "}"); +} + +static void dumpthread(FILE* f, lua_State* th) +{ + size_t size = sizeof(lua_State) + sizeof(TValue) * th->stacksize + sizeof(CallInfo) * th->size_ci; + + fprintf(f, "{\"type\":\"thread\",\"cat\":%d,\"size\":%d", th->memcat, int(size)); + + if (iscollectable(&th->l_gt)) + { + fprintf(f, ",\"env\":"); + dumpref(f, gcvalue(&th->l_gt)); + } + + Closure* tcl = 0; + for (CallInfo* ci = th->base_ci; ci <= th->ci; ++ci) + { + if (ttisfunction(ci->func)) + { + tcl = clvalue(ci->func); + break; + } + } + + if (tcl && !tcl->isC && tcl->l.p->source) + { + Proto* p = tcl->l.p; + + fprintf(f, ",\"source\":\""); + dumpstringdata(f, p->source->data, p->source->len); + fprintf(f, "\",\"line\":%d", p->abslineinfo ? p->abslineinfo[0] : 0); + } + + if (th->top > th->stack) + { + fprintf(f, ",\"stack\":["); + dumprefs(f, th->stack, th->top - th->stack); + fprintf(f, "]"); + } + fprintf(f, "}"); +} + +static void dumpproto(FILE* f, Proto* p) +{ + size_t size = sizeof(Proto) + sizeof(Instruction) * p->sizecode + sizeof(Proto*) * p->sizep + sizeof(TValue) * p->sizek + p->sizelineinfo + + sizeof(LocVar) * p->sizelocvars + sizeof(TString*) * p->sizeupvalues; + + fprintf(f, "{\"type\":\"proto\",\"cat\":%d,\"size\":%d", p->memcat, int(size)); + + if (p->source) + { + fprintf(f, ",\"source\":\""); + dumpstringdata(f, p->source->data, p->source->len); + fprintf(f, "\",\"line\":%d", p->abslineinfo ? p->abslineinfo[0] : 0); + } + + if (p->sizek) + { + fprintf(f, ",\"constants\":["); + dumprefs(f, p->k, p->sizek); + fprintf(f, "]"); + } + + if (p->sizep) + { + fprintf(f, ",\"protos\":["); + for (int i = 0; i < p->sizep; ++i) + { + if (i != 0) + fputc(',', f); + dumpref(f, obj2gco(p->p[i])); + } + fprintf(f, "]"); + } + + fprintf(f, "}"); +} + +static void dumpupval(FILE* f, UpVal* uv) +{ + fprintf(f, "{\"type\":\"upvalue\",\"cat\":%d,\"size\":%d", uv->memcat, int(sizeof(UpVal))); + + if (iscollectable(uv->v)) + { + fprintf(f, ",\"object\":"); + dumpref(f, gcvalue(uv->v)); + } + fprintf(f, "}"); +} + +static void dumpobj(FILE* f, GCObject* o) +{ + switch (o->gch.tt) + { + case LUA_TSTRING: + return dumpstring(f, gco2ts(o)); + + case LUA_TTABLE: + return dumptable(f, gco2h(o)); + + case LUA_TFUNCTION: + return dumpclosure(f, gco2cl(o)); + + case LUA_TUSERDATA: + return dumpudata(f, gco2u(o)); + + case LUA_TTHREAD: + return dumpthread(f, gco2th(o)); + + case LUA_TPROTO: + return dumpproto(f, gco2p(o)); + + case LUA_TUPVAL: + return dumpupval(f, gco2uv(o)); + + default: + LUAU_ASSERT(0); + } +} + +static void dumplist(FILE* f, GCObject* o) +{ + while (o) + { + dumpref(f, o); + fputc(':', f); + dumpobj(f, o); + fputc(',', f); + fputc('\n', f); + + // thread has additional list containing collectable objects that are not present in rootgc + if (o->gch.tt == LUA_TTHREAD) + dumplist(f, gco2th(o)->openupval); + + o = o->gch.next; + } +} + +void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)) +{ + global_State* g = L->global; + FILE* f = static_cast(file); + + fprintf(f, "{\"objects\":{\n"); + dumplist(f, g->rootgc); + dumplist(f, g->strbufgc); + for (int i = 0; i < g->strt.size; ++i) + dumplist(f, g->strt.hash[i]); + + fprintf(f, "\"0\":{\"type\":\"userdata\",\"cat\":0,\"size\":0}\n"); // to avoid issues with trailing , + fprintf(f, "},\"roots\":{\n"); + fprintf(f, "\"mainthread\":"); + dumpref(f, obj2gco(g->mainthread)); + fprintf(f, ",\"registry\":"); + dumpref(f, gcvalue(&g->registry)); + + fprintf(f, "},\"stats\":{\n"); + + fprintf(f, "\"size\":%d,\n", int(g->totalbytes)); + + fprintf(f, "\"categories\":{\n"); + for (int i = 0; i < LUA_MEMORY_CATEGORIES; i++) + { + if (size_t bytes = g->memcatbytes[i]) + { + if (categoryName) + fprintf(f, "\"%d\":{\"name\":\"%s\", \"size\":%d},\n", i, categoryName(L, i), int(bytes)); + else + fprintf(f, "\"%d\":{\"size\":%d},\n", i, int(bytes)); + } + } + fprintf(f, "\"none\":{}\n"); // to avoid issues with trailing , + fprintf(f, "}\n"); + fprintf(f, "}}\n"); +} diff --git a/VM/src/linit.cpp b/VM/src/linit.cpp index 4e40165ab..c93f431f1 100644 --- a/VM/src/linit.cpp +++ b/VM/src/linit.cpp @@ -17,7 +17,7 @@ static const luaL_Reg lualibs[] = { {NULL, NULL}, }; -LUALIB_API void luaL_openlibs(lua_State* L) +void luaL_openlibs(lua_State* L) { const luaL_Reg* lib = lualibs; for (; lib->func; lib++) @@ -28,7 +28,7 @@ LUALIB_API void luaL_openlibs(lua_State* L) } } -LUALIB_API void luaL_sandbox(lua_State* L) +void luaL_sandbox(lua_State* L) { // set all libraries to read-only lua_pushnil(L); @@ -44,14 +44,14 @@ LUALIB_API void luaL_sandbox(lua_State* L) lua_pushliteral(L, ""); lua_getmetatable(L, -1); lua_setreadonly(L, -1, true); - lua_pop(L, 1); + lua_pop(L, 2); // set globals to readonly and activate safeenv since the env is immutable lua_setreadonly(L, LUA_GLOBALSINDEX, true); lua_setsafeenv(L, LUA_GLOBALSINDEX, true); } -LUALIB_API void luaL_sandboxthread(lua_State* L) +void luaL_sandboxthread(lua_State* L) { // create new global table that proxies reads to original table lua_newtable(L); @@ -81,7 +81,7 @@ static void* l_alloc(lua_State* L, void* ud, void* ptr, size_t osize, size_t nsi return realloc(ptr, nsize); } -LUALIB_API lua_State* luaL_newstate(void) +lua_State* luaL_newstate(void) { return lua_newstate(l_alloc, NULL); } diff --git a/VM/src/lmathlib.cpp b/VM/src/lmathlib.cpp index 8e476a526..a6e7b4940 100644 --- a/VM/src/lmathlib.cpp +++ b/VM/src/lmathlib.cpp @@ -385,8 +385,7 @@ static int math_sign(lua_State* L) static int math_round(lua_State* L) { - double v = luaL_checknumber(L, 1); - lua_pushnumber(L, round(v)); + lua_pushnumber(L, round(luaL_checknumber(L, 1))); return 1; } @@ -429,7 +428,7 @@ static const luaL_Reg mathlib[] = { /* ** Open math library */ -LUALIB_API int luaopen_math(lua_State* L) +int luaopen_math(lua_State* L) { uint64_t seed = uintptr_t(L); seed ^= time(NULL); diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index d8b265cba..9f9d4a98f 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -33,11 +33,17 @@ #define ABISWITCH(x64, ms32, gcc32) (sizeof(void*) == 8 ? x64 : ms32) #endif +#if LUA_VECTOR_SIZE == 4 +static_assert(sizeof(TValue) == ABISWITCH(24, 24, 24), "size mismatch for value"); +static_assert(sizeof(LuaNode) == ABISWITCH(48, 48, 48), "size mismatch for table entry"); +#else static_assert(sizeof(TValue) == ABISWITCH(16, 16, 16), "size mismatch for value"); +static_assert(sizeof(LuaNode) == ABISWITCH(32, 32, 32), "size mismatch for table entry"); +#endif + static_assert(offsetof(TString, data) == ABISWITCH(24, 20, 20), "size mismatch for string header"); static_assert(offsetof(Udata, data) == ABISWITCH(24, 16, 16), "size mismatch for userdata header"); static_assert(sizeof(Table) == ABISWITCH(56, 36, 36), "size mismatch for table header"); -static_assert(sizeof(LuaNode) == ABISWITCH(32, 32, 32), "size mismatch for table entry"); const size_t kSizeClasses = LUA_SIZECLASSES; const size_t kMaxSmallSize = 512; diff --git a/VM/src/lnumutils.h b/VM/src/lnumutils.h index 43f8014b4..67f832dc5 100644 --- a/VM/src/lnumutils.h +++ b/VM/src/lnumutils.h @@ -18,12 +18,20 @@ inline bool luai_veceq(const float* a, const float* b) { +#if LUA_VECTOR_SIZE == 4 + return a[0] == b[0] && a[1] == b[1] && a[2] == b[2] && a[3] == b[3]; +#else return a[0] == b[0] && a[1] == b[1] && a[2] == b[2]; +#endif } inline bool luai_vecisnan(const float* a) { +#if LUA_VECTOR_SIZE == 4 + return a[0] != a[0] || a[1] != a[1] || a[2] != a[2] || a[3] != a[3]; +#else return a[0] != a[0] || a[1] != a[1] || a[2] != a[2]; +#endif } LUAU_FASTMATH_BEGIN diff --git a/VM/src/lobject.cpp b/VM/src/lobject.cpp index bf13e6e97..370c7b283 100644 --- a/VM/src/lobject.cpp +++ b/VM/src/lobject.cpp @@ -15,7 +15,7 @@ -const TValue luaO_nilobject_ = {{NULL}, LUA_TNIL}; +const TValue luaO_nilobject_ = {{NULL}, {0}, LUA_TNIL}; int luaO_log2(unsigned int x) { diff --git a/VM/src/lobject.h b/VM/src/lobject.h index c5f2e2f4f..ba040af6c 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -47,7 +47,7 @@ typedef union typedef struct lua_TValue { Value value; - int extra; + int extra[LUA_EXTRA_SIZE]; int tt; } TValue; @@ -105,15 +105,28 @@ typedef struct lua_TValue i_o->tt = LUA_TNUMBER; \ } -#define setvvalue(obj, x, y, z) \ +#if LUA_VECTOR_SIZE == 4 +#define setvvalue(obj, x, y, z, w) \ { \ TValue* i_o = (obj); \ float* i_v = i_o->value.v; \ i_v[0] = (x); \ i_v[1] = (y); \ i_v[2] = (z); \ + i_v[3] = (w); \ i_o->tt = LUA_TVECTOR; \ } +#else +#define setvvalue(obj, x, y, z, w) \ + { \ + TValue* i_o = (obj); \ + float* i_v = i_o->value.v; \ + i_v[0] = (x); \ + i_v[1] = (y); \ + i_v[2] = (z); \ + i_o->tt = LUA_TVECTOR; \ + } +#endif #define setpvalue(obj, x) \ { \ @@ -364,7 +377,7 @@ typedef struct Closure typedef struct TKey { ::Value value; - int extra; + int extra[LUA_EXTRA_SIZE]; unsigned tt : 4; int next : 28; /* for chaining */ } TKey; @@ -381,7 +394,7 @@ typedef struct LuaNode LuaNode* n_ = (node); \ const TValue* i_o = (obj); \ n_->key.value = i_o->value; \ - n_->key.extra = i_o->extra; \ + memcpy(n_->key.extra, i_o->extra, sizeof(n_->key.extra)); \ n_->key.tt = i_o->tt; \ checkliveness(L->global, i_o); \ } @@ -392,7 +405,7 @@ typedef struct LuaNode TValue* i_o = (obj); \ const LuaNode* n_ = (node); \ i_o->value = n_->key.value; \ - i_o->extra = n_->key.extra; \ + memcpy(i_o->extra, n_->key.extra, sizeof(i_o->extra)); \ i_o->tt = n_->key.tt; \ checkliveness(L->global, i_o); \ } diff --git a/VM/src/loslib.cpp b/VM/src/loslib.cpp index 8eaef60cb..b5901865c 100644 --- a/VM/src/loslib.cpp +++ b/VM/src/loslib.cpp @@ -186,7 +186,7 @@ static const luaL_Reg syslib[] = { {NULL, NULL}, }; -LUALIB_API int luaopen_os(lua_State* L) +int luaopen_os(lua_State* L) { luaL_register(L, LUA_OSLIBNAME, syslib); return 1; diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index b576f8093..0b3054ae4 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -1657,7 +1657,7 @@ static void createmetatable(lua_State* L) /* ** Open string library */ -LUALIB_API int luaopen_string(lua_State* L) +int luaopen_string(lua_State* L) { luaL_register(L, LUA_STRLIBNAME, strlib); createmetatable(L); diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 07d22d596..0b55fceac 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -31,18 +31,19 @@ LUAU_FASTFLAGVARIABLE(LuauArrayBoundary, false) #define MAXSIZE (1 << MAXBITS) static_assert(offsetof(LuaNode, val) == 0, "Unexpected Node memory layout, pointer cast in gval2slot is incorrect"); + // TKey is bitpacked for memory efficiency so we need to validate bit counts for worst case -static_assert(TKey{{NULL}, 0, LUA_TDEADKEY, 0}.tt == LUA_TDEADKEY, "not enough bits for tt"); -static_assert(TKey{{NULL}, 0, LUA_TNIL, MAXSIZE - 1}.next == MAXSIZE - 1, "not enough bits for next"); -static_assert(TKey{{NULL}, 0, LUA_TNIL, -(MAXSIZE - 1)}.next == -(MAXSIZE - 1), "not enough bits for next"); +static_assert(TKey{{NULL}, {0}, LUA_TDEADKEY, 0}.tt == LUA_TDEADKEY, "not enough bits for tt"); +static_assert(TKey{{NULL}, {0}, LUA_TNIL, MAXSIZE - 1}.next == MAXSIZE - 1, "not enough bits for next"); +static_assert(TKey{{NULL}, {0}, LUA_TNIL, -(MAXSIZE - 1)}.next == -(MAXSIZE - 1), "not enough bits for next"); // reset cache of absent metamethods, cache is updated in luaT_gettm #define invalidateTMcache(t) t->flags = 0 // empty hash data points to dummynode so that we can always dereference it const LuaNode luaH_dummynode = { - {{NULL}, 0, LUA_TNIL}, /* value */ - {{NULL}, 0, LUA_TNIL, 0} /* key */ + {{NULL}, {0}, LUA_TNIL}, /* value */ + {{NULL}, {0}, LUA_TNIL, 0} /* key */ }; #define dummynode (&luaH_dummynode) @@ -96,7 +97,7 @@ static LuaNode* hashnum(const Table* t, double n) static LuaNode* hashvec(const Table* t, const float* v) { - unsigned int i[3]; + unsigned int i[LUA_VECTOR_SIZE]; memcpy(i, v, sizeof(i)); // convert -0 to 0 to make sure they hash to the same value @@ -112,6 +113,12 @@ static LuaNode* hashvec(const Table* t, const float* v) // Optimized Spatial Hashing for Collision Detection of Deformable Objects unsigned int h = (i[0] * 73856093) ^ (i[1] * 19349663) ^ (i[2] * 83492791); +#if LUA_VECTOR_SIZE == 4 + i[3] = (i[3] == 0x8000000) ? 0 : i[3]; + i[3] ^= i[3] >> 17; + h ^= i[3] * 39916801; +#endif + return hashpow2(t, h); } diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 370258189..0d3374efa 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -527,7 +527,7 @@ static const luaL_Reg tab_funcs[] = { {NULL, NULL}, }; -LUALIB_API int luaopen_table(lua_State* L) +int luaopen_table(lua_State* L) { luaL_register(L, LUA_TABLIBNAME, tab_funcs); diff --git a/VM/src/lutf8lib.cpp b/VM/src/lutf8lib.cpp index 378de3d0d..8bc8200a5 100644 --- a/VM/src/lutf8lib.cpp +++ b/VM/src/lutf8lib.cpp @@ -283,7 +283,7 @@ static const luaL_Reg funcs[] = { {NULL, NULL}, }; -LUALIB_API int luaopen_utf8(lua_State* L) +int luaopen_utf8(lua_State* L) { luaL_register(L, LUA_UTF8LIBNAME, funcs); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index eed2862b1..bf8d493eb 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -601,7 +601,13 @@ static void luau_execute(lua_State* L) const char* name = getstr(tsvalue(kv)); int ic = (name[0] | ' ') - 'x'; - if (unsigned(ic) < 3 && name[1] == '\0') +#if LUA_VECTOR_SIZE == 4 + // 'w' is before 'x' in ascii, so ic is -1 when indexing with 'w' + if (ic == -1) + ic = 3; +#endif + + if (unsigned(ic) < LUA_VECTOR_SIZE && name[1] == '\0') { setnvalue(ra, rb->value.v[ic]); VM_NEXT(); @@ -1526,7 +1532,7 @@ static void luau_execute(lua_State* L) { const float* vb = rb->value.v; const float* vc = rc->value.v; - setvvalue(ra, vb[0] + vc[0], vb[1] + vc[1], vb[2] + vc[2]); + setvvalue(ra, vb[0] + vc[0], vb[1] + vc[1], vb[2] + vc[2], vb[3] + vc[3]); VM_NEXT(); } else @@ -1572,7 +1578,7 @@ static void luau_execute(lua_State* L) { const float* vb = rb->value.v; const float* vc = rc->value.v; - setvvalue(ra, vb[0] - vc[0], vb[1] - vc[1], vb[2] - vc[2]); + setvvalue(ra, vb[0] - vc[0], vb[1] - vc[1], vb[2] - vc[2], vb[3] - vc[3]); VM_NEXT(); } else @@ -1618,21 +1624,21 @@ static void luau_execute(lua_State* L) { const float* vb = rb->value.v; float vc = cast_to(float, nvalue(rc)); - setvvalue(ra, vb[0] * vc, vb[1] * vc, vb[2] * vc); + setvvalue(ra, vb[0] * vc, vb[1] * vc, vb[2] * vc, vb[3] * vc); VM_NEXT(); } else if (ttisvector(rb) && ttisvector(rc)) { const float* vb = rb->value.v; const float* vc = rc->value.v; - setvvalue(ra, vb[0] * vc[0], vb[1] * vc[1], vb[2] * vc[2]); + setvvalue(ra, vb[0] * vc[0], vb[1] * vc[1], vb[2] * vc[2], vb[3] * vc[3]); VM_NEXT(); } else if (ttisnumber(rb) && ttisvector(rc)) { float vb = cast_to(float, nvalue(rb)); const float* vc = rc->value.v; - setvvalue(ra, vb * vc[0], vb * vc[1], vb * vc[2]); + setvvalue(ra, vb * vc[0], vb * vc[1], vb * vc[2], vb * vc[3]); VM_NEXT(); } else @@ -1679,21 +1685,21 @@ static void luau_execute(lua_State* L) { const float* vb = rb->value.v; float vc = cast_to(float, nvalue(rc)); - setvvalue(ra, vb[0] / vc, vb[1] / vc, vb[2] / vc); + setvvalue(ra, vb[0] / vc, vb[1] / vc, vb[2] / vc, vb[3] / vc); VM_NEXT(); } else if (ttisvector(rb) && ttisvector(rc)) { const float* vb = rb->value.v; const float* vc = rc->value.v; - setvvalue(ra, vb[0] / vc[0], vb[1] / vc[1], vb[2] / vc[2]); + setvvalue(ra, vb[0] / vc[0], vb[1] / vc[1], vb[2] / vc[2], vb[3] / vc[3]); VM_NEXT(); } else if (ttisnumber(rb) && ttisvector(rc)) { float vb = cast_to(float, nvalue(rb)); const float* vc = rc->value.v; - setvvalue(ra, vb / vc[0], vb / vc[1], vb / vc[2]); + setvvalue(ra, vb / vc[0], vb / vc[1], vb / vc[2], vb / vc[3]); VM_NEXT(); } else @@ -1826,7 +1832,7 @@ static void luau_execute(lua_State* L) { const float* vb = rb->value.v; float vc = cast_to(float, nvalue(kv)); - setvvalue(ra, vb[0] * vc, vb[1] * vc, vb[2] * vc); + setvvalue(ra, vb[0] * vc, vb[1] * vc, vb[2] * vc, vb[3] * vc); VM_NEXT(); } else @@ -1872,7 +1878,7 @@ static void luau_execute(lua_State* L) { const float* vb = rb->value.v; float vc = cast_to(float, nvalue(kv)); - setvvalue(ra, vb[0] / vc, vb[1] / vc, vb[2] / vc); + setvvalue(ra, vb[0] / vc, vb[1] / vc, vb[2] / vc, vb[3] / vc); VM_NEXT(); } else @@ -2037,7 +2043,7 @@ static void luau_execute(lua_State* L) else if (ttisvector(rb)) { const float* vb = rb->value.v; - setvvalue(ra, -vb[0], -vb[1], -vb[2]); + setvvalue(ra, -vb[0], -vb[1], -vb[2], -vb[3]); VM_NEXT(); } else diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index a168b6522..add3588d4 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -9,6 +9,7 @@ #include "lgc.h" #include "lmem.h" #include "lbytecode.h" +#include "lapi.h" #include @@ -162,9 +163,8 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size size_t GCthreshold = L->global->GCthreshold; L->global->GCthreshold = SIZE_MAX; - // env is 0 for current environment and a stack relative index otherwise - LUAU_ASSERT(env <= 0 && L->top - L->base >= -env); - Table* envt = (env == 0) ? hvalue(gt(L)) : hvalue(L->top + env); + // env is 0 for current environment and a stack index otherwise + Table* envt = (env == 0) ? hvalue(gt(L)) : hvalue(luaA_toobject(L, env)); TString* source = luaS_new(L, chunkname); diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index f52e8e745..740a4cfd2 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -401,19 +401,19 @@ void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TM switch (op) { case TM_ADD: - setvvalue(ra, vb[0] + vc[0], vb[1] + vc[1], vb[2] + vc[2]); + setvvalue(ra, vb[0] + vc[0], vb[1] + vc[1], vb[2] + vc[2], vb[3] + vc[3]); return; case TM_SUB: - setvvalue(ra, vb[0] - vc[0], vb[1] - vc[1], vb[2] - vc[2]); + setvvalue(ra, vb[0] - vc[0], vb[1] - vc[1], vb[2] - vc[2], vb[3] - vc[3]); return; case TM_MUL: - setvvalue(ra, vb[0] * vc[0], vb[1] * vc[1], vb[2] * vc[2]); + setvvalue(ra, vb[0] * vc[0], vb[1] * vc[1], vb[2] * vc[2], vb[3] * vc[3]); return; case TM_DIV: - setvvalue(ra, vb[0] / vc[0], vb[1] / vc[1], vb[2] / vc[2]); + setvvalue(ra, vb[0] / vc[0], vb[1] / vc[1], vb[2] / vc[2], vb[3] / vc[3]); return; case TM_UNM: - setvvalue(ra, -vb[0], -vb[1], -vb[2]); + setvvalue(ra, -vb[0], -vb[1], -vb[2], -vb[3]); return; default: break; @@ -430,10 +430,10 @@ void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TM switch (op) { case TM_MUL: - setvvalue(ra, vb[0] * nc, vb[1] * nc, vb[2] * nc); + setvvalue(ra, vb[0] * nc, vb[1] * nc, vb[2] * nc, vb[3] * nc); return; case TM_DIV: - setvvalue(ra, vb[0] / nc, vb[1] / nc, vb[2] / nc); + setvvalue(ra, vb[0] / nc, vb[1] / nc, vb[2] / nc, vb[3] / nc); return; default: break; @@ -451,10 +451,10 @@ void luaV_doarith(lua_State* L, StkId ra, const TValue* rb, const TValue* rc, TM switch (op) { case TM_MUL: - setvvalue(ra, nb * vc[0], nb * vc[1], nb * vc[2]); + setvvalue(ra, nb * vc[0], nb * vc[1], nb * vc[2], nb * vc[3]); return; case TM_DIV: - setvvalue(ra, nb / vc[0], nb / vc[1], nb / vc[2]); + setvvalue(ra, nb / vc[0], nb / vc[1], nb / vc[2], nb / vc[3]); return; default: break; diff --git a/bench/gc/test_LB_mandel.lua b/bench/gc/test_LB_mandel.lua index 4be785022..fe5b4eb2f 100644 --- a/bench/gc/test_LB_mandel.lua +++ b/bench/gc/test_LB_mandel.lua @@ -88,7 +88,7 @@ for i=1,N do local y=ymin+(j-1)*dy S = S + level(x,y) end - -- if i % 10 == 0 then print(collectgarbage"count") end + -- if i % 10 == 0 then print(collectgarbage("count")) end end print(S) diff --git a/bench/tests/shootout/mandel.lua b/bench/tests/shootout/mandel.lua index 4be785022..fe5b4eb2f 100644 --- a/bench/tests/shootout/mandel.lua +++ b/bench/tests/shootout/mandel.lua @@ -88,7 +88,7 @@ for i=1,N do local y=ymin+(j-1)*dy S = S + level(x,y) end - -- if i % 10 == 0 then print(collectgarbage"count") end + -- if i % 10 == 0 then print(collectgarbage("count")) end end print(S) diff --git a/bench/tests/shootout/qt.lua b/bench/tests/shootout/qt.lua index de962a74f..79cbe38ba 100644 --- a/bench/tests/shootout/qt.lua +++ b/bench/tests/shootout/qt.lua @@ -275,7 +275,7 @@ local function memory(s) local t=os.clock() --local dt=string.format("%f",t-t0) local dt=t-t0 - --io.stdout:write(s,"\t",dt," sec\t",t," sec\t",math.floor(collectgarbage"count"/1024),"M\n") + --io.stdout:write(s,"\t",dt," sec\t",t," sec\t",math.floor(collectgarbage("count")/1024),"M\n") t0=t end @@ -286,7 +286,7 @@ local function do_(f,s) end local function julia(l,a,b) -memory"begin" +memory("begin") cx=a cy=b root=newcell() exterior=newcell() exterior.color=white @@ -297,14 +297,14 @@ memory"begin" do_(update,"update") repeat N=0 color(root,Rxmin,Rxmax,Rymin,Rymax) --print("color",N) - until N==0 memory"color" + until N==0 memory("color") repeat N=0 prewhite(root,Rxmin,Rxmax,Rymin,Rymax) --print("prewhite",N) - until N==0 memory"prewhite" + until N==0 memory("prewhite") do_(recolor,"recolor") do_(colorup,"colorup") --print("colorup",N) local g,b=do_(area,"area") --print("area",g,b,g+b) - show(i) memory"output" + show(i) memory("output") --print("edges",nE) end end diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index ae2399e49..27e534927 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -23,9 +23,11 @@ const bool kFuzzCompiler = true; const bool kFuzzLinter = true; const bool kFuzzTypeck = true; const bool kFuzzVM = true; -const bool kFuzzTypes = true; const bool kFuzzTranspile = true; +// Should we generate type annotations? +const bool kFuzzTypes = true; + static_assert(!(kFuzzVM && !kFuzzCompiler), "VM requires the compiler!"); std::string protoprint(const luau::StatBlock& stat, bool types); diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index aa53a92b2..2090b0148 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -78,3 +78,26 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "overloaded_fn") } TEST_SUITE_END(); + +TEST_SUITE_BEGIN("AstQuery"); + +TEST_CASE_FIXTURE(Fixture, "last_argument_function_call_type") +{ + ScopedFastFlag luauTailArgumentTypeInfo{"LuauTailArgumentTypeInfo", true}; + + check(R"( +local function foo() return 2 end +local function bar(a: number) return -a end +bar(foo()) + )"); + + auto oty = findTypeAtPosition(Position(3, 7)); + REQUIRE(oty); + CHECK_EQ("number", toString(*oty)); + + auto expectedOty = findExpectedTypeAtPosition(Position(3, 7)); + REQUIRE(expectedOty); + CHECK_EQ("number", toString(*expectedOty)); +} + +TEST_SUITE_END(); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 5a7c86023..3b74a99e4 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1935,6 +1935,39 @@ return target(b@1 CHECK(ac.entryMap["bar2"].typeCorrect == TypeCorrectKind::None); } +TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses") +{ + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + ScopedFastFlag luauAutocompletePreferToCallFunctions("LuauAutocompletePreferToCallFunctions", true); + + check(R"( +local function bar(a: number) return -a end +local abc = b@1 + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("bar")); + CHECK(ac.entryMap["bar"].parens == ParenthesesRecommendation::CursorInside); +} + +TEST_CASE_FIXTURE(ACFixture, "function_result_passed_to_function_has_parentheses") +{ + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + ScopedFastFlag luauAutocompletePreferToCallFunctions("LuauAutocompletePreferToCallFunctions", true); + + check(R"( +local function foo() return 1 end +local function bar(a: number) return -a end +local abc = bar(@1) + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("foo")); + CHECK(ac.entryMap["foo"].parens == ParenthesesRecommendation::CursorAfter); +} + TEST_CASE_FIXTURE(ACFixture, "type_correct_sealed_table") { check(R"( @@ -2210,8 +2243,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocompleteSource") TEST_CASE_FIXTURE(ACFixture, "autocompleteSource_require") { - ScopedFastFlag luauResolveModuleNameWithoutACurrentModule("LuauResolveModuleNameWithoutACurrentModule", true); - std::string_view source = R"( local a = require(w -- Line 1 -- | Column 27 @@ -2287,8 +2318,6 @@ until TEST_CASE_FIXTURE(ACFixture, "if_then_else_elseif_completions") { - ScopedFastFlag sff{"ElseElseIfCompletionImprovements", true}; - check(R"( local elsewhere = false @@ -2585,9 +2614,6 @@ a = if temp then even elseif true then temp else e@9 TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - check(R"( type A = () -> T... local a: A<(number, s@1> diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 4ce8d08ae..6ba39adab 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -1057,6 +1057,18 @@ RETURN R0 1 CHECK_EQ("\n" + compileFunction0("return if false then 10 else 20"), R"( LOADN R0 20 RETURN R0 1 +)"); + + // codegen for a true constant condition with non-constant expressions + CHECK_EQ("\n" + compileFunction0("return if true then {} else error()"), R"( +NEWTABLE R0 0 0 +RETURN R0 1 +)"); + + // codegen for a false constant condition with non-constant expressions + CHECK_EQ("\n" + compileFunction0("return if false then error() else {}"), R"( +NEWTABLE R0 0 0 +RETURN R0 1 )"); // codegen for a false (in this case 'nil') constant condition @@ -2360,6 +2372,58 @@ Foo:Bar( )"); } +TEST_CASE("DebugLineInfoCallChain") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); + Luau::compileOrThrow(bcb, R"( +local Foo = ... + +Foo +:Bar(1) +:Baz(2) +.Qux(3) +)"); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +2: GETVARARGS R0 1 +5: LOADN R4 1 +5: NAMECALL R2 R0 K0 +5: CALL R2 2 1 +6: LOADN R4 2 +6: NAMECALL R2 R2 K1 +6: CALL R2 2 1 +7: GETTABLEKS R1 R2 K2 +7: LOADN R2 3 +7: CALL R1 1 0 +8: RETURN R0 0 +)"); +} + +TEST_CASE("DebugLineInfoFastCall") +{ + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); + Luau::compileOrThrow(bcb, R"( +local Foo, Bar = ... + +return + math.max( + Foo, + Bar) +)"); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +2: GETVARARGS R0 2 +5: FASTCALL2 18 R0 R1 +5 +5: MOVE R3 R0 +5: MOVE R4 R1 +5: GETIMPORT R2 2 +5: CALL R2 2 -1 +5: RETURN R2 -1 +)"); +} + TEST_CASE("DebugSource") { const char* source = R"( @@ -3742,4 +3806,108 @@ RETURN R0 0 )"); } +TEST_CASE("ConstantsNoFolding") +{ + const char* source = "return nil, true, 42, 'hello'"; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::CompileOptions options; + options.optimizationLevel = 0; + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +LOADNIL R0 +LOADB R1 1 +LOADK R2 K0 +LOADK R3 K1 +RETURN R0 4 +)"); +} + +TEST_CASE("VectorFastCall") +{ + const char* source = "return Vector3.new(1, 2, 3)"; + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); + Luau::CompileOptions options; + options.vectorLib = "Vector3"; + options.vectorCtor = "new"; + Luau::compileOrThrow(bcb, source, options); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +LOADN R1 1 +LOADN R2 2 +LOADN R3 3 +FASTCALL 54 +2 +GETIMPORT R0 2 +CALL R0 3 -1 +RETURN R0 -1 +)"); +} + +TEST_CASE("TypeAssertion") +{ + // validate that type assertions work with the compiler and that the code inside type assertion isn't evaluated + CHECK_EQ("\n" + compileFunction0(R"( +print(foo() :: typeof(error("compile time"))) +)"), + R"( +GETIMPORT R0 1 +GETIMPORT R1 3 +CALL R1 0 1 +CALL R0 1 0 +RETURN R0 0 +)"); + + // note that above, foo() is treated as single-arg function; removing type assertion changes the bytecode + CHECK_EQ("\n" + compileFunction0(R"( +print(foo()) +)"), + R"( +GETIMPORT R0 1 +GETIMPORT R1 3 +CALL R1 0 -1 +CALL R0 -1 0 +RETURN R0 0 +)"); +} + +TEST_CASE("Arithmetics") +{ + // basic arithmetics codegen with non-constants + CHECK_EQ("\n" + compileFunction0(R"( +local a, b = ... +return a + b, a - b, a / b, a * b, a % b, a ^ b +)"), + R"( +GETVARARGS R0 2 +ADD R2 R0 R1 +SUB R3 R0 R1 +DIV R4 R0 R1 +MUL R5 R0 R1 +MOD R6 R0 R1 +POW R7 R0 R1 +RETURN R2 6 +)"); + + // basic arithmetics codegen with constants on the right side + // note that we don't simplify these expressions as we don't know the type of a + CHECK_EQ("\n" + compileFunction0(R"( +local a = ... +return a + 1, a - 1, a / 1, a * 1, a % 1, a ^ 1 +)"), + R"( +GETVARARGS R0 1 +ADDK R1 R0 K0 +SUBK R2 R0 K0 +DIVK R3 R0 K0 +MULK R4 R0 K0 +MODK R5 R0 K0 +POWK R6 R0 K0 +RETURN R1 6 +)"); +} + TEST_SUITE_END(); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index e495a2136..b2aad3163 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -67,44 +67,42 @@ static int lua_vector(lua_State* L) double y = luaL_checknumber(L, 2); double z = luaL_checknumber(L, 3); +#if LUA_VECTOR_SIZE == 4 + double w = luaL_optnumber(L, 4, 0.0); + lua_pushvector(L, float(x), float(y), float(z), float(w)); +#else lua_pushvector(L, float(x), float(y), float(z)); +#endif return 1; } static int lua_vector_dot(lua_State* L) { - const float* a = lua_tovector(L, 1); - const float* b = lua_tovector(L, 2); - - if (a && b) - { - lua_pushnumber(L, a[0] * b[0] + a[1] * b[1] + a[2] * b[2]); - return 1; - } + const float* a = luaL_checkvector(L, 1); + const float* b = luaL_checkvector(L, 2); - throw std::runtime_error("invalid arguments to vector:Dot"); + lua_pushnumber(L, a[0] * b[0] + a[1] * b[1] + a[2] * b[2]); + return 1; } static int lua_vector_index(lua_State* L) { + const float* v = luaL_checkvector(L, 1); const char* name = luaL_checkstring(L, 2); - if (const float* v = lua_tovector(L, 1)) + if (strcmp(name, "Magnitude") == 0) { - if (strcmp(name, "Magnitude") == 0) - { - lua_pushnumber(L, sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2])); - return 1; - } + lua_pushnumber(L, sqrtf(v[0] * v[0] + v[1] * v[1] + v[2] * v[2])); + return 1; + } - if (strcmp(name, "Dot") == 0) - { - lua_pushcfunction(L, lua_vector_dot, "Dot"); - return 1; - } + if (strcmp(name, "Dot") == 0) + { + lua_pushcfunction(L, lua_vector_dot, "Dot"); + return 1; } - throw std::runtime_error(Luau::format("%s is not a valid member of vector", name)); + luaL_error(L, "%s is not a valid member of vector", name); } static int lua_vector_namecall(lua_State* L) @@ -115,7 +113,7 @@ static int lua_vector_namecall(lua_State* L) return lua_vector_dot(L); } - throw std::runtime_error(Luau::format("%s is not a valid method of vector", luaL_checkstring(L, 1))); + luaL_error(L, "%s is not a valid method of vector", luaL_checkstring(L, 1)); } int lua_silence(lua_State* L) @@ -373,11 +371,17 @@ TEST_CASE("Pack") TEST_CASE("Vector") { + ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; + runConformance("vector.lua", [](lua_State* L) { lua_pushcfunction(L, lua_vector, "vector"); lua_setglobal(L, "vector"); +#if LUA_VECTOR_SIZE == 4 + lua_pushvector(L, 0.0f, 0.0f, 0.0f, 0.0f); +#else lua_pushvector(L, 0.0f, 0.0f, 0.0f); +#endif luaL_newmetatable(L, "vector"); lua_pushstring(L, "__index"); @@ -504,6 +508,9 @@ TEST_CASE("Debugger") cb->debugbreak = [](lua_State* L, lua_Debug* ar) { breakhits++; + // make sure we can trace the stack for every breakpoint we hit + lua_debugtrace(L); + // for every breakpoint, we break on the first invocation and continue on second // this allows us to easily step off breakpoints // (real implementaiton may require singlestepping) @@ -524,7 +531,7 @@ TEST_CASE("Debugger") L, [](lua_State* L) -> int { int line = luaL_checkinteger(L, 1); - bool enabled = lua_isboolean(L, 2) ? lua_toboolean(L, 2) : true; + bool enabled = luaL_optboolean(L, 2, true); lua_Debug ar = {}; lua_getinfo(L, 1, "f", &ar); @@ -699,21 +706,52 @@ TEST_CASE("ApiFunctionCalls") StateRef globalState = runConformance("apicalls.lua"); lua_State* L = globalState.get(); - lua_getfield(L, LUA_GLOBALSINDEX, "add"); - lua_pushnumber(L, 40); - lua_pushnumber(L, 2); - lua_call(L, 2, 1); - CHECK(lua_isnumber(L, -1)); - CHECK(lua_tonumber(L, -1) == 42); - lua_pop(L, 1); + // lua_call + { + lua_getfield(L, LUA_GLOBALSINDEX, "add"); + lua_pushnumber(L, 40); + lua_pushnumber(L, 2); + lua_call(L, 2, 1); + CHECK(lua_isnumber(L, -1)); + CHECK(lua_tonumber(L, -1) == 42); + lua_pop(L, 1); + } - lua_getfield(L, LUA_GLOBALSINDEX, "add"); - lua_pushnumber(L, 40); - lua_pushnumber(L, 2); - lua_pcall(L, 2, 1, 0); - CHECK(lua_isnumber(L, -1)); - CHECK(lua_tonumber(L, -1) == 42); - lua_pop(L, 1); + // lua_pcall + { + lua_getfield(L, LUA_GLOBALSINDEX, "add"); + lua_pushnumber(L, 40); + lua_pushnumber(L, 2); + lua_pcall(L, 2, 1, 0); + CHECK(lua_isnumber(L, -1)); + CHECK(lua_tonumber(L, -1) == 42); + lua_pop(L, 1); + } + + // lua_equal with a sleeping thread wake up + { + ScopedFastFlag luauActivateBeforeExec("LuauActivateBeforeExec", true); + + lua_State* L2 = lua_newthread(L); + + lua_getfield(L2, LUA_GLOBALSINDEX, "create_with_tm"); + lua_pushnumber(L2, 42); + lua_pcall(L2, 1, 1, 0); + + lua_getfield(L2, LUA_GLOBALSINDEX, "create_with_tm"); + lua_pushnumber(L2, 42); + lua_pcall(L2, 1, 1, 0); + + // Reset GC + lua_gc(L2, LUA_GCCOLLECT, 0); + + // Try to mark 'L2' as sleeping + // Can't control GC precisely, even in tests + lua_gc(L2, LUA_GCSTEP, 8); + + CHECK(lua_equal(L2, -1, -2) == 1); + lua_pop(L2, 2); + } } static bool endsWith(const std::string& str, const std::string& suffix) @@ -727,8 +765,6 @@ static bool endsWith(const std::string& str, const std::string& suffix) #if !LUA_USE_LONGJMP TEST_CASE("ExceptionObject") { - ScopedFastFlag sff("LuauExceptionMessageFix", true); - struct ExceptionResult { bool exceptionGenerated; diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 29c33f7c1..36d6f5612 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -19,19 +19,6 @@ static const char* mainModuleName = "MainModule"; namespace Luau { -std::optional TestFileResolver::fromAstFragment(AstExpr* expr) const -{ - auto g = expr->as(); - if (!g) - return std::nullopt; - - std::string_view value = g->name.value; - if (value == "game" || value == "Game" || value == "workspace" || value == "Workspace" || value == "script" || value == "Script") - return ModuleName(value); - - return std::nullopt; -} - std::optional TestFileResolver::resolveModule(const ModuleInfo* context, AstExpr* expr) { if (AstExprGlobal* g = expr->as()) @@ -81,24 +68,6 @@ std::optional TestFileResolver::resolveModule(const ModuleInfo* cont return std::nullopt; } -ModuleName TestFileResolver::concat(const ModuleName& lhs, std::string_view rhs) const -{ - return lhs + "/" + ModuleName(rhs); -} - -std::optional TestFileResolver::getParentModuleName(const ModuleName& name) const -{ - std::string_view view = name; - const size_t lastSeparatorIndex = view.find_last_of('/'); - - if (lastSeparatorIndex != std::string_view::npos) - { - return ModuleName(view.substr(0, lastSeparatorIndex)); - } - - return std::nullopt; -} - std::string TestFileResolver::getHumanReadableModuleName(const ModuleName& name) const { return name; @@ -324,6 +293,13 @@ std::optional Fixture::findTypeAtPosition(Position position) return Luau::findTypeAtPosition(*module, *sourceModule, position); } +std::optional Fixture::findExpectedTypeAtPosition(Position position) +{ + ModulePtr module = getMainModule(); + SourceModule* sourceModule = getMainSourceModule(); + return Luau::findExpectedTypeAtPosition(*module, *sourceModule, position); +} + TypeId Fixture::requireTypeAtPosition(Position position) { auto ty = findTypeAtPosition(position); diff --git a/tests/Fixture.h b/tests/Fixture.h index 1480a7f6a..de2b7381e 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -64,12 +64,8 @@ struct TestFileResolver return SourceCode{it->second, sourceType}; } - std::optional fromAstFragment(AstExpr* expr) const override; std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) override; - ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override; - std::optional getParentModuleName(const ModuleName& name) const override; - std::string getHumanReadableModuleName(const ModuleName& name) const override; std::optional getEnvironmentForModule(const ModuleName& name) const override; @@ -126,6 +122,7 @@ struct Fixture std::optional findTypeAtPosition(Position position); TypeId requireTypeAtPosition(Position position); + std::optional findExpectedTypeAtPosition(Position position); std::optional lookupType(const std::string& name); std::optional lookupImportedType(const std::string& moduleAlias, const std::string& name); diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index fbfec6367..51fcd3d6c 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -46,18 +46,6 @@ NaiveModuleResolver naiveModuleResolver; struct NaiveFileResolver : NullFileResolver { - std::optional fromAstFragment(AstExpr* expr) const override - { - AstExprGlobal* g = expr->as(); - if (g && g->name == "Modules") - return "Modules"; - - if (g && g->name == "game") - return "game"; - - return std::nullopt; - } - std::optional resolveModule(const ModuleInfo* context, AstExpr* expr) override { if (AstExprGlobal* g = expr->as()) @@ -86,11 +74,6 @@ struct NaiveFileResolver : NullFileResolver return std::nullopt; } - - ModuleName concat(const ModuleName& lhs, std::string_view rhs) const override - { - return lhs + "/" + ModuleName(rhs); - } }; } // namespace diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 7ba40c503..1d13df289 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1469,6 +1469,22 @@ _ = true and true or false -- no warning since this is is a common pattern used CHECK_EQ(result.warnings[6].location.begin.line + 1, 19); } +TEST_CASE_FIXTURE(Fixture, "DuplicateConditionsExpr") +{ + LintResult result = lint(R"( +local correct, opaque = ... + +if correct({a = 1, b = 2 * (-2), c = opaque.path['with']("calls")}) then +elseif correct({a = 1, b = 2 * (-2), c = opaque.path['with']("calls")}) then +elseif correct({a = 1, b = 2 * (-2), c = opaque.path['with']("calls", false)}) then +end +)"); + + REQUIRE_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "Condition has already been checked on line 4"); + CHECK_EQ(result.warnings[0].location.begin.line + 1, 5); +} + TEST_CASE_FIXTURE(Fixture, "DuplicateLocal") { LintResult result = lint(R"( diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 7a3543c7c..2800d2fe6 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -44,9 +44,10 @@ TEST_CASE_FIXTURE(Fixture, "dont_clone_persistent_primitive") SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; // numberType is persistent. We leave it as-is. - TypeId newNumber = clone(typeChecker.numberType, dest, seenTypes, seenTypePacks); + TypeId newNumber = clone(typeChecker.numberType, dest, seenTypes, seenTypePacks, cloneState); CHECK_EQ(newNumber, typeChecker.numberType); } @@ -56,12 +57,13 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_non_persistent_primitive") SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; // Create a new number type that isn't persistent unfreeze(typeChecker.globalTypes); TypeId oldNumber = typeChecker.globalTypes.addType(PrimitiveTypeVar{PrimitiveTypeVar::Number}); freeze(typeChecker.globalTypes); - TypeId newNumber = clone(oldNumber, dest, seenTypes, seenTypePacks); + TypeId newNumber = clone(oldNumber, dest, seenTypes, seenTypePacks, cloneState); CHECK_NE(newNumber, oldNumber); CHECK_EQ(*oldNumber, *newNumber); @@ -89,9 +91,10 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_cyclic_table") SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; TypeArena dest; - TypeId counterCopy = clone(counterType, dest, seenTypes, seenTypePacks); + TypeId counterCopy = clone(counterType, dest, seenTypes, seenTypePacks, cloneState); TableTypeVar* ttv = getMutable(counterCopy); REQUIRE(ttv != nullptr); @@ -142,11 +145,12 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_union") SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; unfreeze(typeChecker.globalTypes); TypeId oldUnion = typeChecker.globalTypes.addType(UnionTypeVar{{typeChecker.numberType, typeChecker.stringType}}); freeze(typeChecker.globalTypes); - TypeId newUnion = clone(oldUnion, dest, seenTypes, seenTypePacks); + TypeId newUnion = clone(oldUnion, dest, seenTypes, seenTypePacks, cloneState); CHECK_NE(newUnion, oldUnion); CHECK_EQ("number | string", toString(newUnion)); @@ -159,11 +163,12 @@ TEST_CASE_FIXTURE(Fixture, "deepClone_intersection") SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; unfreeze(typeChecker.globalTypes); TypeId oldIntersection = typeChecker.globalTypes.addType(IntersectionTypeVar{{typeChecker.numberType, typeChecker.stringType}}); freeze(typeChecker.globalTypes); - TypeId newIntersection = clone(oldIntersection, dest, seenTypes, seenTypePacks); + TypeId newIntersection = clone(oldIntersection, dest, seenTypes, seenTypePacks, cloneState); CHECK_NE(newIntersection, oldIntersection); CHECK_EQ("number & string", toString(newIntersection)); @@ -188,8 +193,9 @@ TEST_CASE_FIXTURE(Fixture, "clone_class") SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; - TypeId cloned = clone(&exampleClass, dest, seenTypes, seenTypePacks); + TypeId cloned = clone(&exampleClass, dest, seenTypes, seenTypePacks, cloneState); const ClassTypeVar* ctv = get(cloned); REQUIRE(ctv != nullptr); @@ -211,16 +217,16 @@ TEST_CASE_FIXTURE(Fixture, "clone_sanitize_free_types") TypeArena dest; SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; - bool encounteredFreeType = false; - TypeId clonedTy = clone(&freeTy, dest, seenTypes, seenTypePacks, &encounteredFreeType); + TypeId clonedTy = clone(&freeTy, dest, seenTypes, seenTypePacks, cloneState); CHECK_EQ("any", toString(clonedTy)); - CHECK(encounteredFreeType); + CHECK(cloneState.encounteredFreeType); - encounteredFreeType = false; - TypePackId clonedTp = clone(&freeTp, dest, seenTypes, seenTypePacks, &encounteredFreeType); + cloneState = {}; + TypePackId clonedTp = clone(&freeTp, dest, seenTypes, seenTypePacks, cloneState); CHECK_EQ("...any", toString(clonedTp)); - CHECK(encounteredFreeType); + CHECK(cloneState.encounteredFreeType); } TEST_CASE_FIXTURE(Fixture, "clone_seal_free_tables") @@ -232,12 +238,12 @@ TEST_CASE_FIXTURE(Fixture, "clone_seal_free_tables") TypeArena dest; SeenTypes seenTypes; SeenTypePacks seenTypePacks; + CloneState cloneState; - bool encounteredFreeType = false; - TypeId cloned = clone(&tableTy, dest, seenTypes, seenTypePacks, &encounteredFreeType); + TypeId cloned = clone(&tableTy, dest, seenTypes, seenTypePacks, cloneState); const TableTypeVar* clonedTtv = get(cloned); CHECK_EQ(clonedTtv->state, TableState::Sealed); - CHECK(encounteredFreeType); + CHECK(cloneState.encounteredFreeType); } TEST_CASE_FIXTURE(Fixture, "clone_self_property") @@ -267,4 +273,34 @@ TEST_CASE_FIXTURE(Fixture, "clone_self_property") "dot or pass 1 extra nil to suppress this warning"); } +TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") +{ +#if defined(_DEBUG) || defined(_NOOPT) + int limit = 250; +#else + int limit = 500; +#endif + ScopedFastInt luauTypeCloneRecursionLimit{"LuauTypeCloneRecursionLimit", limit}; + + TypeArena src; + + TypeId table = src.addType(TableTypeVar{}); + TypeId nested = table; + + for (unsigned i = 0; i < limit + 100; i++) + { + TableTypeVar* ttv = getMutable(nested); + + ttv->props["a"].type = src.addType(TableTypeVar{}); + nested = ttv->props["a"].type; + } + + TypeArena dest; + SeenTypes seenTypes; + SeenTypePacks seenTypePacks; + CloneState cloneState; + + CHECK_THROWS_AS(clone(table, dest, seenTypes, seenTypePacks, cloneState), std::runtime_error); +} + TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index e3e6ce6d8..72d3a9a64 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2518,8 +2518,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") TEST_CASE_FIXTURE(Fixture, "parse_type_pack_type_parameters") { - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - AstStat* stat = parse(R"( type Packed = () -> T... diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp new file mode 100644 index 000000000..0ca9c9949 --- /dev/null +++ b/tests/ToDot.test.cpp @@ -0,0 +1,366 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/Scope.h" +#include "Luau/ToDot.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +struct ToDotClassFixture : Fixture +{ + ToDotClassFixture() + { + TypeArena& arena = typeChecker.globalTypes; + + unfreeze(arena); + + TypeId baseClassMetaType = arena.addType(TableTypeVar{}); + + TypeId baseClassInstanceType = arena.addType(ClassTypeVar{"BaseClass", {}, std::nullopt, baseClassMetaType, {}, {}}); + getMutable(baseClassInstanceType)->props = { + {"BaseField", {typeChecker.numberType}}, + }; + typeChecker.globalScope->exportedTypeBindings["BaseClass"] = TypeFun{{}, baseClassInstanceType}; + + TypeId childClassInstanceType = arena.addType(ClassTypeVar{"ChildClass", {}, baseClassInstanceType, std::nullopt, {}, {}}); + getMutable(childClassInstanceType)->props = { + {"ChildField", {typeChecker.stringType}}, + }; + typeChecker.globalScope->exportedTypeBindings["ChildClass"] = TypeFun{{}, childClassInstanceType}; + + freeze(arena); + } +}; + +TEST_SUITE_BEGIN("ToDot"); + +TEST_CASE_FIXTURE(Fixture, "primitive") +{ + CheckResult result = check(R"( +local a: nil +local b: number +local c: any +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_NE("nil", toDot(requireType("a"))); + + CHECK_EQ(R"(digraph graphname { +n1 [label="number"]; +})", + toDot(requireType("b"))); + + CHECK_EQ(R"(digraph graphname { +n1 [label="any"]; +})", + toDot(requireType("c"))); + + ToDotOptions opts; + opts.showPointers = false; + opts.duplicatePrimitives = false; + + CHECK_EQ(R"(digraph graphname { +n1 [label="PrimitiveTypeVar number"]; +})", + toDot(requireType("b"), opts)); + + CHECK_EQ(R"(digraph graphname { +n1 [label="AnyTypeVar 1"]; +})", + toDot(requireType("c"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "bound") +{ + CheckResult result = check(R"( +local a = 444 +local b = a +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = getType("b"); + REQUIRE(bool(ty)); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="BoundTypeVar 1"]; +n1 -> n2; +n2 [label="number"]; +})", + toDot(*ty, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "function") +{ + ScopedFastFlag luauQuantifyInPlace2{"LuauQuantifyInPlace2", true}; + + CheckResult result = check(R"( +local function f(a, ...: string) return a end +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="FunctionTypeVar 1"]; +n1 -> n2 [label="arg"]; +n2 [label="TypePack 2"]; +n2 -> n3; +n3 [label="GenericTypeVar 3"]; +n2 -> n4 [label="tail"]; +n4 [label="VariadicTypePack 4"]; +n4 -> n5; +n5 [label="string"]; +n1 -> n6 [label="ret"]; +n6 [label="BoundTypePack 6"]; +n6 -> n7; +n7 [label="TypePack 7"]; +n7 -> n3; +})", + toDot(requireType("f"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "union") +{ + CheckResult result = check(R"( +local a: string | number +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="UnionTypeVar 1"]; +n1 -> n2; +n2 [label="string"]; +n1 -> n3; +n3 [label="number"]; +})", + toDot(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "intersection") +{ + CheckResult result = check(R"( +local a: string & number -- uninhabited +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="IntersectionTypeVar 1"]; +n1 -> n2; +n2 [label="string"]; +n1 -> n3; +n3 [label="number"]; +})", + toDot(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "table") +{ + CheckResult result = check(R"( +type A = { x: T, y: (U...) -> (), [string]: any } +local a: A +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="TableTypeVar A"]; +n1 -> n2 [label="x"]; +n2 [label="number"]; +n1 -> n3 [label="y"]; +n3 [label="FunctionTypeVar 3"]; +n3 -> n4 [label="arg"]; +n4 [label="VariadicTypePack 4"]; +n4 -> n5; +n5 [label="string"]; +n3 -> n6 [label="ret"]; +n6 [label="TypePack 6"]; +n1 -> n7 [label="[index]"]; +n7 [label="string"]; +n1 -> n8 [label="[value]"]; +n8 [label="any"]; +n1 -> n9 [label="typeParam"]; +n9 [label="number"]; +n1 -> n4 [label="typePackParam"]; +})", + toDot(requireType("a"), opts)); + + // Extra coverage with pointers (unstable values) + (void)toDot(requireType("a")); +} + +TEST_CASE_FIXTURE(Fixture, "metatable") +{ + CheckResult result = check(R"( +local a: typeof(setmetatable({}, {})) +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="MetatableTypeVar 1"]; +n1 -> n2 [label="table"]; +n2 [label="TableTypeVar 2"]; +n1 -> n3 [label="metatable"]; +n3 [label="TableTypeVar 3"]; +})", + toDot(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "free") +{ + TypeVar type{TypeVariant{FreeTypeVar{TypeLevel{0, 0}}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="FreeTypeVar 1"]; +})", + toDot(&type, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "error") +{ + TypeVar type{TypeVariant{ErrorTypeVar{}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="ErrorTypeVar 1"]; +})", + toDot(&type, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "generic") +{ + TypeVar type{TypeVariant{GenericTypeVar{"T"}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="GenericTypeVar T"]; +})", + toDot(&type, opts)); +} + +TEST_CASE_FIXTURE(ToDotClassFixture, "class") +{ + CheckResult result = check(R"( +local a: ChildClass +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="ClassTypeVar ChildClass"]; +n1 -> n2 [label="ChildField"]; +n2 [label="string"]; +n1 -> n3 [label="[parent]"]; +n3 [label="ClassTypeVar BaseClass"]; +n3 -> n4 [label="BaseField"]; +n4 [label="number"]; +n3 -> n5 [label="[metatable]"]; +n5 [label="TableTypeVar 5"]; +})", + toDot(requireType("a"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "free_pack") +{ + TypePackVar pack{TypePackVariant{FreeTypePack{TypeLevel{0, 0}}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="FreeTypePack 1"]; +})", + toDot(&pack, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "error_pack") +{ + TypePackVar pack{TypePackVariant{Unifiable::Error{}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="ErrorTypePack 1"]; +})", + toDot(&pack, opts)); + + // Extra coverage with pointers (unstable values) + (void)toDot(&pack); +} + +TEST_CASE_FIXTURE(Fixture, "generic_pack") +{ + TypePackVar pack1{TypePackVariant{GenericTypePack{}}}; + TypePackVar pack2{TypePackVariant{GenericTypePack{"T"}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="GenericTypePack 1"]; +})", + toDot(&pack1, opts)); + + CHECK_EQ(R"(digraph graphname { +n1 [label="GenericTypePack T"]; +})", + toDot(&pack2, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "bound_pack") +{ + TypePackVar pack{TypePackVariant{TypePack{{typeChecker.numberType}, {}}}}; + TypePackVar bound{TypePackVariant{BoundTypePack{&pack}}}; + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="BoundTypePack 1"]; +n1 -> n2; +n2 [label="TypePack 2"]; +n2 -> n3; +n3 [label="number"]; +})", + toDot(&bound, opts)); +} + +TEST_CASE_FIXTURE(Fixture, "bound_table") +{ + CheckResult result = check(R"( +local a = {x=2} +local b +b.x = 2 +b = a +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = getType("b"); + REQUIRE(bool(ty)); + + ToDotOptions opts; + opts.showPointers = false; + CHECK_EQ(R"(digraph graphname { +n1 [label="TableTypeVar 1"]; +n1 -> n2 [label="boundTo"]; +n2 [label="TableTypeVar a"]; +n2 -> n3 [label="x"]; +n3 [label="number"]; +})", + toDot(*ty, opts)); +} + +TEST_SUITE_END(); diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 928c03a31..327fa0bbd 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -445,9 +445,6 @@ local a: Import.Type TEST_CASE_FIXTURE(Fixture, "transpile_type_packs") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - std::string code = R"( type Packed = (T...)->(T...) local a: Packed<> diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 74ce155c2..822bd727e 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -537,8 +537,6 @@ TEST_CASE_FIXTURE(Fixture, "free_variables_from_typeof_in_aliases") TEST_CASE_FIXTURE(Fixture, "non_recursive_aliases_that_reuse_a_generic_name") { - ScopedFastFlag sff1{"LuauSubstitutionDontReplaceIgnoredTypes", true}; - CheckResult result = check(R"( type Array = { [number]: T } type Tuple = Array diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 88c2dc85d..aba508918 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -609,8 +609,6 @@ TEST_CASE_FIXTURE(Fixture, "typefuns_sharing_types") TEST_CASE_FIXTURE(Fixture, "bound_tables_do_not_clone_original_fields") { - ScopedFastFlag luauCloneBoundTables{"LuauCloneBoundTables", true}; - CheckResult result = check(R"( local exports = {} local nested = {} @@ -627,4 +625,23 @@ return exports LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "instantiated_function_argument_names") +{ + ScopedFastFlag luauFunctionArgumentNameSize{"LuauFunctionArgumentNameSize", true}; + + CheckResult result = check(R"( +local function f(a: T, ...: U...) end + +f(1, 2, 3) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + auto ty = findTypeAtPosition(Position(3, 0)); + REQUIRE(ty); + ToStringOptions opts; + opts.functionTypeArguments = true; + CHECK_EQ(toString(*ty, opts), "(a: number, number, number) -> ()"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index e5c14dde8..e6d3d4d47 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -31,8 +31,6 @@ TEST_SUITE_BEGIN("ProvisionalTests"); */ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - const std::string code = R"( function f(a) if type(a) == "boolean" then diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index c3694be79..cb72faaf4 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2022,4 +2022,74 @@ caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()')"); } +TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") +{ + ScopedFastFlag sffs[] { + {"LuauPropertiesGetExpectedType", true}, + {"LuauExpectedTypesOfProperties", true}, + {"LuauTableSubtypingVariance", true}, + }; + + CheckResult result = check(R"( +--!strict +type Super = { x : number } +type Sub = { x : number, y: number } +type HasSuper = { p : Super } +type HasSub = { p : Sub } +local a: HasSuper = { p = { x = 5, y = 7 }} +a.p = { x = 9 } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") +{ + ScopedFastFlag sffs[] { + {"LuauPropertiesGetExpectedType", true}, + {"LuauExpectedTypesOfProperties", true}, + {"LuauTableSubtypingVariance", true}, + {"LuauExtendedTypeMismatchError", true}, + }; + + CheckResult result = check(R"( +--!strict +type Super = { x : number } +type Sub = { x : number, y: number } +type HasSuper = { p : Super } +type HasSub = { p : Sub } +local tmp = { p = { x = 5, y = 7 }} +local a: HasSuper = tmp +a.p = { x = 9 } +-- needs to be an error because +local y: number = tmp.p.y + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'tmp' could not be converted into 'HasSuper' +caused by: + Property 'p' is not compatible. Table type '{| x: number, y: number |}' not compatible with type 'Super' because the former has extra field 'y')"); +} + +TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") +{ + ScopedFastFlag sffs[] { + {"LuauPropertiesGetExpectedType", true}, + {"LuauExpectedTypesOfProperties", true}, + {"LuauTableSubtypingVariance", true}, + }; + + CheckResult result = check(R"( +--!strict +type Super = { x : number } +type Sub = { x : number, y: number } +type HasSuper = { [string] : Super } +type HasSub = { [string] : Sub } +local a: HasSuper = { p = { x = 5, y = 7 }} +a.p = { x = 9 } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 99fd8339c..e3222a410 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -4779,4 +4779,24 @@ local bar = foo.nutrition + 100 // CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'foo'", toString(result.errors[1])); } +TEST_CASE_FIXTURE(Fixture, "require_failed_module") +{ + ScopedFastFlag luauModuleRequireErrorPack{"LuauModuleRequireErrorPack", true}; + + fileResolver.source["game/A"] = R"( +return unfortunately() + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_ERRORS(aResult); + + CheckResult result = check(R"( +local ModuleA = require(game.A) + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional oty = requireType("ModuleA"); + CHECK_EQ("*unknown*", toString(*oty)); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index c6de0abf5..3f4420cda 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -296,9 +296,6 @@ end TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type Packed = (T...) -> T... local a: Packed<> @@ -360,9 +357,6 @@ local c: Packed TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_import") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - fileResolver.source["game/A"] = R"( export type Packed = { a: T, b: (U...) -> () } return {} @@ -393,9 +387,6 @@ local d: { a: typeof(c) } TEST_CASE_FIXTURE(Fixture, "type_pack_type_parameters") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - fileResolver.source["game/A"] = R"( export type Packed = { a: T, b: (U...) -> () } return {} @@ -431,9 +422,6 @@ type C = Import.Packed TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_nested") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type Packed1 = (T...) -> (T...) type Packed2 = (Packed1, T...) -> (Packed1, T...) @@ -452,9 +440,6 @@ type Packed4 = (Packed3, T...) -> (Packed3, T...) TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_variadic") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type X = (T...) -> (string, T...) @@ -470,9 +455,6 @@ type E = X<(number, ...string)> TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_multi") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type Y = (T...) -> (U...) type A = Y @@ -501,9 +483,6 @@ type I = W TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type X = (T...) -> (T...) @@ -527,9 +506,6 @@ type F = X<(string, ...number)> TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type Y = (T...) -> (U...) @@ -549,9 +525,6 @@ type D = Y TEST_CASE_FIXTURE(Fixture, "type_alias_type_pack_explicit_multi_tostring") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type Y = { f: (T...) -> (U...) } @@ -567,9 +540,6 @@ local b: Y<(), ()> TEST_CASE_FIXTURE(Fixture, "type_alias_backwards_compatible") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type X = () -> T type Y = (T) -> U @@ -588,9 +558,6 @@ type C = Y<(number), boolean> TEST_CASE_FIXTURE(Fixture, "type_alias_type_packs_errors") { - ScopedFastFlag luauTypeAliasPacks("LuauTypeAliasPacks", true); - ScopedFastFlag luauParseTypePackTypeParameters("LuauParseTypePackTypeParameters", true); - CheckResult result = check(R"( type Packed = (T, U) -> (V...) local b: Packed diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 91efa8188..13db923ed 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -3,6 +3,7 @@ #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" #include "Fixture.h" #include "ScopedFlags.h" @@ -323,4 +324,48 @@ TEST_CASE("tagging_props") CHECK(Luau::hasTag(prop, "foo")); } +struct VisitCountTracker +{ + std::unordered_map tyVisits; + std::unordered_map tpVisits; + + void cycle(TypeId) {} + void cycle(TypePackId) {} + + template + bool operator()(TypeId ty, const T& t) + { + tyVisits[ty]++; + return true; + } + + template + bool operator()(TypePackId tp, const T&) + { + tpVisits[tp]++; + return true; + } +}; + +TEST_CASE_FIXTURE(Fixture, "visit_once") +{ + CheckResult result = check(R"( +type T = { a: number, b: () -> () } +local b: (T, T, T) -> T +)"); + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId bType = requireType("b"); + + VisitCountTracker tester; + DenseHashSet seen{nullptr}; + visitTypeVarOnce(bType, tester, seen); + + for (auto [_, count] : tester.tyVisits) + CHECK_EQ(count, 1); + + for (auto [_, count] : tester.tpVisits) + CHECK_EQ(count, 1); +} + TEST_SUITE_END(); diff --git a/tests/conformance/apicalls.lua b/tests/conformance/apicalls.lua index 5e03b0559..7a4058b52 100644 --- a/tests/conformance/apicalls.lua +++ b/tests/conformance/apicalls.lua @@ -2,7 +2,13 @@ print('testing function calls through API') function add(a, b) - return a + b + return a + b +end + +local m = { __eq = function(a, b) return a.a == b.a end } + +function create_with_tm(x) + return setmetatable({ a = x }, m) end return('OK') diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index 687fff1ee..188b8ebc4 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -441,7 +441,8 @@ assert((function() a = {} b = {} mt = { __eq = function(l, r) return #l == #r en assert((function() a = {} b = {} function eq(l, r) return #l == #r end setmetatable(a, {__eq = eq}) setmetatable(b, {__eq = eq}) return concat(a == b, a ~= b) end)() == "true,false") assert((function() a = {} b = {} setmetatable(a, {__eq = function(l, r) return #l == #r end}) setmetatable(b, {__eq = function(l, r) return #l == #r end}) return concat(a == b, a ~= b) end)() == "false,true") --- userdata, reference equality (no mt) +-- userdata, reference equality (no mt or mt.__eq) +assert((function() a = newproxy() return concat(a == newproxy(),a ~= newproxy()) end)() == "false,true") assert((function() a = newproxy(true) return concat(a == newproxy(true),a ~= newproxy(true)) end)() == "false,true") -- rawequal @@ -876,4 +877,4 @@ assert(concat(typeof(5), typeof(nil), typeof({}), typeof(newproxy())) == "number testgetfenv() -- DONT MOVE THIS LINE -return'OK' +return 'OK' diff --git a/tests/conformance/closure.lua b/tests/conformance/closure.lua index aac42c56f..f32d5bdcb 100644 --- a/tests/conformance/closure.lua +++ b/tests/conformance/closure.lua @@ -419,11 +419,5 @@ co = coroutine.create(function () return loadstring("return a")() end) -a = {a = 15} --- debug.setfenv(co, a) --- assert(debug.getfenv(co) == a) --- assert(select(2, coroutine.resume(co)) == a) --- assert(select(2, coroutine.resume(co)) == a.a) - -return'OK' +return 'OK' diff --git a/tests/conformance/constructs.lua b/tests/conformance/constructs.lua index 16c63b00f..f133501f1 100644 --- a/tests/conformance/constructs.lua +++ b/tests/conformance/constructs.lua @@ -237,4 +237,4 @@ repeat i = i+1 until i==c -return'OK' +return 'OK' diff --git a/tests/conformance/coroutine.lua b/tests/conformance/coroutine.lua index 4d9b12953..f2ecc96bb 100644 --- a/tests/conformance/coroutine.lua +++ b/tests/conformance/coroutine.lua @@ -373,4 +373,4 @@ do assert(f() == 42) end -return'OK' +return 'OK' diff --git a/tests/conformance/datetime.lua b/tests/conformance/datetime.lua index 21ef60d74..ca35cf2f1 100644 --- a/tests/conformance/datetime.lua +++ b/tests/conformance/datetime.lua @@ -74,4 +74,4 @@ assert(os.difftime(t1,t2) == 60*2-19) assert(os.time({ year = 1970, day = 1, month = 1, hour = 0}) == 0) -return'OK' +return 'OK' diff --git a/tests/conformance/debug.lua b/tests/conformance/debug.lua index ee79a14fd..9cf3c7423 100644 --- a/tests/conformance/debug.lua +++ b/tests/conformance/debug.lua @@ -98,4 +98,4 @@ assert(quuz(function(...) end) == "0 true") assert(quuz(function(a, b) end) == "2 false") assert(quuz(function(a, b, ...) end) == "2 true") -return'OK' +return 'OK' diff --git a/tests/conformance/errors.lua b/tests/conformance/errors.lua index eded14e9d..d5ff215b4 100644 --- a/tests/conformance/errors.lua +++ b/tests/conformance/errors.lua @@ -34,15 +34,15 @@ assert(doit("error('hi', 0)") == 'hi') assert(doit("unpack({}, 1, n=2^30)")) assert(doit("a=math.sin()")) assert(not doit("tostring(1)") and doit("tostring()")) -assert(doit"tonumber()") -assert(doit"repeat until 1; a") +assert(doit("tonumber()")) +assert(doit("repeat until 1; a")) checksyntax("break label", "", "label", 1) -assert(doit";") -assert(doit"a=1;;") -assert(doit"return;;") -assert(doit"assert(false)") -assert(doit"assert(nil)") -assert(doit"a=math.sin\n(3)") +assert(doit(";")) +assert(doit("a=1;;")) +assert(doit("return;;")) +assert(doit("assert(false)")) +assert(doit("assert(nil)")) +assert(doit("a=math.sin\n(3)")) assert(doit("function a (... , ...) end")) assert(doit("function a (, ...) end")) @@ -59,7 +59,7 @@ checkmessage("a=1; local a,bbbb=2,3; a = math.sin(1) and bbbb(3)", "local 'bbbb'") checkmessage("a={}; do local a=1 end a:bbbb(3)", "method 'bbbb'") checkmessage("local a={}; a.bbbb(3)", "field 'bbbb'") -assert(not string.find(doit"a={13}; local bbbb=1; a[bbbb](3)", "'bbbb'")) +assert(not string.find(doit("a={13}; local bbbb=1; a[bbbb](3)"), "'bbbb'")) checkmessage("a={13}; local bbbb=1; a[bbbb](3)", "number") aaa = nil @@ -67,14 +67,14 @@ checkmessage("aaa.bbb:ddd(9)", "global 'aaa'") checkmessage("local aaa={bbb=1}; aaa.bbb:ddd(9)", "field 'bbb'") checkmessage("local aaa={bbb={}}; aaa.bbb:ddd(9)", "method 'ddd'") checkmessage("local a,b,c; (function () a = b+1 end)()", "upvalue 'b'") -assert(not doit"local aaa={bbb={ddd=next}}; aaa.bbb:ddd(nil)") +assert(not doit("local aaa={bbb={ddd=next}}; aaa.bbb:ddd(nil)")) checkmessage("b=1; local aaa='a'; x=aaa+b", "local 'aaa'") checkmessage("aaa={}; x=3/aaa", "global 'aaa'") checkmessage("aaa='2'; b=nil;x=aaa*b", "global 'b'") checkmessage("aaa={}; x=-aaa", "global 'aaa'") -assert(not string.find(doit"aaa={}; x=(aaa or aaa)+(aaa and aaa)", "'aaa'")) -assert(not string.find(doit"aaa={}; (aaa or aaa)()", "'aaa'")) +assert(not string.find(doit("aaa={}; x=(aaa or aaa)+(aaa and aaa)"), "'aaa'")) +assert(not string.find(doit("aaa={}; (aaa or aaa)()"), "'aaa'")) checkmessage([[aaa=9 repeat until 3==3 @@ -122,10 +122,10 @@ function lineerror (s) return line and line+0 end -assert(lineerror"local a\n for i=1,'a' do \n print(i) \n end" == 2) --- assert(lineerror"\n local a \n for k,v in 3 \n do \n print(k) \n end" == 3) --- assert(lineerror"\n\n for k,v in \n 3 \n do \n print(k) \n end" == 4) -assert(lineerror"function a.x.y ()\na=a+1\nend" == 1) +assert(lineerror("local a\n for i=1,'a' do \n print(i) \n end") == 2) +-- assert(lineerror("\n local a \n for k,v in 3 \n do \n print(k) \n end") == 3) +-- assert(lineerror("\n\n for k,v in \n 3 \n do \n print(k) \n end") == 4) +assert(lineerror("function a.x.y ()\na=a+1\nend") == 1) local p = [[ function g() f() end diff --git a/tests/conformance/gc.lua b/tests/conformance/gc.lua index 4263dfda7..6d9eb8544 100644 --- a/tests/conformance/gc.lua +++ b/tests/conformance/gc.lua @@ -77,7 +77,7 @@ end local function dosteps (siz) collectgarbage() - collectgarbage"stop" + collectgarbage("stop") local a = {} for i=1,100 do a[i] = {{}}; local b = {} end local x = gcinfo() @@ -99,11 +99,11 @@ assert(dosteps(10000) == 1) do local x = gcinfo() collectgarbage() - collectgarbage"stop" + collectgarbage("stop") repeat local a = {} until gcinfo() > 1000 - collectgarbage"restart" + collectgarbage("restart") repeat local a = {} until gcinfo() < 1000 @@ -123,7 +123,7 @@ for n in pairs(b) do end b = nil collectgarbage() -for n in pairs(a) do error'cannot be here' end +for n in pairs(a) do error("cannot be here") end for i=1,lim do a[i] = i end for i=1,lim do assert(a[i] == i) end diff --git a/tests/conformance/nextvar.lua b/tests/conformance/nextvar.lua index 7f9b75962..94ba5ccfb 100644 --- a/tests/conformance/nextvar.lua +++ b/tests/conformance/nextvar.lua @@ -368,9 +368,9 @@ assert(next(a,nil) == 1000 and next(a,1000) == nil) assert(next({}) == nil) assert(next({}, nil) == nil) -for a,b in pairs{} do error"not here" end -for i=1,0 do error'not here' end -for i=0,1,-1 do error'not here' end +for a,b in pairs{} do error("not here") end +for i=1,0 do error("not here") end +for i=0,1,-1 do error("not here") end a = nil; for i=1,1 do assert(not a); a=1 end; assert(a) a = nil; for i=1,1,-1 do assert(not a); a=1 end; assert(a) diff --git a/tests/conformance/pcall.lua b/tests/conformance/pcall.lua index a2072d2c8..84ac2ba19 100644 --- a/tests/conformance/pcall.lua +++ b/tests/conformance/pcall.lua @@ -144,4 +144,4 @@ coroutine.resume(co) resumeerror(co, "fail") checkresults({ true, false, "fail" }, coroutine.resume(co)) -return'OK' +return 'OK' diff --git a/tests/conformance/utf8.lua b/tests/conformance/utf8.lua index 024cb16d7..bfd7a1ac8 100644 --- a/tests/conformance/utf8.lua +++ b/tests/conformance/utf8.lua @@ -205,4 +205,4 @@ for p, c in string.gmatch(x, "()(" .. utf8.charpattern .. ")") do end end -return'OK' +return 'OK' diff --git a/tests/conformance/vector.lua b/tests/conformance/vector.lua index 620f646aa..7d18bda33 100644 --- a/tests/conformance/vector.lua +++ b/tests/conformance/vector.lua @@ -1,6 +1,9 @@ -- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details print('testing vectors') +-- detect vector size +local vector_size = if pcall(function() return vector(0, 0, 0).w end) then 4 else 3 + -- equality assert(vector(1, 2, 3) == vector(1, 2, 3)) assert(vector(0, 1, 2) == vector(-0, 1, 2)) @@ -13,8 +16,14 @@ assert(not rawequal(vector(1, 2, 3), vector(1, 2, 4))) -- type & tostring assert(type(vector(1, 2, 3)) == "vector") -assert(tostring(vector(1, 2, 3)) == "1, 2, 3") -assert(tostring(vector(-1, 2, 0.5)) == "-1, 2, 0.5") + +if vector_size == 4 then + assert(tostring(vector(1, 2, 3, 4)) == "1, 2, 3, 4") + assert(tostring(vector(-1, 2, 0.5, 0)) == "-1, 2, 0.5, 0") +else + assert(tostring(vector(1, 2, 3)) == "1, 2, 3") + assert(tostring(vector(-1, 2, 0.5)) == "-1, 2, 0.5") +end local t = {} @@ -42,12 +51,19 @@ assert(8 * vector(8, 16, 24) == vector(64, 128, 192)); assert(vector(1, 2, 4) * '8' == vector(8, 16, 32)); assert('8' * vector(8, 16, 24) == vector(64, 128, 192)); -assert(vector(1, 2, 4) / vector(8, 16, 24) == vector(1/8, 2/16, 4/24)); +if vector_size == 4 then + assert(vector(1, 2, 4, 8) / vector(8, 16, 24, 32) == vector(1/8, 2/16, 4/24, 8/32)); + assert(8 / vector(8, 16, 24, 32) == vector(1, 1/2, 1/3, 1/4)); + assert('8' / vector(8, 16, 24, 32) == vector(1, 1/2, 1/3, 1/4)); +else + assert(vector(1, 2, 4) / vector(8, 16, 24, 1) == vector(1/8, 2/16, 4/24)); + assert(8 / vector(8, 16, 24) == vector(1, 1/2, 1/3)); + assert('8' / vector(8, 16, 24) == vector(1, 1/2, 1/3)); +end + assert(vector(1, 2, 4) / 8 == vector(1/8, 1/4, 1/2)); assert(vector(1, 2, 4) / (1 / val) == vector(1/8, 2/8, 4/8)); -assert(8 / vector(8, 16, 24) == vector(1, 1/2, 1/3)); assert(vector(1, 2, 4) / '8' == vector(1/8, 1/4, 1/2)); -assert('8' / vector(8, 16, 24) == vector(1, 1/2, 1/3)); assert(-vector(1, 2, 4) == vector(-1, -2, -4)); @@ -71,4 +87,9 @@ assert(pcall(function() local t = {} rawset(t, vector(0/0, 2, 3), 1) end) == fal -- make sure we cover both builtin and C impl assert(vector(1, 2, 4) == vector("1", "2", "4")) +-- additional checks for 4-component vectors +if vector_size == 4 then + assert(vector(1, 2, 3, 4).w == 4) +end + return 'OK' diff --git a/tools/svg.py b/tools/svg.py index 3b3bb28c4..99853fb6e 100644 --- a/tools/svg.py +++ b/tools/svg.py @@ -458,13 +458,16 @@ def display(root, title, colors, flip = False): framewidth = 1200 - 20 + def pixels(x): + return float(x) / root.width * framewidth if root.width > 0 else 0 + for n in root.subtree(): - if n.width / root.width * framewidth < 0.1: + if pixels(n.width) < 0.1: continue - x = 10 + n.offset / root.width * framewidth + x = 10 + pixels(n.offset) y = (maxdepth - 1 - n.depth if flip else n.depth) * 16 + 3 * 16 - width = n.width / root.width * framewidth + width = pixels(n.width) height = 15 if colors == "cold": From e440729e2bb98aba0bdb41109e257de522b857e4 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 2 Dec 2021 15:46:33 -0800 Subject: [PATCH 09/32] Fix signed/unsigned mismatch warning + lower limit to match upstream --- tests/Module.test.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 2800d2fe6..e3993cc53 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -278,7 +278,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") #if defined(_DEBUG) || defined(_NOOPT) int limit = 250; #else - int limit = 500; + int limit = 400; #endif ScopedFastInt luauTypeCloneRecursionLimit{"LuauTypeCloneRecursionLimit", limit}; @@ -287,7 +287,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") TypeId table = src.addType(TableTypeVar{}); TypeId nested = table; - for (unsigned i = 0; i < limit + 100; i++) + for (int i = 0; i < limit + 100; i++) { TableTypeVar* ttv = getMutable(nested); From a8673f0f99885da92597d6efb4feaf0f14d59991 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 10 Dec 2021 13:17:10 -0800 Subject: [PATCH 10/32] Sync to upstream/release/507-pre This doesn't contain all changes for 507 yet but we might want to do the Luau 0.507 release a bit earlier to end the year sooner. --- Analysis/include/Luau/Error.h | 11 +- Analysis/include/Luau/IostreamHelpers.h | 4 + Analysis/include/Luau/ToString.h | 4 +- Analysis/include/Luau/TypeInfer.h | 7 +- Analysis/include/Luau/TypeVar.h | 8 +- Analysis/include/Luau/Unifiable.h | 11 +- Analysis/include/Luau/Unifier.h | 3 + Analysis/src/Autocomplete.cpp | 115 +++++++-- Analysis/src/BuiltinDefinitions.cpp | 9 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 12 +- Analysis/src/Error.cpp | 17 +- Analysis/src/IostreamHelpers.cpp | 6 + Analysis/src/JsonEncoder.cpp | 4 +- Analysis/src/Module.cpp | 19 +- Analysis/src/Predicate.cpp | 2 +- Analysis/src/ToString.cpp | 38 ++- Analysis/src/TypeInfer.cpp | 189 +++++++++----- Analysis/src/TypeUtils.cpp | 6 +- Analysis/src/TypeVar.cpp | 220 ++++------------ Analysis/src/Unifier.cpp | 266 +++++++++++++++++--- Ast/src/Parser.cpp | 5 +- CLI/Analyze.cpp | 20 +- CLI/Coverage.cpp | 88 +++++++ CLI/Coverage.h | 10 + CLI/FileUtils.cpp | 4 + CLI/Profiler.cpp | 8 +- CLI/Profiler.h | 2 +- CLI/Repl.cpp | 35 ++- Compiler/src/Compiler.cpp | 22 +- Makefile | 6 +- Sources.cmake | 4 + VM/include/lua.h | 4 + VM/src/lapi.cpp | 162 ++++++------ VM/src/ldebug.cpp | 63 +++++ VM/src/lgc.cpp | 49 +--- VM/src/lgcdebug.cpp | 1 + VM/src/lobject.h | 10 +- VM/src/lstring.cpp | 29 --- VM/src/lstring.h | 7 - VM/src/ludata.cpp | 37 +++ VM/src/ludata.h | 13 + VM/src/lvmexecute.cpp | 14 +- VM/src/lvmutils.cpp | 4 +- bench/tests/sunspider/3d-raytrace.lua | 15 +- tests/Autocomplete.test.cpp | 55 +++- tests/Compiler.test.cpp | 49 +--- tests/Conformance.test.cpp | 100 ++++++-- tests/Module.test.cpp | 4 +- tests/Parser.test.cpp | 4 - tests/TypeInfer.annotations.test.cpp | 30 +++ tests/TypeInfer.builtins.test.cpp | 26 ++ tests/TypeInfer.generics.test.cpp | 40 +++ tests/TypeInfer.refinements.test.cpp | 17 +- tests/TypeInfer.singletons.test.cpp | 50 ++++ tests/TypeInfer.tables.test.cpp | 61 +++-- tests/TypeInfer.test.cpp | 239 +++++++++++++++++- tests/TypeInfer.tryUnify.test.cpp | 28 +++ tests/TypeInfer.unionTypes.test.cpp | 16 ++ tests/conformance/coverage.lua | 64 +++++ 59 files changed, 1703 insertions(+), 643 deletions(-) create mode 100644 CLI/Coverage.cpp create mode 100644 CLI/Coverage.h create mode 100644 VM/src/ludata.cpp create mode 100644 VM/src/ludata.h create mode 100644 tests/conformance/coverage.lua diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 9ee750043..aff3c4d9e 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -277,11 +277,20 @@ struct MissingUnionProperty bool operator==(const MissingUnionProperty& rhs) const; }; +struct TypesAreUnrelated +{ + TypeId left; + TypeId right; + + bool operator==(const TypesAreUnrelated& rhs) const; +}; + using TypeErrorData = Variant; + DuplicateGenericParameter, CannotInferBinaryOperation, MissingProperties, SwappedGenericTypeParameter, OptionalValueAccess, MissingUnionProperty, + TypesAreUnrelated>; struct TypeError { diff --git a/Analysis/include/Luau/IostreamHelpers.h b/Analysis/include/Luau/IostreamHelpers.h index f9e9cd48c..ee994296c 100644 --- a/Analysis/include/Luau/IostreamHelpers.h +++ b/Analysis/include/Luau/IostreamHelpers.h @@ -36,6 +36,10 @@ std::ostream& operator<<(std::ostream& lhs, const IllegalRequire& error); std::ostream& operator<<(std::ostream& lhs, const ModuleHasCyclicDependency& error); std::ostream& operator<<(std::ostream& lhs, const DuplicateGenericParameter& error); std::ostream& operator<<(std::ostream& lhs, const CannotInferBinaryOperation& error); +std::ostream& operator<<(std::ostream& lhs, const SwappedGenericTypeParameter& error); +std::ostream& operator<<(std::ostream& lhs, const OptionalValueAccess& error); +std::ostream& operator<<(std::ostream& lhs, const MissingUnionProperty& error); +std::ostream& operator<<(std::ostream& lhs, const TypesAreUnrelated& error); std::ostream& operator<<(std::ostream& lhs, const TableState& tv); std::ostream& operator<<(std::ostream& lhs, const TypeVar& tv); diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index 50379c1cd..a97bf6d6b 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -69,8 +69,8 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV // It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression -void dump(TypeId ty); -void dump(TypePackId ty); +std::string dump(TypeId ty); +std::string dump(TypePackId ty); std::string generateName(size_t n); diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 9f553bc14..451976e48 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -156,13 +156,14 @@ struct TypeChecker // Returns both the type of the lvalue and its binding (if the caller wants to mutate the binding). // Note: the binding may be null. + // TODO: remove second return value with FFlagLuauUpdateFunctionNameBinding std::pair checkLValueBinding(const ScopePtr& scope, const AstExpr& expr); std::pair checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr); std::pair checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr); std::pair checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr); std::pair checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr); - TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName); + TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level); std::pair checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, std::optional originalNameLoc, std::optional expectedType); void checkFunctionBody(const ScopePtr& scope, TypeId type, const AstExprFunction& function); @@ -174,7 +175,7 @@ struct TypeChecker ExprResult checkExprPack(const ScopePtr& scope, const AstExprCall& expr); std::vector> getExpectedTypesForCall(const std::vector& overloads, size_t argumentCount, bool selfCall); std::optional> checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, + TypePackId argPack, TypePack* args, const std::vector* argLocations, const ExprResult& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors); bool handleSelfCallMismatch(const ScopePtr& scope, const AstExprCall& expr, TypePack* args, const std::vector& argLocations, const std::vector& errors); @@ -277,7 +278,7 @@ struct TypeChecker [[noreturn]] void ice(const std::string& message); ScopePtr childFunctionScope(const ScopePtr& parent, const Location& location, int subLevel = 0); - ScopePtr childScope(const ScopePtr& parent, const Location& location, int subLevel = 0); + ScopePtr childScope(const ScopePtr& parent, const Location& location); // Wrapper for merge(l, r, toUnion) but without the lambda junk. void merge(RefinementMap& l, const RefinementMap& r); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 8c4c2f34f..f6829ec3e 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -499,6 +499,7 @@ struct SingletonTypes const TypePackId anyTypePack; SingletonTypes(); + ~SingletonTypes(); SingletonTypes(const SingletonTypes&) = delete; void operator=(const SingletonTypes&) = delete; @@ -509,10 +510,12 @@ struct SingletonTypes private: std::unique_ptr arena; + bool debugFreezeArena = false; + TypeId makeStringMetatable(); }; -extern SingletonTypes singletonTypes; +SingletonTypes& getSingletonTypes(); void persist(TypeId ty); void persist(TypePackId tp); @@ -523,9 +526,6 @@ TypeLevel* getMutableLevel(TypeId ty); const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name); bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent); -bool hasGeneric(TypeId ty); -bool hasGeneric(TypePackId tp); - TypeVar* asMutable(TypeId ty); template diff --git a/Analysis/include/Luau/Unifiable.h b/Analysis/include/Luau/Unifiable.h index b47610fca..e8eafe688 100644 --- a/Analysis/include/Luau/Unifiable.h +++ b/Analysis/include/Luau/Unifiable.h @@ -24,7 +24,7 @@ struct TypeLevel int level = 0; int subLevel = 0; - // Returns true if the typelevel "this" is "bigger" than rhs + // Returns true if the level of "this" belongs to an equal or larger scope than that of rhs bool subsumes(const TypeLevel& rhs) const { if (level < rhs.level) @@ -38,6 +38,15 @@ struct TypeLevel return false; } + // Returns true if the level of "this" belongs to a larger (not equal) scope than that of rhs + bool subsumesStrict(const TypeLevel& rhs) const + { + if (level == rhs.level && subLevel == rhs.subLevel) + return false; + else + return subsumes(rhs); + } + TypeLevel incr() const { TypeLevel result; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 4588cdd8c..7681b9662 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -91,6 +91,9 @@ struct Unifier [[noreturn]] void ice(const std::string& message, const Location& location); [[noreturn]] void ice(const std::string& message); + + // Available after regular type pack unification errors + std::optional firstPackErrorPos; }; } // namespace Luau diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index db2d1d0e5..4b583792c 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); LUAU_FASTFLAGVARIABLE(LuauAutocompletePreferToCallFunctions, false); +LUAU_FASTFLAGVARIABLE(LuauAutocompleteFirstArg, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -190,7 +191,48 @@ static ParenthesesRecommendation getParenRecommendation(TypeId id, const std::ve return ParenthesesRecommendation::None; } -static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typeArena, AstNode* node, TypeId ty) +static std::optional findExpectedTypeAt(const Module& module, AstNode* node, Position position) +{ + LUAU_ASSERT(FFlag::LuauAutocompleteFirstArg); + + auto expr = node->asExpr(); + if (!expr) + return std::nullopt; + + // Extra care for first function call argument location + // When we don't have anything inside () yet, we also don't have an AST node to base our lookup + if (AstExprCall* exprCall = expr->as()) + { + if (exprCall->args.size == 0 && exprCall->argLocation.contains(position)) + { + auto it = module.astTypes.find(exprCall->func); + + if (!it) + return std::nullopt; + + const FunctionTypeVar* ftv = get(follow(*it)); + + if (!ftv) + return std::nullopt; + + auto [head, tail] = flatten(ftv->argTypes); + unsigned index = exprCall->self ? 1 : 0; + + if (index < head.size()) + return head[index]; + + return std::nullopt; + } + } + + auto it = module.astExpectedTypes.find(expr); + if (!it) + return std::nullopt; + + return *it; +} + +static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typeArena, AstNode* node, Position position, TypeId ty) { ty = follow(ty); @@ -220,15 +262,29 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ } }; - auto expr = node->asExpr(); - if (!expr) - return TypeCorrectKind::None; + TypeId expectedType; - auto it = module.astExpectedTypes.find(expr); - if (!it) - return TypeCorrectKind::None; + if (FFlag::LuauAutocompleteFirstArg) + { + auto typeAtPosition = findExpectedTypeAt(module, node, position); + + if (!typeAtPosition) + return TypeCorrectKind::None; + + expectedType = follow(*typeAtPosition); + } + else + { + auto expr = node->asExpr(); + if (!expr) + return TypeCorrectKind::None; + + auto it = module.astExpectedTypes.find(expr); + if (!it) + return TypeCorrectKind::None; - TypeId expectedType = follow(*it); + expectedType = follow(*it); + } if (FFlag::LuauAutocompletePreferToCallFunctions) { @@ -333,8 +389,8 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId if (result.count(name) == 0 && name != Parser::errorName) { Luau::TypeId type = Luau::follow(prop.type); - TypeCorrectKind typeCorrect = - indexType == PropIndexType::Key ? TypeCorrectKind::Correct : checkTypeCorrectKind(module, typeArena, nodes.back(), type); + TypeCorrectKind typeCorrect = indexType == PropIndexType::Key ? TypeCorrectKind::Correct + : checkTypeCorrectKind(module, typeArena, nodes.back(), {{}, {}}, type); ParenthesesRecommendation parens = indexType == PropIndexType::Key ? ParenthesesRecommendation::None : getParenRecommendation(type, nodes, typeCorrect); @@ -692,17 +748,31 @@ std::optional returnFirstNonnullOptionOfType(const UnionTypeVar* utv) return ret; } -static std::optional functionIsExpectedAt(const Module& module, AstNode* node) +static std::optional functionIsExpectedAt(const Module& module, AstNode* node, Position position) { - auto expr = node->asExpr(); - if (!expr) - return std::nullopt; + TypeId expectedType; - auto it = module.astExpectedTypes.find(expr); - if (!it) - return std::nullopt; + if (FFlag::LuauAutocompleteFirstArg) + { + auto typeAtPosition = findExpectedTypeAt(module, node, position); - TypeId expectedType = follow(*it); + if (!typeAtPosition) + return std::nullopt; + + expectedType = follow(*typeAtPosition); + } + else + { + auto expr = node->asExpr(); + if (!expr) + return std::nullopt; + + auto it = module.astExpectedTypes.find(expr); + if (!it) + return std::nullopt; + + expectedType = follow(*it); + } if (get(expectedType)) return true; @@ -1171,7 +1241,7 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul std::string n = toString(name); if (!result.count(n)) { - TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, node, binding.typeId); + TypeCorrectKind typeCorrect = checkTypeCorrectKind(module, typeArena, node, position, binding.typeId); result[n] = {AutocompleteEntryKind::Binding, binding.typeId, binding.deprecated, false, typeCorrect, std::nullopt, std::nullopt, binding.documentationSymbol, {}, getParenRecommendation(binding.typeId, ancestry, typeCorrect)}; @@ -1181,9 +1251,10 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul scope = scope->parent; } - TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, typeChecker.nilType); - TypeCorrectKind correctForBoolean = checkTypeCorrectKind(module, typeArena, node, typeChecker.booleanType); - TypeCorrectKind correctForFunction = functionIsExpectedAt(module, node).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + TypeCorrectKind correctForNil = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.nilType); + TypeCorrectKind correctForBoolean = checkTypeCorrectKind(module, typeArena, node, position, typeChecker.booleanType); + TypeCorrectKind correctForFunction = + functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; if (FFlag::LuauIfElseExpressionAnalysisSupport) result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index bac94a2bd..d527414a1 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -217,9 +217,9 @@ void registerBuiltinTypes(TypeChecker& typeChecker) TypeId genericK = arena.addType(GenericTypeVar{"K"}); TypeId genericV = arena.addType(GenericTypeVar{"V"}); - TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level}); + TypeId mapOfKtoV = arena.addType(TableTypeVar{{}, TableIndexer(genericK, genericV), typeChecker.globalScope->level, TableState::Generic}); - std::optional stringMetatableTy = getMetatable(singletonTypes.stringType); + std::optional stringMetatableTy = getMetatable(getSingletonTypes().stringType); LUAU_ASSERT(stringMetatableTy); const TableTypeVar* stringMetatableTable = get(follow(*stringMetatableTy)); LUAU_ASSERT(stringMetatableTable); @@ -271,7 +271,10 @@ void registerBuiltinTypes(TypeChecker& typeChecker) persist(pair.second.typeId); if (TableTypeVar* ttv = getMutable(pair.second.typeId)) - ttv->name = toString(pair.first); + { + if (!ttv->name) + ttv->name = toString(pair.first); + } } attachMagicFunction(getGlobalBinding(typeChecker, "assert"), magicFunctionAssert); diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 9f5c82500..d0afa7424 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,6 +1,8 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" +LUAU_FASTFLAGVARIABLE(LuauFixTonumberReturnType, false) + namespace Luau { @@ -113,7 +115,6 @@ declare function gcinfo(): number declare function error(message: T, level: number?) declare function tostring(value: T): string - declare function tonumber(value: T, radix: number?): number declare function rawequal(a: T1, b: T2): boolean declare function rawget(tab: {[K]: V}, k: K): V @@ -204,7 +205,14 @@ declare function gcinfo(): number std::string getBuiltinDefinitionSource() { - return kBuiltinDefinitionLuaSrc; + std::string result = kBuiltinDefinitionLuaSrc; + + if (FFlag::LuauFixTonumberReturnType) + result += "declare function tonumber(value: T, radix: number?): number?\n"; + else + result += "declare function tonumber(value: T, radix: number?): number\n"; + + return result; } } // namespace Luau diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 8334bd626..ce832c6b3 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -58,7 +58,7 @@ struct ErrorConverter result += "\ncaused by:\n "; if (!tm.reason.empty()) - result += tm.reason + ". "; + result += tm.reason + " "; result += Luau::toString(*tm.error); } @@ -410,6 +410,11 @@ struct ErrorConverter return ss + " in the type '" + toString(e.type) + "'"; } + + std::string operator()(const TypesAreUnrelated& e) const + { + return "Cannot cast '" + toString(e.left) + "' into '" + toString(e.right) + "' because the types are unrelated"; + } }; struct InvalidNameChecker @@ -658,6 +663,11 @@ bool MissingUnionProperty::operator==(const MissingUnionProperty& rhs) const return *type == *rhs.type && key == rhs.key; } +bool TypesAreUnrelated::operator==(const TypesAreUnrelated& rhs) const +{ + return left == rhs.left && right == rhs.right; +} + std::string toString(const TypeError& error) { ErrorConverter converter; @@ -793,6 +803,11 @@ void copyError(T& e, TypeArena& destArena, SeenTypes& seenTypes, SeenTypePacks& for (auto& ty : e.missing) ty = clone(ty); } + else if constexpr (std::is_same_v) + { + e.left = clone(e.left); + e.right = clone(e.right); + } else static_assert(always_false_v, "Non-exhaustive type switch"); } diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index ac46b5a49..5bc76ade5 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -262,6 +262,12 @@ std::ostream& operator<<(std::ostream& stream, const MissingUnionProperty& error return stream << " }, key = '" + error.key + "' }"; } +std::ostream& operator<<(std::ostream& stream, const TypesAreUnrelated& error) +{ + stream << "TypesAreUnrelated { left = '" + toString(error.left) + "', right = '" + toString(error.right) + "' }"; + return stream; +} + std::ostream& operator<<(std::ostream& stream, const TableState& tv) { return stream << static_cast::type>(tv); diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index c7f623eea..23491a5a1 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -262,7 +262,7 @@ struct AstJsonEncoder : public AstVisitor if (comma) writeRaw(","); else - comma = false; + comma = true; write(a); } @@ -379,7 +379,7 @@ struct AstJsonEncoder : public AstVisitor if (comma) writeRaw(","); else - comma = false; + comma = true; write(prop); } }); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index b4b6eb425..e1e53c971 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -13,7 +13,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) -LUAU_FASTFLAG(LuauCaptureBrokenCommentSpans) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 0) namespace Luau @@ -23,7 +22,7 @@ static bool contains(Position pos, Comment comment) { if (comment.location.contains(pos)) return true; - else if (FFlag::LuauCaptureBrokenCommentSpans && comment.type == Lexeme::BrokenComment && + else if (comment.type == Lexeme::BrokenComment && comment.location.begin <= pos) // Broken comments are broken specifically because they don't have an end return true; else if (comment.type == Lexeme::Comment && comment.location.end == pos) @@ -194,7 +193,7 @@ struct TypePackCloner { cloneState.encounteredFreeType = true; - TypePackId err = singletonTypes.errorRecoveryTypePack(singletonTypes.anyTypePack); + TypePackId err = getSingletonTypes().errorRecoveryTypePack(getSingletonTypes().anyTypePack); TypePackId cloned = dest.addTypePack(*err); seenTypePacks[typePackId] = cloned; } @@ -247,7 +246,7 @@ void TypeCloner::defaultClone(const T& t) void TypeCloner::operator()(const Unifiable::Free& t) { cloneState.encounteredFreeType = true; - TypeId err = singletonTypes.errorRecoveryType(singletonTypes.anyType); + TypeId err = getSingletonTypes().errorRecoveryType(getSingletonTypes().anyType); TypeId cloned = dest.addType(*err); seenTypes[typeId] = cloned; } @@ -421,9 +420,6 @@ TypePackId clone(TypePackId tp, TypeArena& dest, SeenTypes& seenTypes, SeenTypeP Luau::visit(cloner, tp->ty); // Mutates the storage that 'res' points into. } - if (FFlag::DebugLuauTrackOwningArena) - asMutable(res)->owningArena = &dest; - return res; } @@ -440,12 +436,11 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks { TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks, cloneState}; Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. + + // TODO: Make this work when the arena of 'res' might be frozen asMutable(res)->documentationSymbol = typeId->documentationSymbol; } - if (FFlag::DebugLuauTrackOwningArena) - asMutable(res)->owningArena = &dest; - return res; } @@ -508,8 +503,8 @@ bool Module::clonePublicInterface() if (moduleScope->varargPack) moduleScope->varargPack = clone(*moduleScope->varargPack, interfaceTypes, seenTypes, seenTypePacks, cloneState); - for (auto& pair : moduleScope->exportedTypeBindings) - pair.second = clone(pair.second, interfaceTypes, seenTypes, seenTypePacks, cloneState); + for (auto& [name, tf] : moduleScope->exportedTypeBindings) + tf = clone(tf, interfaceTypes, seenTypes, seenTypePacks, cloneState); for (TypeId ty : moduleScope->returnType) if (get(follow(ty))) diff --git a/Analysis/src/Predicate.cpp b/Analysis/src/Predicate.cpp index 848627cf8..7bd8001e3 100644 --- a/Analysis/src/Predicate.cpp +++ b/Analysis/src/Predicate.cpp @@ -24,7 +24,7 @@ std::optional tryGetLValue(const AstExpr& node) else if (auto indexexpr = expr->as()) { if (auto lvalue = tryGetLValue(*indexexpr->expr)) - if (auto string = indexexpr->expr->as()) + if (auto string = indexexpr->index->as()) return Field{std::make_shared(*lvalue), std::string(string->value.data, string->value.size)}; } diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 6322096c4..a6be53482 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -13,6 +13,13 @@ LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) LUAU_FASTFLAGVARIABLE(LuauFunctionArgumentNameSize, false) +/* + * Prefix generic typenames with gen- + * Additionally, free types will be prefixed with free- and suffixed with their level. eg free-a-4 + * Fair warning: Setting this will break a lot of Luau unit tests. + */ +LUAU_FASTFLAGVARIABLE(DebugLuauVerboseTypeNames, false) + namespace Luau { @@ -290,7 +297,15 @@ struct TypeVarStringifier void operator()(TypeId ty, const Unifiable::Free& ftv) { state.result.invalid = true; + if (FFlag::DebugLuauVerboseTypeNames) + state.emit("free-"); state.emit(state.getName(ty)); + + if (FFlag::DebugLuauVerboseTypeNames) + { + state.emit("-"); + state.emit(std::to_string(ftv.level.level)); + } } void operator()(TypeId, const BoundTypeVar& btv) @@ -802,6 +817,8 @@ struct TypePackStringifier void operator()(TypePackId tp, const GenericTypePack& pack) { + if (FFlag::DebugLuauVerboseTypeNames) + state.emit("gen-"); if (pack.explicitName) { state.result.nameMap.typePacks[tp] = pack.name; @@ -817,7 +834,16 @@ struct TypePackStringifier void operator()(TypePackId tp, const FreeTypePack& pack) { state.result.invalid = true; + if (FFlag::DebugLuauVerboseTypeNames) + state.emit("free-"); state.emit(state.getName(tp)); + + if (FFlag::DebugLuauVerboseTypeNames) + { + state.emit("-"); + state.emit(std::to_string(pack.level.level)); + } + state.emit("..."); } @@ -1181,20 +1207,24 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV return s; } -void dump(TypeId ty) +std::string dump(TypeId ty) { ToStringOptions opts; opts.exhaustive = true; opts.functionTypeArguments = true; - printf("%s\n", toString(ty, opts).c_str()); + std::string s = toString(ty, opts); + printf("%s\n", s.c_str()); + return s; } -void dump(TypePackId ty) +std::string dump(TypePackId ty) { ToStringOptions opts; opts.exhaustive = true; opts.functionTypeArguments = true; - printf("%s\n", toString(ty, opts).c_str()); + std::string s = toString(ty, opts); + printf("%s\n", s.c_str()); + return s; } std::string generateName(size_t i) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 617bf482c..abbc2901b 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -9,9 +9,9 @@ #include "Luau/Scope.h" #include "Luau/Substitution.h" #include "Luau/TopoSortStatements.h" -#include "Luau/ToString.h" #include "Luau/TypePack.h" #include "Luau/TypeUtils.h" +#include "Luau/ToString.h" #include "Luau/TypeVar.h" #include "Luau/TimeTrace.h" @@ -29,7 +29,6 @@ LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) -LUAU_FASTFLAGVARIABLE(LuauStrictRequire, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) @@ -37,6 +36,12 @@ LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauTailArgumentTypeInfo, false) LUAU_FASTFLAGVARIABLE(LuauModuleRequireErrorPack, false) +LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false) +LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) +LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) +LUAU_FASTFLAGVARIABLE(LuauFixRecursiveMetatableCall, false) +LUAU_FASTFLAGVARIABLE(LuauBidirectionalAsExpr, false) +LUAU_FASTFLAGVARIABLE(LuauUpdateFunctionNameBinding, false) namespace Luau { @@ -206,14 +211,14 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan : resolver(resolver) , iceHandler(iceHandler) , unifierState(iceHandler) - , nilType(singletonTypes.nilType) - , numberType(singletonTypes.numberType) - , stringType(singletonTypes.stringType) - , booleanType(singletonTypes.booleanType) - , threadType(singletonTypes.threadType) - , anyType(singletonTypes.anyType) - , optionalNumberType(singletonTypes.optionalNumberType) - , anyTypePack(singletonTypes.anyTypePack) + , nilType(getSingletonTypes().nilType) + , numberType(getSingletonTypes().numberType) + , stringType(getSingletonTypes().stringType) + , booleanType(getSingletonTypes().booleanType) + , threadType(getSingletonTypes().threadType) + , anyType(getSingletonTypes().anyType) + , optionalNumberType(getSingletonTypes().optionalNumberType) + , anyTypePack(getSingletonTypes().anyTypePack) { globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); @@ -443,7 +448,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) functionDecls[*protoIter] = pair; ++subLevel; - TypeId leftType = checkFunctionName(scope, *fun->name); + TypeId leftType = checkFunctionName(scope, *fun->name, funScope->level); unify(leftType, funTy, fun->location); } else if (auto fun = (*protoIter)->as()) @@ -711,14 +716,15 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) } else if (auto tail = valueIter.tail()) { - if (get(*tail)) + TypePackId tailPack = follow(*tail); + if (get(tailPack)) right = errorRecoveryType(scope); - else if (auto vtp = get(*tail)) + else if (auto vtp = get(tailPack)) right = vtp->ty; - else if (get(*tail)) + else if (get(tailPack)) { - *asMutable(*tail) = TypePack{{left}}; - growingPack = getMutable(*tail); + *asMutable(tailPack) = TypePack{{left}}; + growingPack = getMutable(tailPack); } } @@ -1107,8 +1113,27 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco unify(leftType, ty, function.location); - if (leftTypeBinding) - *leftTypeBinding = follow(quantify(funScope, leftType, function.name->location)); + if (FFlag::LuauUpdateFunctionNameBinding) + { + LUAU_ASSERT(function.name->is() || function.name->is()); + + if (auto exprIndexName = function.name->as()) + { + if (auto typeIt = currentModule->astTypes.find(exprIndexName->expr)) + { + if (auto ttv = getMutableTableType(*typeIt)) + { + if (auto it = ttv->props.find(exprIndexName->index.value); it != ttv->props.end()) + it->second.type = follow(quantify(funScope, leftType, function.name->location)); + } + } + } + } + else + { + if (leftTypeBinding) + *leftTypeBinding = follow(quantify(funScope, leftType, function.name->location)); + } } } @@ -1148,8 +1173,10 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias } else { - ScopePtr aliasScope = - FFlag::LuauQuantifyInPlace2 ? childScope(scope, typealias.location, subLevel) : childScope(scope, typealias.location); + ScopePtr aliasScope = childScope(scope, typealias.location); + aliasScope->level = scope->level.incr(); + if (FFlag::LuauProperTypeLevels) + aliasScope->level.subLevel = subLevel; auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks); @@ -1166,6 +1193,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias ice("Not predeclared"); ScopePtr aliasScope = childScope(scope, typealias.location); + aliasScope->level = scope->level.incr(); for (TypeId ty : binding->typeParams) { @@ -1505,9 +1533,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa else if (auto vtp = get(retPack)) return {vtp->ty, std::move(result.predicates)}; else if (get(retPack)) - ice("Unexpected abstract type pack!"); + ice("Unexpected abstract type pack!", expr.location); else - ice("Unknown TypePack type!"); + ice("Unknown TypePack type!", expr.location); } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexName& expr) @@ -1574,7 +1602,7 @@ std::optional TypeChecker::getIndexTypeFromType( } else if (tableType->state == TableState::Free) { - TypeId result = freshType(scope); + TypeId result = FFlag::LuauAscribeCorrectLevelToInferredProperitesOfFreeTables ? freshType(tableType->level) : freshType(scope); tableType->props[name] = {result}; return result; } @@ -1738,7 +1766,16 @@ TypeId TypeChecker::stripFromNilAndReport(TypeId ty, const Location& location) ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIndexExpr& expr) { - return {checkLValue(scope, expr)}; + TypeId ty = checkLValue(scope, expr); + + if (FFlag::LuauRefiLookupFromIndexExpr) + { + if (std::optional lvalue = tryGetLValue(expr)) + if (std::optional refiTy = resolveLValue(scope, *lvalue)) + return {*refiTy, {TruthyPredicate{std::move(*lvalue), expr.location}}}; + } + + return {ty}; } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprFunction& expr, std::optional expectedType) @@ -2421,12 +2458,27 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy TypeId annotationType = resolveType(scope, *expr.annotation); ExprResult result = checkExpr(scope, *expr.expr, annotationType); - ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); - reportErrors(errorVec); - if (!errorVec.empty()) - annotationType = errorRecoveryType(annotationType); + if (FFlag::LuauBidirectionalAsExpr) + { + // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. + if (canUnify(result.type, annotationType, expr.location).empty()) + return {annotationType, std::move(result.predicates)}; + + if (canUnify(annotationType, result.type, expr.location).empty()) + return {annotationType, std::move(result.predicates)}; - return {annotationType, std::move(result.predicates)}; + reportError(expr.location, TypesAreUnrelated{result.type, annotationType}); + return {errorRecoveryType(annotationType), std::move(result.predicates)}; + } + else + { + ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); + reportErrors(errorVec); + if (!errorVec.empty()) + annotationType = errorRecoveryType(annotationType); + + return {annotationType, std::move(result.predicates)}; + } } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprError& expr) @@ -2674,8 +2726,15 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope // Answers the question: "Can I define another function with this name?" // Primarily about detecting duplicates. -TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) +TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level) { + auto freshTy = [&]() { + if (FFlag::LuauProperTypeLevels) + return freshType(level); + else + return freshType(scope); + }; + if (auto globalName = funName.as()) { const ScopePtr& globalScope = currentModule->getModuleScope(); @@ -2689,7 +2748,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) } else { - TypeId ty = freshType(scope); + TypeId ty = freshTy(); globalScope->bindings[name] = {ty, funName.location}; return ty; } @@ -2699,7 +2758,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) Symbol name = localName->local; Binding& binding = scope->bindings[name]; if (binding.typeId == nullptr) - binding = {freshType(scope), funName.location}; + binding = {freshTy(), funName.location}; return binding.typeId; } @@ -2730,7 +2789,7 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName) Property& property = ttv->props[name]; - property.type = freshType(scope); + property.type = freshTy(); property.location = indexName->indexLocation; ttv->methodDefinitionLocations[name] = funName.location; return property.type; @@ -3327,7 +3386,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A fn = follow(fn); if (auto ret = checkCallOverload( - scope, expr, fn, retPack, argPack, args, argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) + scope, expr, fn, retPack, argPack, args, &argLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors)) return *ret; } @@ -3402,9 +3461,11 @@ std::vector> TypeChecker::getExpectedTypesForCall(const st } std::optional> TypeChecker::checkCallOverload(const ScopePtr& scope, const AstExprCall& expr, TypeId fn, TypePackId retPack, - TypePackId argPack, TypePack* args, const std::vector& argLocations, const ExprResult& argListResult, + TypePackId argPack, TypePack* args, const std::vector* argLocations, const ExprResult& argListResult, std::vector& overloadsThatMatchArgCount, std::vector& overloadsThatDont, std::vector& errors) { + LUAU_ASSERT(argLocations); + fn = stripFromNilAndReport(fn, expr.func->location); if (get(fn)) @@ -3428,31 +3489,44 @@ std::optional> TypeChecker::checkCallOverload(const Scope return {{retPack}}; } - const FunctionTypeVar* ftv = get(fn); - if (!ftv) + std::vector metaArgLocations; + + // Might be a callable table + if (const MetatableTypeVar* mttv = get(fn)) { - // Might be a callable table - if (const MetatableTypeVar* mttv = get(fn)) + if (std::optional ty = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, false)) { - if (std::optional ty = getIndexTypeFromType(scope, mttv->metatable, "__call", expr.func->location, false)) - { - // Construct arguments with 'self' added in front - TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail})); + // Construct arguments with 'self' added in front + TypePackId metaCallArgPack = addTypePack(TypePackVar(TypePack{args->head, args->tail})); - TypePack* metaCallArgs = getMutable(metaCallArgPack); - metaCallArgs->head.insert(metaCallArgs->head.begin(), fn); + TypePack* metaCallArgs = getMutable(metaCallArgPack); + metaCallArgs->head.insert(metaCallArgs->head.begin(), fn); - std::vector metaArgLocations = argLocations; - metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); + metaArgLocations = *argLocations; + metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); + if (FFlag::LuauFixRecursiveMetatableCall) + { + fn = instantiate(scope, *ty, expr.func->location); + + argPack = metaCallArgPack; + args = metaCallArgs; + argLocations = &metaArgLocations; + } + else + { TypeId fn = *ty; fn = instantiate(scope, fn, expr.func->location); - return checkCallOverload(scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, metaArgLocations, argListResult, + return checkCallOverload(scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, &metaArgLocations, argListResult, overloadsThatMatchArgCount, overloadsThatDont, errors); } } + } + const FunctionTypeVar* ftv = get(fn); + if (!ftv) + { reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}}); unify(retPack, errorRecoveryTypePack(scope), expr.func->location); return {{errorRecoveryTypePack(retPack)}}; @@ -3477,7 +3551,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope return {}; } - checkArgumentList(scope, state, argPack, ftv->argTypes, argLocations); + checkArgumentList(scope, state, argPack, ftv->argTypes, *argLocations); if (!state.errors.empty()) { @@ -3772,7 +3846,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module if (moduleInfo.name.empty()) { - if (FFlag::LuauStrictRequire && currentModule->mode == Mode::Strict) + if (currentModule->mode == Mode::Strict) { reportError(TypeError{location, UnknownRequire{}}); return errorRecoveryType(anyType); @@ -4268,9 +4342,11 @@ ScopePtr TypeChecker::childFunctionScope(const ScopePtr& parent, const Location& } // Creates a new Scope and carries forward the varargs from the parent. -ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& location, int subLevel) +ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& location) { - ScopePtr scope = std::make_shared(parent, subLevel); + ScopePtr scope = std::make_shared(parent); + if (FFlag::LuauProperTypeLevels) + scope->level = parent->level; scope->varargPack = parent->varargPack; currentModule->scopes.push_back(std::make_pair(location, scope)); @@ -4329,22 +4405,22 @@ TypeId TypeChecker::singletonType(std::string value) TypeId TypeChecker::errorRecoveryType(const ScopePtr& scope) { - return singletonTypes.errorRecoveryType(); + return getSingletonTypes().errorRecoveryType(); } TypeId TypeChecker::errorRecoveryType(TypeId guess) { - return singletonTypes.errorRecoveryType(guess); + return getSingletonTypes().errorRecoveryType(guess); } TypePackId TypeChecker::errorRecoveryTypePack(const ScopePtr& scope) { - return singletonTypes.errorRecoveryTypePack(); + return getSingletonTypes().errorRecoveryTypePack(); } TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess) { - return singletonTypes.errorRecoveryTypePack(guess); + return getSingletonTypes().errorRecoveryTypePack(guess); } std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) @@ -4547,6 +4623,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation else if (const auto& func = annotation.as()) { ScopePtr funcScope = childScope(scope, func->location); + funcScope->level = scope->level.incr(); auto [generics, genericPacks] = createGenericTypes(funcScope, std::nullopt, annotation, func->generics, func->genericPacks); diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 0d9d91e0c..8c6d5e49f 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -19,7 +19,7 @@ std::optional findMetatableEntry(ErrorVec& errors, const ScopePtr& globa TypeId unwrapped = follow(*metatable); if (get(unwrapped)) - return singletonTypes.anyType; + return getSingletonTypes().anyType; const TableTypeVar* mtt = getTableType(unwrapped); if (!mtt) @@ -61,12 +61,12 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, const Sc { std::optional r = first(follow(itf->retType)); if (!r) - return singletonTypes.nilType; + return getSingletonTypes().nilType; else return *r; } else if (get(index)) - return singletonTypes.anyType; + return getSingletonTypes().anyType; else errors.push_back(TypeError{location, GenericError{"__index should either be a function or table. Got " + toString(index)}}); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 62715af53..571b13ca0 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -21,6 +21,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) LUAU_FASTFLAG(LuauErrorRecoveryType) +LUAU_FASTFLAG(DebugLuauFreezeArena) namespace Luau { @@ -579,11 +580,25 @@ SingletonTypes::SingletonTypes() , arena(new TypeArena) { TypeId stringMetatable = makeStringMetatable(); - stringType_.ty = PrimitiveTypeVar{PrimitiveTypeVar::String, makeStringMetatable()}; + stringType_.ty = PrimitiveTypeVar{PrimitiveTypeVar::String, stringMetatable}; persist(stringMetatable); + + debugFreezeArena = FFlag::DebugLuauFreezeArena; freeze(*arena); } +SingletonTypes::~SingletonTypes() +{ + // Destroy the arena with the same memory management flags it was created with + bool prevFlag = FFlag::DebugLuauFreezeArena; + FFlag::DebugLuauFreezeArena.value = debugFreezeArena; + + unfreeze(*arena); + arena.reset(nullptr); + + FFlag::DebugLuauFreezeArena.value = prevFlag; +} + TypeId SingletonTypes::makeStringMetatable() { const TypeId optionalNumber = arena->addType(UnionTypeVar{{nilType, numberType}}); @@ -641,6 +656,9 @@ TypeId SingletonTypes::makeStringMetatable() TypeId tableType = arena->addType(TableTypeVar{std::move(stringLib), std::nullopt, TypeLevel{}, TableState::Sealed}); + if (TableTypeVar* ttv = getMutable(tableType)) + ttv->name = "string"; + return arena->addType(TableTypeVar{{{{"__index", {tableType}}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); } @@ -670,7 +688,11 @@ TypePackId SingletonTypes::errorRecoveryTypePack(TypePackId guess) return &errorTypePack_; } -SingletonTypes singletonTypes; +SingletonTypes& getSingletonTypes() +{ + static SingletonTypes singletonTypes; + return singletonTypes; +} void persist(TypeId ty) { @@ -719,6 +741,18 @@ void persist(TypeId ty) for (TypeId opt : itv->parts) queue.push_back(opt); } + else if (auto mtv = get(t)) + { + queue.push_back(mtv->table); + queue.push_back(mtv->metatable); + } + else if (get(t) || get(t) || get(t) || get(t) || get(t)) + { + } + else + { + LUAU_ASSERT(!"TypeId is not supported in a persist call"); + } } } @@ -736,6 +770,17 @@ void persist(TypePackId tp) if (p->tail) persist(*p->tail); } + else if (auto vtp = get(tp)) + { + persist(vtp->ty); + } + else if (get(tp)) + { + } + else + { + LUAU_ASSERT(!"TypePackId is not supported in a persist call"); + } } const TypeLevel* getLevel(TypeId ty) @@ -757,167 +802,6 @@ TypeLevel* getMutableLevel(TypeId ty) return const_cast(getLevel(ty)); } -struct QVarFinder -{ - mutable DenseHashSet seen; - - QVarFinder() - : seen(nullptr) - { - } - - bool hasSeen(const void* tv) const - { - if (seen.contains(tv)) - return true; - - seen.insert(tv); - return false; - } - - bool hasGeneric(TypeId tid) const - { - if (hasSeen(&tid->ty)) - return false; - - return Luau::visit(*this, tid->ty); - } - - bool hasGeneric(TypePackId tp) const - { - if (hasSeen(&tp->ty)) - return false; - - return Luau::visit(*this, tp->ty); - } - - bool operator()(const Unifiable::Free&) const - { - return false; - } - - bool operator()(const Unifiable::Bound& bound) const - { - return hasGeneric(bound.boundTo); - } - - bool operator()(const Unifiable::Generic&) const - { - return true; - } - bool operator()(const Unifiable::Error&) const - { - return false; - } - bool operator()(const PrimitiveTypeVar&) const - { - return false; - } - - bool operator()(const SingletonTypeVar&) const - { - return false; - } - - bool operator()(const FunctionTypeVar& ftv) const - { - if (hasGeneric(ftv.argTypes)) - return true; - return hasGeneric(ftv.retType); - } - - bool operator()(const TableTypeVar& ttv) const - { - if (ttv.state == TableState::Generic) - return true; - - if (ttv.indexer) - { - if (hasGeneric(ttv.indexer->indexType)) - return true; - if (hasGeneric(ttv.indexer->indexResultType)) - return true; - } - - for (const auto& [_name, prop] : ttv.props) - { - if (hasGeneric(prop.type)) - return true; - } - - return false; - } - - bool operator()(const MetatableTypeVar& mtv) const - { - return hasGeneric(mtv.table) || hasGeneric(mtv.metatable); - } - - bool operator()(const ClassTypeVar& ctv) const - { - for (const auto& [name, prop] : ctv.props) - { - if (hasGeneric(prop.type)) - return true; - } - - if (ctv.parent) - return hasGeneric(*ctv.parent); - - return false; - } - - bool operator()(const AnyTypeVar&) const - { - return false; - } - - bool operator()(const UnionTypeVar& utv) const - { - for (TypeId tid : utv.options) - if (hasGeneric(tid)) - return true; - - return false; - } - - bool operator()(const IntersectionTypeVar& utv) const - { - for (TypeId tid : utv.parts) - if (hasGeneric(tid)) - return true; - - return false; - } - - bool operator()(const LazyTypeVar&) const - { - return false; - } - - bool operator()(const Unifiable::Bound& bound) const - { - return hasGeneric(bound.boundTo); - } - - bool operator()(const TypePack& pack) const - { - for (TypeId ty : pack.head) - if (hasGeneric(ty)) - return true; - - if (pack.tail) - return hasGeneric(*pack.tail); - - return false; - } - - bool operator()(const VariadicTypePack& pack) const - { - return hasGeneric(pack.ty); - } -}; - const Property* lookupClassProp(const ClassTypeVar* cls, const Name& name) { while (cls) @@ -953,16 +837,6 @@ bool isSubclass(const ClassTypeVar* cls, const ClassTypeVar* parent) return false; } -bool hasGeneric(TypeId ty) -{ - return Luau::visit(QVarFinder{}, ty->ty); -} - -bool hasGeneric(TypePackId tp) -{ - return Luau::visit(QVarFinder{}, tp->ty); -} - UnionTypeVarIterator::UnionTypeVarIterator(const UnionTypeVar* utv) { LUAU_ASSERT(utv); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index d0b188374..c5aab8562 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -14,7 +14,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); -LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance, false); +LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) @@ -22,9 +22,82 @@ LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false) LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false) LUAU_FASTFLAG(LuauErrorRecoveryType); +LUAU_FASTFLAG(LuauProperTypeLevels); +LUAU_FASTFLAGVARIABLE(LuauExtendedUnionMismatchError, false) +LUAU_FASTFLAGVARIABLE(LuauExtendedFunctionMismatchError, false) namespace Luau { + +struct PromoteTypeLevels +{ + TxnLog& log; + TypeLevel minLevel; + + explicit PromoteTypeLevels(TxnLog& log, TypeLevel minLevel) + : log(log) + , minLevel(minLevel) + {} + + template + void promote(TID ty, T* t) + { + LUAU_ASSERT(t); + if (minLevel.subsumesStrict(t->level)) + { + log(ty); + t->level = minLevel; + } + } + + template + void cycle(TID) {} + + template + bool operator()(TID, const T&) + { + return true; + } + + bool operator()(TypeId ty, const FreeTypeVar&) + { + promote(ty, getMutable(ty)); + return true; + } + + bool operator()(TypeId ty, const FunctionTypeVar&) + { + promote(ty, getMutable(ty)); + return true; + } + + bool operator()(TypeId ty, const TableTypeVar&) + { + promote(ty, getMutable(ty)); + return true; + } + + bool operator()(TypePackId tp, const FreeTypePack&) + { + promote(tp, getMutable(tp)); + return true; + } +}; + +void promoteTypeLevels(TxnLog& log, TypeLevel minLevel, TypeId ty) +{ + PromoteTypeLevels ptl{log, minLevel}; + DenseHashSet seen{nullptr}; + visitTypeVarOnce(ty, ptl, seen); +} + +void promoteTypeLevels(TxnLog& log, TypeLevel minLevel, TypePackId tp) +{ + PromoteTypeLevels ptl{log, minLevel}; + DenseHashSet seen{nullptr}; + visitTypeVarOnce(tp, ptl, seen); +} + struct SkipCacheForType { SkipCacheForType(const DenseHashMap& skipCacheForType) @@ -127,6 +200,29 @@ static std::optional hasUnificationTooComplex(const ErrorVec& errors) return *it; } +// Used for tagged union matching heuristic, returns first singleton type field +static std::optional> getTableMatchTag(TypeId type) +{ + LUAU_ASSERT(FFlag::LuauExtendedUnionMismatchError); + + type = follow(type); + + if (auto ttv = get(type)) + { + for (auto&& [name, prop] : ttv->props) + { + if (auto sing = get(follow(prop.type))) + return {{name, sing}}; + } + } + else if (auto mttv = get(type)) + { + return getTableMatchTag(mttv->table); + } + + return std::nullopt; +} + Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState) : types(types) , mode(mode) @@ -214,9 +310,11 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { occursCheck(superTy, subTy); + TypeLevel superLevel = l->level; + // Unification can't change the level of a generic. auto rightGeneric = get(subTy); - if (rightGeneric && !rightGeneric->level.subsumes(l->level)) + if (rightGeneric && !rightGeneric->level.subsumes(superLevel)) { // TODO: a more informative error message? CLI-39912 errors.push_back(TypeError{location, GenericError{"Generic subtype escaping scope"}}); @@ -226,7 +324,9 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool // The occurrence check might have caused superTy no longer to be a free type if (!get(superTy)) { - if (auto rightLevel = getMutableLevel(subTy)) + if (FFlag::LuauProperTypeLevels) + promoteTypeLevels(log, superLevel, subTy); + else if (auto rightLevel = getMutableLevel(subTy)) { if (!rightLevel->subsumes(l->level)) *rightLevel = l->level; @@ -240,6 +340,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool } else if (r) { + TypeLevel subLevel = r->level; + occursCheck(subTy, superTy); // Unification can't change the level of a generic. @@ -253,10 +355,16 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (!get(subTy)) { - if (auto leftLevel = getMutableLevel(superTy)) + if (FFlag::LuauProperTypeLevels) + promoteTypeLevels(log, subLevel, superTy); + + if (auto superLevel = getMutableLevel(superTy)) { - if (!leftLevel->subsumes(r->level)) - *leftLevel = r->level; + if (!superLevel->subsumes(r->level)) + { + log(superTy); + *superLevel = r->level; + } } log(subTy); @@ -327,7 +435,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool else if (failed) { if (FFlag::LuauExtendedTypeMismatchError && firstFailedOption) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible", *firstFailedOption}}); + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); else errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } @@ -338,28 +446,46 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool bool found = false; std::optional unificationTooComplex; + size_t failedOptionCount = 0; + std::optional failedOption; + + bool foundHeuristic = false; size_t startIndex = 0; if (FFlag::LuauUnionHeuristic) { - bool found = false; - - const std::string* subName = getName(subTy); - if (subName) + if (const std::string* subName = getName(subTy)) { for (size_t i = 0; i < uv->options.size(); ++i) { const std::string* optionName = getName(uv->options[i]); if (optionName && *optionName == *subName) { - found = true; + foundHeuristic = true; startIndex = i; break; } } } - if (!found && cacheEnabled) + if (FFlag::LuauExtendedUnionMismatchError) + { + if (auto subMatchTag = getTableMatchTag(subTy)) + { + for (size_t i = 0; i < uv->options.size(); ++i) + { + auto optionMatchTag = getTableMatchTag(uv->options[i]); + if (optionMatchTag && optionMatchTag->first == subMatchTag->first && *optionMatchTag->second == *subMatchTag->second) + { + foundHeuristic = true; + startIndex = i; + break; + } + } + } + } + + if (!foundHeuristic && cacheEnabled) { for (size_t i = 0; i < uv->options.size(); ++i) { @@ -390,15 +516,27 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { unificationTooComplex = e; } + else if (FFlag::LuauExtendedUnionMismatchError && !isNil(type)) + { + failedOptionCount++; + + if (!failedOption) + failedOption = {innerState.errors.front()}; + } innerState.log.rollback(); } if (unificationTooComplex) + { errors.push_back(*unificationTooComplex); + } else if (!found) { - if (FFlag::LuauExtendedTypeMismatchError) + if (FFlag::LuauExtendedUnionMismatchError && (failedOptionCount == 1 || foundHeuristic) && failedOption) + errors.push_back( + TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}}); + else if (FFlag::LuauExtendedTypeMismatchError) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); else errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); @@ -431,7 +569,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (unificationTooComplex) errors.push_back(*unificationTooComplex); else if (firstFailedOption) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible", *firstFailedOption}}); + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); } else { @@ -771,6 +909,10 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal if (superIter.good() && subIter.good()) { tryUnify_(*superIter, *subIter); + + if (FFlag::LuauExtendedFunctionMismatchError && !errors.empty() && !firstPackErrorPos) + firstPackErrorPos = loopCount; + superIter.advance(); subIter.advance(); continue; @@ -853,13 +995,13 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal while (superIter.good()) { - tryUnify_(singletonTypes.errorRecoveryType(), *superIter); + tryUnify_(getSingletonTypes().errorRecoveryType(), *superIter); superIter.advance(); } while (subIter.good()) { - tryUnify_(singletonTypes.errorRecoveryType(), *subIter); + tryUnify_(getSingletonTypes().errorRecoveryType(), *subIter); subIter.advance(); } @@ -917,14 +1059,22 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal if (numGenerics != rf->generics.size()) { numGenerics = std::min(lf->generics.size(), rf->generics.size()); - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + + if (FFlag::LuauExtendedFunctionMismatchError) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}}); + else + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } size_t numGenericPacks = lf->genericPacks.size(); if (numGenericPacks != rf->genericPacks.size()) { numGenericPacks = std::min(lf->genericPacks.size(), rf->genericPacks.size()); - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + + if (FFlag::LuauExtendedFunctionMismatchError) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}}); + else + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } for (size_t i = 0; i < numGenerics; i++) @@ -936,13 +1086,49 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal { Unifier innerState = makeChildUnifier(); - ctx = CountMismatch::Arg; - innerState.tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); + if (FFlag::LuauExtendedFunctionMismatchError) + { + innerState.ctx = CountMismatch::Arg; + innerState.tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); - ctx = CountMismatch::Result; - innerState.tryUnify_(lf->retType, rf->retType); + bool reported = !innerState.errors.empty(); + + if (auto e = hasUnificationTooComplex(innerState.errors)) + errors.push_back(*e); + else if (!innerState.errors.empty() && innerState.firstPackErrorPos) + errors.push_back( + TypeError{location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front()}}); + else if (!innerState.errors.empty()) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); + + innerState.ctx = CountMismatch::Result; + innerState.tryUnify_(lf->retType, rf->retType); + + if (!reported) + { + if (auto e = hasUnificationTooComplex(innerState.errors)) + errors.push_back(*e); + else if (!innerState.errors.empty() && size(lf->retType) == 1 && finite(lf->retType)) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); + else if (!innerState.errors.empty() && innerState.firstPackErrorPos) + errors.push_back( + TypeError{location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front()}}); + else if (!innerState.errors.empty()) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); + } + } + else + { + ctx = CountMismatch::Arg; + innerState.tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); - checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); + ctx = CountMismatch::Result; + innerState.tryUnify_(lf->retType, rf->retType); + + checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); + } log.concat(std::move(innerState.log)); } @@ -994,7 +1180,7 @@ struct Resetter void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) { - if (!FFlag::LuauTableSubtypingVariance) + if (!FFlag::LuauTableSubtypingVariance2) return DEPRECATED_tryUnifyTables(left, right, isIntersection); TableTypeVar* lt = getMutable(left); @@ -1133,7 +1319,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) // TODO: hopefully readonly/writeonly properties will fix this. Property clone = prop; clone.type = deeplyOptional(clone.type); - log(lt); + log(left); lt->props[name] = clone; } else if (variance == Covariant) @@ -1146,7 +1332,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) } else if (lt->state == TableState::Free) { - log(lt); + log(left); lt->props[name] = prop; } else @@ -1176,7 +1362,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. // TODO: we only need to do this if the supertype's indexer is read/write // since that can add indexed elements. - log(rt); + log(right); rt->indexer = lt->indexer; } } @@ -1185,7 +1371,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) // Symmetric if we are invariant if (lt->state == TableState::Unsealed || lt->state == TableState::Free) { - log(lt); + log(left); lt->indexer = rt->indexer; } } @@ -1241,15 +1427,15 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map see TableTypeVar* resultTtv = getMutable(result); for (auto& [name, prop] : resultTtv->props) prop.type = deeplyOptional(prop.type, seen); - return types->addType(UnionTypeVar{{singletonTypes.nilType, result}}); + return types->addType(UnionTypeVar{{getSingletonTypes().nilType, result}}); } else - return types->addType(UnionTypeVar{{singletonTypes.nilType, ty}}); + return types->addType(UnionTypeVar{{getSingletonTypes().nilType, ty}}); } void Unifier::DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection) { - LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance); + LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); Resetter resetter{&variance}; variance = Invariant; @@ -1467,7 +1653,7 @@ void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersectio } else if (lt->indexer) { - innerState.tryUnify_(lt->indexer->indexType, singletonTypes.stringType); + innerState.tryUnify_(lt->indexer->indexType, getSingletonTypes().stringType); // We already try to unify properties in both tables. // Skip those and just look for the ones remaining and see if they fit into the indexer. for (const auto& [name, type] : rt->props) @@ -1636,7 +1822,7 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) ok = false; errors.push_back(TypeError{location, UnknownProperty{superTy, propName}}); if (!FFlag::LuauExtendedClassMismatchError) - tryUnify_(prop.type, singletonTypes.errorRecoveryType()); + tryUnify_(prop.type, getSingletonTypes().errorRecoveryType()); } else { @@ -1825,7 +2011,7 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) if (get(ty) || get(ty) || get(ty)) return; - const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{singletonTypes.anyType}}); + const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{getSingletonTypes().anyType}}); const TypePackId anyTP = get(any) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); @@ -1834,14 +2020,14 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) sharedState.tempSeenTy.clear(); sharedState.tempSeenTp.clear(); - Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, singletonTypes.anyType, anyTP); + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, getSingletonTypes().anyType, anyTP); } void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) { LUAU_ASSERT(get(any)); - const TypeId anyTy = singletonTypes.errorRecoveryType(); + const TypeId anyTy = getSingletonTypes().errorRecoveryType(); std::vector queue; @@ -1887,7 +2073,7 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays { errors.push_back(TypeError{location, OccursCheckFailed{}}); log(needle); - *asMutable(needle) = *singletonTypes.errorRecoveryType(); + *asMutable(needle) = *getSingletonTypes().errorRecoveryType(); return; } @@ -1951,7 +2137,7 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ { errors.push_back(TypeError{location, OccursCheckFailed{}}); log(needle); - *asMutable(needle) = *singletonTypes.errorRecoveryTypePack(); + *asMutable(needle) = *getSingletonTypes().errorRecoveryTypePack(); return; } @@ -2005,7 +2191,7 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const s errors.push_back(*e); else if (!innerErrors.empty()) errors.push_back( - TypeError{location, TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible", prop.c_str()), innerErrors.front()}}); + TypeError{location, TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible.", prop.c_str()), innerErrors.front()}}); } void Unifier::ice(const std::string& message, const Location& location) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 3d0d5b7e6..dd24f27cb 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -10,7 +10,6 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauCaptureBrokenCommentSpans, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) @@ -159,7 +158,7 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n { std::vector hotcomments; - while (isComment(p.lexer.current()) || (FFlag::LuauCaptureBrokenCommentSpans && p.lexer.current().type == Lexeme::BrokenComment)) + while (isComment(p.lexer.current()) || p.lexer.current().type == Lexeme::BrokenComment) { const char* text = p.lexer.current().data; unsigned int length = p.lexer.current().length; @@ -2780,7 +2779,7 @@ const Lexeme& Parser::nextLexeme() const Lexeme& lexeme = lexer.next(/*skipComments*/ false); // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. // The parser will turn this into a proper syntax error. - if (FFlag::LuauCaptureBrokenCommentSpans && lexeme.type == Lexeme::BrokenComment) + if (lexeme.type == Lexeme::BrokenComment) commentLocations.push_back(Comment{lexeme.type, lexeme.location}); if (isComment(lexeme)) commentLocations.push_back(Comment{lexeme.type, lexeme.location}); diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 9230d80d0..aecb619a1 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -11,26 +11,33 @@ enum class ReportFormat { Default, - Luacheck + Luacheck, + Gnu, }; -static void report(ReportFormat format, const char* name, const Luau::Location& location, const char* type, const char* message) +static void report(ReportFormat format, const char* name, const Luau::Location& loc, const char* type, const char* message) { switch (format) { case ReportFormat::Default: - fprintf(stderr, "%s(%d,%d): %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, type, message); + fprintf(stderr, "%s(%d,%d): %s: %s\n", name, loc.begin.line + 1, loc.begin.column + 1, type, message); break; case ReportFormat::Luacheck: { // Note: luacheck's end column is inclusive but our end column is exclusive // In addition, luacheck doesn't support multi-line messages, so if the error is multiline we'll fake end column as 100 and hope for the best - int columnEnd = (location.begin.line == location.end.line) ? location.end.column : 100; + int columnEnd = (loc.begin.line == loc.end.line) ? loc.end.column : 100; - fprintf(stdout, "%s:%d:%d-%d: (W0) %s: %s\n", name, location.begin.line + 1, location.begin.column + 1, columnEnd, type, message); + // Use stdout to match luacheck behavior + fprintf(stdout, "%s:%d:%d-%d: (W0) %s: %s\n", name, loc.begin.line + 1, loc.begin.column + 1, columnEnd, type, message); break; } + + case ReportFormat::Gnu: + // Note: GNU end column is inclusive but our end column is exclusive + fprintf(stderr, "%s:%d.%d-%d.%d: %s: %s\n", name, loc.begin.line + 1, loc.begin.column + 1, loc.end.line + 1, loc.end.column, type, message); + break; } } @@ -97,6 +104,7 @@ static void displayHelp(const char* argv0) printf("\n"); printf("Available options:\n"); printf(" --formatter=plain: report analysis errors in Luacheck-compatible format\n"); + printf(" --formatter=gnu: report analysis errors in GNU-compatible format\n"); } static int assertionHandler(const char* expr, const char* file, int line) @@ -201,6 +209,8 @@ int main(int argc, char** argv) if (strcmp(argv[i], "--formatter=plain") == 0) format = ReportFormat::Luacheck; + else if (strcmp(argv[i], "--formatter=gnu") == 0) + format = ReportFormat::Gnu; else if (strcmp(argv[i], "--annotate") == 0) annotate = true; } diff --git a/CLI/Coverage.cpp b/CLI/Coverage.cpp new file mode 100644 index 000000000..254df3f03 --- /dev/null +++ b/CLI/Coverage.cpp @@ -0,0 +1,88 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Coverage.h" + +#include "lua.h" + +#include +#include + +struct Coverage +{ + lua_State* L = nullptr; + std::vector functions; +} gCoverage; + +void coverageInit(lua_State* L) +{ + gCoverage.L = lua_mainthread(L); +} + +bool coverageActive() +{ + return gCoverage.L != nullptr; +} + +void coverageTrack(lua_State* L, int funcindex) +{ + int ref = lua_ref(L, funcindex); + gCoverage.functions.push_back(ref); +} + +static void coverageCallback(void* context, const char* function, int linedefined, int depth, const int* hits, size_t size) +{ + FILE* f = static_cast(context); + + std::string name; + + if (depth == 0) + name = "
"; + else if (function) + name = std::string(function) + ":" + std::to_string(linedefined); + else + name = ":" + std::to_string(linedefined); + + fprintf(f, "FN:%d,%s\n", linedefined, name.c_str()); + + for (size_t i = 0; i < size; ++i) + if (hits[i] != -1) + { + fprintf(f, "FNDA:%d,%s\n", hits[i], name.c_str()); + break; + } + + for (size_t i = 0; i < size; ++i) + if (hits[i] != -1) + fprintf(f, "DA:%d,%d\n", int(i), hits[i]); +} + +void coverageDump(const char* path) +{ + lua_State* L = gCoverage.L; + + FILE* f = fopen(path, "w"); + if (!f) + { + fprintf(stderr, "Error opening coverage %s\n", path); + return; + } + + fprintf(f, "TN:\n"); + + for (int fref: gCoverage.functions) + { + lua_getref(L, fref); + + lua_Debug ar = {}; + lua_getinfo(L, -1, "s", &ar); + + fprintf(f, "SF:%s\n", ar.short_src); + lua_getcoverage(L, -1, f, coverageCallback); + fprintf(f, "end_of_record\n"); + + lua_pop(L, 1); + } + + fclose(f); + + printf("Coverage dump written to %s (%d functions)\n", path, int(gCoverage.functions.size())); +} diff --git a/CLI/Coverage.h b/CLI/Coverage.h new file mode 100644 index 000000000..74be4e5c9 --- /dev/null +++ b/CLI/Coverage.h @@ -0,0 +1,10 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +struct lua_State; + +void coverageInit(lua_State* L); +bool coverageActive(); + +void coverageTrack(lua_State* L, int funcindex); +void coverageDump(const char* path); diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index b3c9557bb..cb993dfee 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -67,6 +67,10 @@ std::optional readFile(const std::string& name) if (read != size_t(length)) return std::nullopt; + // Skip first line if it's a shebang + if (length > 2 && result[0] == '#' && result[1] == '!') + result.erase(0, result.find('\n')); + return result; } diff --git a/CLI/Profiler.cpp b/CLI/Profiler.cpp index c6d15a7f2..30a171f0f 100644 --- a/CLI/Profiler.cpp +++ b/CLI/Profiler.cpp @@ -110,12 +110,12 @@ void profilerStop() gProfiler.thread.join(); } -void profilerDump(const char* name) +void profilerDump(const char* path) { - FILE* f = fopen(name, "wb"); + FILE* f = fopen(path, "wb"); if (!f) { - fprintf(stderr, "Error opening profile %s\n", name); + fprintf(stderr, "Error opening profile %s\n", path); return; } @@ -129,7 +129,7 @@ void profilerDump(const char* name) fclose(f); - printf("Profiler dump written to %s (total runtime %.3f seconds, %lld samples, %lld stacks)\n", name, double(total) / 1e6, + printf("Profiler dump written to %s (total runtime %.3f seconds, %lld samples, %lld stacks)\n", path, double(total) / 1e6, static_cast(gProfiler.samples.load()), static_cast(gProfiler.data.size())); uint64_t totalgc = 0; diff --git a/CLI/Profiler.h b/CLI/Profiler.h index 0a407e476..67b1acfd0 100644 --- a/CLI/Profiler.h +++ b/CLI/Profiler.h @@ -5,4 +5,4 @@ struct lua_State; void profilerStart(lua_State* L, int frequency); void profilerStop(); -void profilerDump(const char* name); \ No newline at end of file +void profilerDump(const char* path); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 2cdd0062f..35c02f2c8 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -8,6 +8,7 @@ #include "FileUtils.h" #include "Profiler.h" +#include "Coverage.h" #include "linenoise.hpp" @@ -24,6 +25,16 @@ enum class CompileFormat Binary }; +static Luau::CompileOptions copts() +{ + Luau::CompileOptions result = {}; + result.optimizationLevel = 1; + result.debugLevel = 1; + result.coverageLevel = coverageActive() ? 2 : 0; + + return result; +} + static int lua_loadstring(lua_State* L) { size_t l = 0; @@ -32,7 +43,7 @@ static int lua_loadstring(lua_State* L) lua_setsafeenv(L, LUA_ENVIRONINDEX, false); - std::string bytecode = Luau::compile(std::string(s, l)); + std::string bytecode = Luau::compile(std::string(s, l), copts()); if (luau_load(L, chunkname, bytecode.data(), bytecode.size(), 0) == 0) return 1; @@ -79,9 +90,12 @@ static int lua_require(lua_State* L) luaL_sandboxthread(ML); // now we can compile & run module on the new thread - std::string bytecode = Luau::compile(*source); + std::string bytecode = Luau::compile(*source, copts()); if (luau_load(ML, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { + if (coverageActive()) + coverageTrack(ML, -1); + int status = lua_resume(ML, L, 0); if (status == 0) @@ -149,7 +163,7 @@ static void setupState(lua_State* L) static std::string runCode(lua_State* L, const std::string& source) { - std::string bytecode = Luau::compile(source); + std::string bytecode = Luau::compile(source, copts()); if (luau_load(L, "=stdin", bytecode.data(), bytecode.size(), 0) != 0) { @@ -329,11 +343,14 @@ static bool runFile(const char* name, lua_State* GL) std::string chunkname = "=" + std::string(name); - std::string bytecode = Luau::compile(*source); + std::string bytecode = Luau::compile(*source, copts()); int status = 0; if (luau_load(L, chunkname.c_str(), bytecode.data(), bytecode.size(), 0) == 0) { + if (coverageActive()) + coverageTrack(L, -1); + status = lua_resume(L, NULL, 0); } else @@ -437,6 +454,7 @@ static void displayHelp(const char* argv0) printf("\n"); printf("Available options:\n"); printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); + printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); } static int assertionHandler(const char* expr, const char* file, int line) @@ -495,6 +513,7 @@ int main(int argc, char** argv) setupState(L); int profile = 0; + bool coverage = false; for (int i = 1; i < argc; ++i) { @@ -505,11 +524,16 @@ int main(int argc, char** argv) profile = 10000; // default to 10 KHz else if (strncmp(argv[i], "--profile=", 10) == 0) profile = atoi(argv[i] + 10); + else if (strcmp(argv[i], "--coverage") == 0) + coverage = true; } if (profile) profilerStart(L, profile); + if (coverage) + coverageInit(L); + std::vector files = getSourceFiles(argc, argv); int failed = 0; @@ -523,6 +547,9 @@ int main(int argc, char** argv) profilerDump("profile.out"); } + if (coverage) + coverageDump("coverage.out"); + return failed; } } diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 2c1e85ff0..8f74ffedd 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -10,7 +10,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauPreloadClosures, false) LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport) LUAU_FASTFLAGVARIABLE(LuauBit32CountBuiltin, false) @@ -462,20 +461,17 @@ struct Compiler bool shared = false; - if (FFlag::LuauPreloadClosures) + // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure + // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it + // is used) + if (options.optimizationLevel >= 1 && shouldShareClosure(expr) && !setfenvUsed) { - // Optimization: when closure has no upvalues, or upvalues are safe to share, instead of allocating it every time we can share closure - // objects (this breaks assumptions about function identity which can lead to setfenv not working as expected, so we disable this when it - // is used) - if (options.optimizationLevel >= 1 && shouldShareClosure(expr) && !setfenvUsed) - { - int32_t cid = bytecode.addConstantClosure(f->id); + int32_t cid = bytecode.addConstantClosure(f->id); - if (cid >= 0 && cid < 32768) - { - bytecode.emitAD(LOP_DUPCLOSURE, target, cid); - shared = true; - } + if (cid >= 0 && cid < 32768) + { + bytecode.emitAD(LOP_DUPCLOSURE, target, cid); + shared = true; } } diff --git a/Makefile b/Makefile index 15c7ff7a4..b144cac60 100644 --- a/Makefile +++ b/Makefile @@ -27,7 +27,7 @@ TESTS_SOURCES=$(wildcard tests/*.cpp) TESTS_OBJECTS=$(TESTS_SOURCES:%=$(BUILD)/%.o) TESTS_TARGET=$(BUILD)/luau-tests -REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Repl.cpp +REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp REPL_CLI_OBJECTS=$(REPL_CLI_SOURCES:%=$(BUILD)/%.o) REPL_CLI_TARGET=$(BUILD)/luau @@ -128,10 +128,10 @@ luau-size: luau # executable target aliases luau: $(REPL_CLI_TARGET) - cp $^ $@ + ln -fs $^ $@ luau-analyze: $(ANALYZE_CLI_TARGET) - cp $^ $@ + ln -fs $^ $@ # executable targets $(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) diff --git a/Sources.cmake b/Sources.cmake index 57df9b91e..14834b3a5 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -133,6 +133,7 @@ target_sources(Luau.VM PRIVATE VM/src/ltable.cpp VM/src/ltablib.cpp VM/src/ltm.cpp + VM/src/ludata.cpp VM/src/lutf8lib.cpp VM/src/lvmexecute.cpp VM/src/lvmload.cpp @@ -152,12 +153,15 @@ target_sources(Luau.VM PRIVATE VM/src/lstring.h VM/src/ltable.h VM/src/ltm.h + VM/src/ludata.h VM/src/lvm.h ) if(TARGET Luau.Repl.CLI) # Luau.Repl.CLI Sources target_sources(Luau.Repl.CLI PRIVATE + CLI/Coverage.h + CLI/Coverage.cpp CLI/FileUtils.h CLI/FileUtils.cpp CLI/Profiler.h diff --git a/VM/include/lua.h b/VM/include/lua.h index 7078acd0f..55902160c 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -334,6 +334,10 @@ LUA_API const char* lua_setupvalue(lua_State* L, int funcindex, int n); LUA_API void lua_singlestep(lua_State* L, int enabled); LUA_API void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled); +typedef void (*lua_Coverage)(void* context, const char* function, int linedefined, int depth, const int* hits, size_t size); + +LUA_API void lua_getcoverage(lua_State* L, int funcindex, void* context, lua_Coverage callback); + /* Warning: this function is not thread-safe since it stores the result in a shared global array! Only use for debugging. */ LUA_API const char* lua_debugtrace(lua_State* L); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 76043b9cd..a65b03253 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -8,6 +8,7 @@ #include "lfunc.h" #include "lgc.h" #include "ldo.h" +#include "ludata.h" #include "lvm.h" #include "lnumutils.h" @@ -43,36 +44,30 @@ static Table* getcurrenv(lua_State* L) } } -static LUAU_NOINLINE TValue* index2adrslow(lua_State* L, int idx) +static LUAU_NOINLINE TValue* pseudo2addr(lua_State* L, int idx) { - api_check(L, idx <= 0); - if (idx > LUA_REGISTRYINDEX) + api_check(L, lua_ispseudo(idx)); + switch (idx) + { /* pseudo-indices */ + case LUA_REGISTRYINDEX: + return registry(L); + case LUA_ENVIRONINDEX: { - api_check(L, idx != 0 && -idx <= L->top - L->base); - return L->top + idx; + sethvalue(L, &L->env, getcurrenv(L)); + return &L->env; + } + case LUA_GLOBALSINDEX: + return gt(L); + default: + { + Closure* func = curr_func(L); + idx = LUA_GLOBALSINDEX - idx; + return (idx <= func->nupvalues) ? &func->c.upvals[idx - 1] : cast_to(TValue*, luaO_nilobject); + } } - else - switch (idx) - { /* pseudo-indices */ - case LUA_REGISTRYINDEX: - return registry(L); - case LUA_ENVIRONINDEX: - { - sethvalue(L, &L->env, getcurrenv(L)); - return &L->env; - } - case LUA_GLOBALSINDEX: - return gt(L); - default: - { - Closure* func = curr_func(L); - idx = LUA_GLOBALSINDEX - idx; - return (idx <= func->nupvalues) ? &func->c.upvals[idx - 1] : cast_to(TValue*, luaO_nilobject); - } - } } -static LUAU_FORCEINLINE TValue* index2adr(lua_State* L, int idx) +static LUAU_FORCEINLINE TValue* index2addr(lua_State* L, int idx) { if (idx > 0) { @@ -83,15 +78,20 @@ static LUAU_FORCEINLINE TValue* index2adr(lua_State* L, int idx) else return o; } + else if (idx > LUA_REGISTRYINDEX) + { + api_check(L, idx != 0 && -idx <= L->top - L->base); + return L->top + idx; + } else { - return index2adrslow(L, idx); + return pseudo2addr(L, idx); } } const TValue* luaA_toobject(lua_State* L, int idx) { - StkId p = index2adr(L, idx); + StkId p = index2addr(L, idx); return (p == luaO_nilobject) ? NULL : p; } @@ -145,7 +145,7 @@ void lua_xpush(lua_State* from, lua_State* to, int idx) { api_check(from, from->global == to->global); luaC_checkthreadsleep(to); - setobj2s(to, to->top, index2adr(from, idx)); + setobj2s(to, to->top, index2addr(from, idx)); api_incr_top(to); return; } @@ -202,7 +202,7 @@ void lua_settop(lua_State* L, int idx) void lua_remove(lua_State* L, int idx) { - StkId p = index2adr(L, idx); + StkId p = index2addr(L, idx); api_checkvalidindex(L, p); while (++p < L->top) setobjs2s(L, p - 1, p); @@ -213,7 +213,7 @@ void lua_remove(lua_State* L, int idx) void lua_insert(lua_State* L, int idx) { luaC_checkthreadsleep(L); - StkId p = index2adr(L, idx); + StkId p = index2addr(L, idx); api_checkvalidindex(L, p); for (StkId q = L->top; q > p; q--) setobjs2s(L, q, q - 1); @@ -228,7 +228,7 @@ void lua_replace(lua_State* L, int idx) luaG_runerror(L, "no calling environment"); api_checknelems(L, 1); luaC_checkthreadsleep(L); - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); api_checkvalidindex(L, o); if (idx == LUA_ENVIRONINDEX) { @@ -250,7 +250,7 @@ void lua_replace(lua_State* L, int idx) void lua_pushvalue(lua_State* L, int idx) { luaC_checkthreadsleep(L); - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); setobj2s(L, L->top, o); api_incr_top(L); return; @@ -262,7 +262,7 @@ void lua_pushvalue(lua_State* L, int idx) int lua_type(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return (o == luaO_nilobject) ? LUA_TNONE : ttype(o); } @@ -273,20 +273,20 @@ const char* lua_typename(lua_State* L, int t) int lua_iscfunction(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return iscfunction(o); } int lua_isLfunction(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return isLfunction(o); } int lua_isnumber(lua_State* L, int idx) { TValue n; - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); return tonumber(o, &n); } @@ -298,14 +298,14 @@ int lua_isstring(lua_State* L, int idx) int lua_isuserdata(lua_State* L, int idx) { - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); return (ttisuserdata(o) || ttislightuserdata(o)); } int lua_rawequal(lua_State* L, int index1, int index2) { - StkId o1 = index2adr(L, index1); - StkId o2 = index2adr(L, index2); + StkId o1 = index2addr(L, index1); + StkId o2 = index2addr(L, index2); return (o1 == luaO_nilobject || o2 == luaO_nilobject) ? 0 : luaO_rawequalObj(o1, o2); } @@ -313,8 +313,8 @@ int lua_equal(lua_State* L, int index1, int index2) { StkId o1, o2; int i; - o1 = index2adr(L, index1); - o2 = index2adr(L, index2); + o1 = index2addr(L, index1); + o2 = index2addr(L, index2); i = (o1 == luaO_nilobject || o2 == luaO_nilobject) ? 0 : equalobj(L, o1, o2); return i; } @@ -323,8 +323,8 @@ int lua_lessthan(lua_State* L, int index1, int index2) { StkId o1, o2; int i; - o1 = index2adr(L, index1); - o2 = index2adr(L, index2); + o1 = index2addr(L, index1); + o2 = index2addr(L, index2); i = (o1 == luaO_nilobject || o2 == luaO_nilobject) ? 0 : luaV_lessthan(L, o1, o2); return i; } @@ -332,7 +332,7 @@ int lua_lessthan(lua_State* L, int index1, int index2) double lua_tonumberx(lua_State* L, int idx, int* isnum) { TValue n; - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); if (tonumber(o, &n)) { if (isnum) @@ -350,7 +350,7 @@ double lua_tonumberx(lua_State* L, int idx, int* isnum) int lua_tointegerx(lua_State* L, int idx, int* isnum) { TValue n; - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); if (tonumber(o, &n)) { int res; @@ -371,7 +371,7 @@ int lua_tointegerx(lua_State* L, int idx, int* isnum) unsigned lua_tounsignedx(lua_State* L, int idx, int* isnum) { TValue n; - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); if (tonumber(o, &n)) { unsigned res; @@ -391,13 +391,13 @@ unsigned lua_tounsignedx(lua_State* L, int idx, int* isnum) int lua_toboolean(lua_State* L, int idx) { - const TValue* o = index2adr(L, idx); + const TValue* o = index2addr(L, idx); return !l_isfalse(o); } const char* lua_tolstring(lua_State* L, int idx, size_t* len) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); if (!ttisstring(o)) { luaC_checkthreadsleep(L); @@ -408,7 +408,7 @@ const char* lua_tolstring(lua_State* L, int idx, size_t* len) return NULL; } luaC_checkGC(L); - o = index2adr(L, idx); /* previous call may reallocate the stack */ + o = index2addr(L, idx); /* previous call may reallocate the stack */ } if (len != NULL) *len = tsvalue(o)->len; @@ -417,7 +417,7 @@ const char* lua_tolstring(lua_State* L, int idx, size_t* len) const char* lua_tostringatom(lua_State* L, int idx, int* atom) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); if (!ttisstring(o)) return NULL; const TString* s = tsvalue(o); @@ -438,7 +438,7 @@ const char* lua_namecallatom(lua_State* L, int* atom) const float* lua_tovector(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); if (!ttisvector(o)) { return NULL; @@ -448,7 +448,7 @@ const float* lua_tovector(lua_State* L, int idx) int lua_objlen(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); switch (ttype(o)) { case LUA_TSTRING: @@ -469,13 +469,13 @@ int lua_objlen(lua_State* L, int idx) lua_CFunction lua_tocfunction(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return (!iscfunction(o)) ? NULL : cast_to(lua_CFunction, clvalue(o)->c.f); } void* lua_touserdata(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); switch (ttype(o)) { case LUA_TUSERDATA: @@ -489,13 +489,13 @@ void* lua_touserdata(lua_State* L, int idx) void* lua_touserdatatagged(lua_State* L, int idx, int tag) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return (ttisuserdata(o) && uvalue(o)->tag == tag) ? uvalue(o)->data : NULL; } int lua_userdatatag(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); if (ttisuserdata(o)) return uvalue(o)->tag; return -1; @@ -503,13 +503,13 @@ int lua_userdatatag(lua_State* L, int idx) lua_State* lua_tothread(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); return (!ttisthread(o)) ? NULL : thvalue(o); } const void* lua_topointer(lua_State* L, int idx) { - StkId o = index2adr(L, idx); + StkId o = index2addr(L, idx); switch (ttype(o)) { case LUA_TTABLE: @@ -657,7 +657,7 @@ int lua_pushthread(lua_State* L) void lua_gettable(lua_State* L, int idx) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_checkvalidindex(L, t); luaV_gettable(L, t, L->top - 1, L->top - 1); return; @@ -666,7 +666,7 @@ void lua_gettable(lua_State* L, int idx) void lua_getfield(lua_State* L, int idx, const char* k) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_checkvalidindex(L, t); TValue key; setsvalue(L, &key, luaS_new(L, k)); @@ -678,7 +678,7 @@ void lua_getfield(lua_State* L, int idx, const char* k) void lua_rawgetfield(lua_State* L, int idx, const char* k) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_check(L, ttistable(t)); TValue key; setsvalue(L, &key, luaS_new(L, k)); @@ -690,7 +690,7 @@ void lua_rawgetfield(lua_State* L, int idx, const char* k) void lua_rawget(lua_State* L, int idx) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_check(L, ttistable(t)); setobj2s(L, L->top - 1, luaH_get(hvalue(t), L->top - 1)); return; @@ -699,7 +699,7 @@ void lua_rawget(lua_State* L, int idx) void lua_rawgeti(lua_State* L, int idx, int n) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_check(L, ttistable(t)); setobj2s(L, L->top, luaH_getnum(hvalue(t), n)); api_incr_top(L); @@ -717,7 +717,7 @@ void lua_createtable(lua_State* L, int narray, int nrec) void lua_setreadonly(lua_State* L, int objindex, int enabled) { - const TValue* o = index2adr(L, objindex); + const TValue* o = index2addr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); api_check(L, t != hvalue(registry(L))); @@ -727,7 +727,7 @@ void lua_setreadonly(lua_State* L, int objindex, int enabled) int lua_getreadonly(lua_State* L, int objindex) { - const TValue* o = index2adr(L, objindex); + const TValue* o = index2addr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); int res = t->readonly; @@ -736,7 +736,7 @@ int lua_getreadonly(lua_State* L, int objindex) void lua_setsafeenv(lua_State* L, int objindex, int enabled) { - const TValue* o = index2adr(L, objindex); + const TValue* o = index2addr(L, objindex); api_check(L, ttistable(o)); Table* t = hvalue(o); t->safeenv = bool(enabled); @@ -748,7 +748,7 @@ int lua_getmetatable(lua_State* L, int objindex) const TValue* obj; Table* mt = NULL; int res; - obj = index2adr(L, objindex); + obj = index2addr(L, objindex); switch (ttype(obj)) { case LUA_TTABLE: @@ -775,7 +775,7 @@ int lua_getmetatable(lua_State* L, int objindex) void lua_getfenv(lua_State* L, int idx) { StkId o; - o = index2adr(L, idx); + o = index2addr(L, idx); api_checkvalidindex(L, o); switch (ttype(o)) { @@ -801,7 +801,7 @@ void lua_settable(lua_State* L, int idx) { StkId t; api_checknelems(L, 2); - t = index2adr(L, idx); + t = index2addr(L, idx); api_checkvalidindex(L, t); luaV_settable(L, t, L->top - 2, L->top - 1); L->top -= 2; /* pop index and value */ @@ -813,7 +813,7 @@ void lua_setfield(lua_State* L, int idx, const char* k) StkId t; TValue key; api_checknelems(L, 1); - t = index2adr(L, idx); + t = index2addr(L, idx); api_checkvalidindex(L, t); setsvalue(L, &key, luaS_new(L, k)); luaV_settable(L, t, &key, L->top - 1); @@ -825,7 +825,7 @@ void lua_rawset(lua_State* L, int idx) { StkId t; api_checknelems(L, 2); - t = index2adr(L, idx); + t = index2addr(L, idx); api_check(L, ttistable(t)); if (hvalue(t)->readonly) luaG_runerror(L, "Attempt to modify a readonly table"); @@ -839,7 +839,7 @@ void lua_rawseti(lua_State* L, int idx, int n) { StkId o; api_checknelems(L, 1); - o = index2adr(L, idx); + o = index2addr(L, idx); api_check(L, ttistable(o)); if (hvalue(o)->readonly) luaG_runerror(L, "Attempt to modify a readonly table"); @@ -854,7 +854,7 @@ int lua_setmetatable(lua_State* L, int objindex) TValue* obj; Table* mt; api_checknelems(L, 1); - obj = index2adr(L, objindex); + obj = index2addr(L, objindex); api_checkvalidindex(L, obj); if (ttisnil(L->top - 1)) mt = NULL; @@ -896,7 +896,7 @@ int lua_setfenv(lua_State* L, int idx) StkId o; int res = 1; api_checknelems(L, 1); - o = index2adr(L, idx); + o = index2addr(L, idx); api_checkvalidindex(L, o); api_check(L, ttistable(L->top - 1)); switch (ttype(o)) @@ -987,7 +987,7 @@ int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc) func = 0; else { - StkId o = index2adr(L, errfunc); + StkId o = index2addr(L, errfunc); api_checkvalidindex(L, o); func = savestack(L, o); } @@ -1150,7 +1150,7 @@ l_noret lua_error(lua_State* L) int lua_next(lua_State* L, int idx) { luaC_checkthreadsleep(L); - StkId t = index2adr(L, idx); + StkId t = index2addr(L, idx); api_check(L, ttistable(t)); int more = luaH_next(L, hvalue(t), L->top - 1); if (more) @@ -1187,7 +1187,7 @@ void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag) api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); luaC_checkGC(L); luaC_checkthreadsleep(L); - Udata* u = luaS_newudata(L, sz, tag); + Udata* u = luaU_newudata(L, sz, tag); setuvalue(L, L->top, u); api_incr_top(L); return u->data; @@ -1197,7 +1197,7 @@ void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)) { luaC_checkGC(L); luaC_checkthreadsleep(L); - Udata* u = luaS_newudata(L, sz + sizeof(dtor), UTAG_IDTOR); + Udata* u = luaU_newudata(L, sz + sizeof(dtor), UTAG_IDTOR); memcpy(&u->data + sz, &dtor, sizeof(dtor)); setuvalue(L, L->top, u); api_incr_top(L); @@ -1232,7 +1232,7 @@ const char* lua_getupvalue(lua_State* L, int funcindex, int n) { luaC_checkthreadsleep(L); TValue* val; - const char* name = aux_upvalue(index2adr(L, funcindex), n, &val); + const char* name = aux_upvalue(index2addr(L, funcindex), n, &val); if (name) { setobj2s(L, L->top, val); @@ -1246,7 +1246,7 @@ const char* lua_setupvalue(lua_State* L, int funcindex, int n) const char* name; TValue* val; StkId fi; - fi = index2adr(L, funcindex); + fi = index2addr(L, funcindex); api_checknelems(L, 1); name = aux_upvalue(fi, n, &val); if (name) @@ -1270,7 +1270,7 @@ int lua_ref(lua_State* L, int idx) api_check(L, idx != LUA_REGISTRYINDEX); /* idx is a stack index for value */ int ref = LUA_REFNIL; global_State* g = L->global; - StkId p = index2adr(L, idx); + StkId p = index2addr(L, idx); if (!ttisnil(p)) { Table* reg = hvalue(registry(L)); diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index d77f84ef9..9fe1885fb 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -370,6 +370,69 @@ void lua_breakpoint(lua_State* L, int funcindex, int line, int enabled) luaG_breakpoint(L, clvalue(func)->l.p, line, bool(enabled)); } +static int getmaxline(Proto* p) +{ + int result = -1; + + for (int i = 0; i < p->sizecode; ++i) + { + int line = luaG_getline(p, i); + result = result < line ? line : result; + } + + for (int i = 0; i < p->sizep; ++i) + { + int psize = getmaxline(p->p[i]); + result = result < psize ? psize : result; + } + + return result; +} + +static void getcoverage(Proto* p, int depth, int* buffer, size_t size, void* context, lua_Coverage callback) +{ + memset(buffer, -1, size * sizeof(int)); + + for (int i = 0; i < p->sizecode; ++i) + { + Instruction insn = p->code[i]; + if (LUAU_INSN_OP(insn) != LOP_COVERAGE) + continue; + + int line = luaG_getline(p, i); + int hits = LUAU_INSN_E(insn); + + LUAU_ASSERT(size_t(line) < size); + buffer[line] = buffer[line] < hits ? hits : buffer[line]; + } + + const char* debugname = p->debugname ? getstr(p->debugname) : NULL; + int linedefined = luaG_getline(p, 0); + + callback(context, debugname, linedefined, depth, buffer, size); + + for (int i = 0; i < p->sizep; ++i) + getcoverage(p->p[i], depth + 1, buffer, size, context, callback); +} + +void lua_getcoverage(lua_State* L, int funcindex, void* context, lua_Coverage callback) +{ + const TValue* func = luaA_toobject(L, funcindex); + api_check(L, ttisfunction(func) && !clvalue(func)->isC); + + Proto* p = clvalue(func)->l.p; + + size_t size = getmaxline(p) + 1; + if (size == 0) + return; + + int* buffer = luaM_newarray(L, size, int, 0); + + getcoverage(p, 0, buffer, size, context, callback); + + luaM_freearray(L, buffer, size, int, 0); +} + static size_t append(char* buf, size_t bufsize, size_t offset, const char* data) { size_t size = strlen(data); diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index ab416041e..7393fc74f 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -8,11 +8,10 @@ #include "lfunc.h" #include "lstring.h" #include "ldo.h" +#include "ludata.h" #include -LUAU_FASTFLAGVARIABLE(LuauSeparateAtomic, false) - LUAU_FASTFLAG(LuauArrayBoundary) #define GC_SWEEPMAX 40 @@ -59,10 +58,6 @@ static void recordGcStateTime(global_State* g, int startgcstate, double seconds, case GCSpropagate: case GCSpropagateagain: g->gcstats.currcycle.marktime += seconds; - - // atomic step had to be performed during the switch and it's tracked separately - if (!FFlag::LuauSeparateAtomic && g->gcstate == GCSsweepstring) - g->gcstats.currcycle.marktime -= g->gcstats.currcycle.atomictime; break; case GCSatomic: g->gcstats.currcycle.atomictime += seconds; @@ -488,7 +483,7 @@ static void freeobj(lua_State* L, GCObject* o) luaS_free(L, gco2ts(o)); break; case LUA_TUSERDATA: - luaS_freeudata(L, gco2u(o)); + luaU_freeudata(L, gco2u(o)); break; default: LUAU_ASSERT(0); @@ -632,17 +627,9 @@ static size_t remarkupvals(global_State* g) static size_t atomic(lua_State* L) { global_State* g = L->global; - size_t work = 0; - - if (FFlag::LuauSeparateAtomic) - { - LUAU_ASSERT(g->gcstate == GCSatomic); - } - else - { - g->gcstate = GCSatomic; - } + LUAU_ASSERT(g->gcstate == GCSatomic); + size_t work = 0; /* remark occasional upvalues of (maybe) dead threads */ work += remarkupvals(g); /* traverse objects caught by write barrier and by 'remarkupvals' */ @@ -666,11 +653,6 @@ static size_t atomic(lua_State* L) g->sweepgc = &g->rootgc; g->gcstate = GCSsweepstring; - if (!FFlag::LuauSeparateAtomic) - { - GC_INTERRUPT(GCSatomic); - } - return work; } @@ -716,22 +698,7 @@ static size_t gcstep(lua_State* L, size_t limit) if (!g->gray) /* no more `gray' objects */ { - if (FFlag::LuauSeparateAtomic) - { - g->gcstate = GCSatomic; - } - else - { - double starttimestamp = lua_clock(); - - g->gcstats.currcycle.atomicstarttimestamp = starttimestamp; - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; - - atomic(L); /* finish mark phase */ - LUAU_ASSERT(g->gcstate == GCSsweepstring); - - g->gcstats.currcycle.atomictime += lua_clock() - starttimestamp; - } + g->gcstate = GCSatomic; } break; } @@ -853,7 +820,7 @@ static size_t getheaptrigger(global_State* g, size_t heapgoal) void luaC_step(lua_State* L, bool assist) { global_State* g = L->global; - ptrdiff_t lim = (g->gcstepsize / 100) * g->gcstepmul; /* how much to work */ + int lim = (g->gcstepsize / 100) * g->gcstepmul; /* how much to work */ LUAU_ASSERT(g->totalbytes >= g->GCthreshold); size_t debt = g->totalbytes - g->GCthreshold; @@ -908,7 +875,7 @@ void luaC_fullgc(lua_State* L) if (g->gcstate == GCSpause) startGcCycleStats(g); - if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain)) + if (g->gcstate <= GCSatomic) { /* reset sweep marks to sweep all elements (returning them to white) */ g->sweepstrgc = 0; @@ -1049,7 +1016,7 @@ int64_t luaC_allocationrate(lua_State* L) global_State* g = L->global; const double durationthreshold = 1e-3; // avoid measuring intervals smaller than 1ms - if (g->gcstate <= (FFlag::LuauSeparateAtomic ? GCSatomic : GCSpropagateagain)) + if (g->gcstate <= GCSatomic) { double duration = lua_clock() - g->gcstats.lastcycle.endtimestamp; diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index a79e7b953..f6f7a878f 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -7,6 +7,7 @@ #include "ltable.h" #include "lfunc.h" #include "lstring.h" +#include "ludata.h" #include #include diff --git a/VM/src/lobject.h b/VM/src/lobject.h index ba040af6c..fd0a15b75 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -78,15 +78,7 @@ typedef struct lua_TValue #define thvalue(o) check_exp(ttisthread(o), &(o)->value.gc->th) #define upvalue(o) check_exp(ttisupval(o), &(o)->value.gc->uv) -// beware bit magic: a value is false if it's nil or boolean false -// baseline implementation: (ttisnil(o) || (ttisboolean(o) && bvalue(o) == 0)) -// we'd like a branchless version of this which helps with performance, and a very fast version -// so our strategy is to always read the boolean value (not using bvalue(o) because that asserts when type isn't boolean) -// we then combine it with type to produce 0/1 as follows: -// - when type is nil (0), & makes the result 0 -// - when type is boolean (1), we effectively only look at the bottom bit, so result is 0 iff boolean value is 0 -// - when type is different, it must have some of the top bits set - we keep all top bits of boolean value so the result is non-0 -#define l_isfalse(o) (!(((o)->value.b | ~1) & ttype(o))) +#define l_isfalse(o) (ttisnil(o) || (ttisboolean(o) && bvalue(o) == 0)) /* ** for internal debug only diff --git a/VM/src/lstring.cpp b/VM/src/lstring.cpp index 18ee1cda5..a9e90d17a 100644 --- a/VM/src/lstring.cpp +++ b/VM/src/lstring.cpp @@ -206,32 +206,3 @@ void luaS_free(lua_State* L, TString* ts) L->global->strt.nuse--; luaM_free(L, ts, sizestring(ts->len), ts->memcat); } - -Udata* luaS_newudata(lua_State* L, size_t s, int tag) -{ - if (s > INT_MAX - sizeof(Udata)) - luaM_toobig(L); - Udata* u = luaM_new(L, Udata, sizeudata(s), L->activememcat); - luaC_link(L, u, LUA_TUSERDATA); - u->len = int(s); - u->metatable = NULL; - LUAU_ASSERT(tag >= 0 && tag <= 255); - u->tag = uint8_t(tag); - return u; -} - -void luaS_freeudata(lua_State* L, Udata* u) -{ - LUAU_ASSERT(u->tag < LUA_UTAG_LIMIT || u->tag == UTAG_IDTOR); - - void (*dtor)(void*) = nullptr; - if (u->tag == UTAG_IDTOR) - memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor)); - else if (u->tag) - dtor = L->global->udatagc[u->tag]; - - if (dtor) - dtor(u->data); - - luaM_free(L, u, sizeudata(u->len), u->memcat); -} diff --git a/VM/src/lstring.h b/VM/src/lstring.h index 612da28d5..3fd0bd39b 100644 --- a/VM/src/lstring.h +++ b/VM/src/lstring.h @@ -8,11 +8,7 @@ /* string size limit */ #define MAXSSIZE (1 << 30) -/* special tag value is used for user data with inline dtors */ -#define UTAG_IDTOR LUA_UTAG_LIMIT - #define sizestring(len) (offsetof(TString, data) + len + 1) -#define sizeudata(len) (offsetof(Udata, data) + len) #define luaS_new(L, s) (luaS_newlstr(L, s, strlen(s))) #define luaS_newliteral(L, s) (luaS_newlstr(L, "" s, (sizeof(s) / sizeof(char)) - 1)) @@ -26,8 +22,5 @@ LUAI_FUNC void luaS_resize(lua_State* L, int newsize); LUAI_FUNC TString* luaS_newlstr(lua_State* L, const char* str, size_t l); LUAI_FUNC void luaS_free(lua_State* L, TString* ts); -LUAI_FUNC Udata* luaS_newudata(lua_State* L, size_t s, int tag); -LUAI_FUNC void luaS_freeudata(lua_State* L, Udata* u); - LUAI_FUNC TString* luaS_bufstart(lua_State* L, size_t size); LUAI_FUNC TString* luaS_buffinish(lua_State* L, TString* ts); diff --git a/VM/src/ludata.cpp b/VM/src/ludata.cpp new file mode 100644 index 000000000..d180c388e --- /dev/null +++ b/VM/src/ludata.cpp @@ -0,0 +1,37 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "ludata.h" + +#include "lgc.h" +#include "lmem.h" + +#include + +Udata* luaU_newudata(lua_State* L, size_t s, int tag) +{ + if (s > INT_MAX - sizeof(Udata)) + luaM_toobig(L); + Udata* u = luaM_new(L, Udata, sizeudata(s), L->activememcat); + luaC_link(L, u, LUA_TUSERDATA); + u->len = int(s); + u->metatable = NULL; + LUAU_ASSERT(tag >= 0 && tag <= 255); + u->tag = uint8_t(tag); + return u; +} + +void luaU_freeudata(lua_State* L, Udata* u) +{ + LUAU_ASSERT(u->tag < LUA_UTAG_LIMIT || u->tag == UTAG_IDTOR); + + void (*dtor)(void*) = nullptr; + if (u->tag == UTAG_IDTOR) + memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor)); + else if (u->tag) + dtor = L->global->udatagc[u->tag]; + + if (dtor) + dtor(u->data); + + luaM_free(L, u, sizeudata(u->len), u->memcat); +} diff --git a/VM/src/ludata.h b/VM/src/ludata.h new file mode 100644 index 000000000..59cb85bd1 --- /dev/null +++ b/VM/src/ludata.h @@ -0,0 +1,13 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#pragma once + +#include "lobject.h" + +/* special tag value is used for user data with inline dtors */ +#define UTAG_IDTOR LUA_UTAG_LIMIT + +#define sizeudata(len) (offsetof(Udata, data) + len) + +LUAI_FUNC Udata* luaU_newudata(lua_State* L, size_t s, int tag); +LUAI_FUNC void luaU_freeudata(lua_State* L, Udata* u); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index bf8d493eb..cebeeb584 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -63,7 +63,8 @@ #define VM_KV(i) (LUAU_ASSERT(unsigned(i) < unsigned(cl->l.p->sizek)), &k[i]) #define VM_UV(i) (LUAU_ASSERT(unsigned(i) < unsigned(cl->nupvalues)), &cl->l.uprefs[i]) -#define VM_PATCH_C(pc, slot) ((uint8_t*)(pc))[3] = uint8_t(slot) +#define VM_PATCH_C(pc, slot) *const_cast(pc) = ((uint8_t(slot) << 24) | (0x00ffffffu & *(pc))) +#define VM_PATCH_E(pc, slot) *const_cast(pc) = ((uint32_t(slot) << 8) | (0x000000ffu & *(pc))) // NOTE: If debugging the Luau code, disable this macro to prevent timeouts from // occurring when tracing code in Visual Studio / XCode @@ -120,7 +121,7 @@ */ #if VM_USE_CGOTO #define VM_CASE(op) CASE_##op: -#define VM_NEXT() goto*(SingleStep ? &&dispatch : kDispatchTable[*(uint8_t*)pc]) +#define VM_NEXT() goto*(SingleStep ? &&dispatch : kDispatchTable[LUAU_INSN_OP(*pc)]) #define VM_CONTINUE(op) goto* kDispatchTable[uint8_t(op)] #else #define VM_CASE(op) case op: @@ -325,7 +326,7 @@ static void luau_execute(lua_State* L) // ... and singlestep logic :) if (SingleStep) { - if (L->global->cb.debugstep && !luau_skipstep(*(uint8_t*)pc)) + if (L->global->cb.debugstep && !luau_skipstep(LUAU_INSN_OP(*pc))) { VM_PROTECT(luau_callhook(L, L->global->cb.debugstep, NULL)); @@ -335,13 +336,12 @@ static void luau_execute(lua_State* L) } #if VM_USE_CGOTO - VM_CONTINUE(*(uint8_t*)pc); + VM_CONTINUE(LUAU_INSN_OP(*pc)); #endif } #if !VM_USE_CGOTO - // Note: this assumes that LUAU_INSN_OP() decodes the first byte (aka least significant byte in the little endian encoding) - size_t dispatchOp = *(uint8_t*)pc; + size_t dispatchOp = LUAU_INSN_OP(*pc); dispatchContinue: switch (dispatchOp) @@ -2577,7 +2577,7 @@ static void luau_execute(lua_State* L) // update hits with saturated add and patch the instruction in place hits = (hits < (1 << 23) - 1) ? hits + 1 : hits; - ((uint32_t*)pc)[-1] = LOP_COVERAGE | (uint32_t(hits) << 8); + VM_PATCH_E(pc - 1, hits); VM_NEXT(); } diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 740a4cfd2..5d802277a 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -53,7 +53,7 @@ const float* luaV_tovector(const TValue* obj) static void callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1, const TValue* p2) { ptrdiff_t result = savestack(L, res); - // RBOLOX: using stack room beyond top is technically safe here, but for very complicated reasons: + // using stack room beyond top is technically safe here, but for very complicated reasons: // * The stack guarantees 1 + EXTRA_STACK room beyond stack_last (see luaD_reallocstack) will be allocated // * we cannot move luaD_checkstack above because the arguments are *sometimes* pointers to the lua // stack and checkstack may invalidate those pointers @@ -74,7 +74,7 @@ static void callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1 static void callTM(lua_State* L, const TValue* f, const TValue* p1, const TValue* p2, const TValue* p3) { - // RBOLOX: using stack room beyond top is technically safe here, but for very complicated reasons: + // using stack room beyond top is technically safe here, but for very complicated reasons: // * The stack guarantees 1 + EXTRA_STACK room beyond stack_last (see luaD_reallocstack) will be allocated // * we cannot move luaD_checkstack above because the arguments are *sometimes* pointers to the lua // stack and checkstack may invalidate those pointers diff --git a/bench/tests/sunspider/3d-raytrace.lua b/bench/tests/sunspider/3d-raytrace.lua index 60e4f61e4..c8f6b5dcd 100644 --- a/bench/tests/sunspider/3d-raytrace.lua +++ b/bench/tests/sunspider/3d-raytrace.lua @@ -451,15 +451,16 @@ function raytraceScene() end function arrayToCanvasCommands(pixels) - local s = 'Test\nvar pixels = ['; + local s = {}; + table.insert(s, 'Test\nvar pixels = ['); for y = 0,size-1 do - s = s .. "["; + table.insert(s, "["); for x = 0,size-1 do - s = s .. "[" .. math.floor(pixels[y + 1][x + 1][1] * 255) .. "," .. math.floor(pixels[y + 1][x + 1][2] * 255) .. "," .. math.floor(pixels[y + 1][x + 1][3] * 255) .. "],"; + table.insert(s, "[" .. math.floor(pixels[y + 1][x + 1][1] * 255) .. "," .. math.floor(pixels[y + 1][x + 1][2] * 255) .. "," .. math.floor(pixels[y + 1][x + 1][3] * 255) .. "],"); end - s = s .. "],"; + table.insert(s, "],"); end - s = s .. '];\n var canvas = document.getElementById("renderCanvas").getContext("2d");\n\ + table.insert(s, '];\n var canvas = document.getElementById("renderCanvas").getContext("2d");\n\ \n\ \n\ var size = ' .. size .. ';\n\ @@ -479,9 +480,9 @@ for (var y = 0; y < size; y++) {\n\ canvas.setFillColor(l[0], l[1], l[2], 1);\n\ canvas.fillRect(x, y, 1, 1);\n\ }\n\ -}'; +}'); - return s; + return table.concat(s); end testOutput = arrayToCanvasCommands(raytraceScene()); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 3b74a99e4..62a9999b0 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -513,8 +513,6 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_the_end_of_a_comme TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_comment") { - ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - check(R"( --[[ @1 )"); @@ -526,8 +524,6 @@ TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_co TEST_CASE_FIXTURE(ACFixture, "dont_offer_any_suggestions_from_within_a_broken_comment_at_the_very_end_of_the_file") { - ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - check("--[[@1"); auto ac = autocomplete('1'); @@ -2625,4 +2621,55 @@ local a: A<(number, s@1> CHECK(ac.entryMap.count("string")); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_first_function_arg_expected_type") +{ + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + ScopedFastFlag luauAutocompleteFirstArg("LuauAutocompleteFirstArg", true); + + check(R"( +local function foo1() return 1 end +local function foo2() return "1" end + +local function bar0() return "got" .. a end +local function bar1(a: number) return "got " .. a end +local function bar2(a: number, b: string) return "got " .. a .. b end + +local t = {} +function t:bar1(a: number) return "got " .. a end + +local r1 = bar0(@1) +local r2 = bar1(@2) +local r3 = bar2(@3) +local r4 = t:bar1(@4) + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("foo1")); + CHECK(ac.entryMap["foo1"].typeCorrect == TypeCorrectKind::None); + REQUIRE(ac.entryMap.count("foo2")); + CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); + + ac = autocomplete('2'); + + REQUIRE(ac.entryMap.count("foo1")); + CHECK(ac.entryMap["foo1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); + REQUIRE(ac.entryMap.count("foo2")); + CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); + + ac = autocomplete('3'); + + REQUIRE(ac.entryMap.count("foo1")); + CHECK(ac.entryMap["foo1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); + REQUIRE(ac.entryMap.count("foo2")); + CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); + + ac = autocomplete('4'); + + REQUIRE(ac.entryMap.count("foo1")); + CHECK(ac.entryMap["foo1"].typeCorrect == TypeCorrectKind::CorrectFunctionResult); + REQUIRE(ac.entryMap.count("foo2")); + CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); +} + TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 6ba39adab..95811b3f4 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -10,9 +10,6 @@ #include #include -LUAU_FASTFLAG(LuauPreloadClosures) -LUAU_FASTFLAG(LuauGenericSpecialGlobals) - using namespace Luau; static std::string compileFunction(const char* source, uint32_t id) @@ -74,20 +71,10 @@ TEST_CASE("BasicFunction") bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code); Luau::compileOrThrow(bcb, "local function foo(a, b) return b end"); - if (FFlag::LuauPreloadClosures) - { - CHECK_EQ("\n" + bcb.dumpFunction(1), R"( + CHECK_EQ("\n" + bcb.dumpFunction(1), R"( DUPCLOSURE R0 K0 RETURN R0 0 )"); - } - else - { - CHECK_EQ("\n" + bcb.dumpFunction(1), R"( -NEWCLOSURE R0 P0 -RETURN R0 0 -)"); - } CHECK_EQ("\n" + bcb.dumpFunction(0), R"( RETURN R1 1 @@ -2859,47 +2846,35 @@ CAPTURE UPVAL U1 RETURN R0 1 )"); - if (FFlag::LuauPreloadClosures) - { - // recursive capture - CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( + // recursive capture + CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( DUPCLOSURE R0 K0 CAPTURE VAL R0 RETURN R0 0 )"); - // multi-level recursive capture - CHECK_EQ("\n" + compileFunction("local function foo() return function() return foo() end end", 1), R"( + // multi-level recursive capture + CHECK_EQ("\n" + compileFunction("local function foo() return function() return foo() end end", 1), R"( DUPCLOSURE R0 K0 CAPTURE UPVAL U0 RETURN R0 1 )"); - // multi-level recursive capture where function isn't top-level - // note: this should probably be optimized to DUPCLOSURE but doing that requires a different upval tracking flow in the compiler - CHECK_EQ("\n" + compileFunction(R"( + // multi-level recursive capture where function isn't top-level + // note: this should probably be optimized to DUPCLOSURE but doing that requires a different upval tracking flow in the compiler + CHECK_EQ("\n" + compileFunction(R"( local function foo() local function bar() return function() return bar() end end end )", - 1), - R"( + 1), + R"( NEWCLOSURE R0 P0 CAPTURE UPVAL U0 RETURN R0 1 )"); - } - else - { - // recursive capture - CHECK_EQ("\n" + compileFunction("local function foo() return foo() end", 1), R"( -NEWCLOSURE R0 P0 -CAPTURE VAL R0 -RETURN R0 0 -)"); - } } TEST_CASE("OutOfLocals") @@ -3504,8 +3479,6 @@ local t = { TEST_CASE("ConstantClosure") { - ScopedFastFlag sff("LuauPreloadClosures", true); - // closures without upvalues are created when bytecode is loaded CHECK_EQ("\n" + compileFunction(R"( return function() end @@ -3570,8 +3543,6 @@ RETURN R0 1 TEST_CASE("SharedClosure") { - ScopedFastFlag sff1("LuauPreloadClosures", true); - // closures can be shared even if functions refer to upvalues, as long as upvalues are top-level CHECK_EQ("\n" + compileFunction(R"( local val = ... diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index b2aad3163..b055a38e4 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -123,8 +123,8 @@ int lua_silence(lua_State* L) using StateRef = std::unique_ptr; -static StateRef runConformance( - const char* name, void (*setup)(lua_State* L) = nullptr, void (*yield)(lua_State* L) = nullptr, lua_State* initialLuaState = nullptr) +static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = nullptr, void (*yield)(lua_State* L) = nullptr, + lua_State* initialLuaState = nullptr, lua_CompileOptions* copts = nullptr) { std::string path = __FILE__; path.erase(path.find_last_of("\\/")); @@ -180,13 +180,8 @@ static StateRef runConformance( std::string chunkname = "=" + std::string(name); - lua_CompileOptions copts = {}; - copts.optimizationLevel = 1; // default - copts.debugLevel = 2; // for debugger tests - copts.vectorCtor = "vector"; // for vector tests - size_t bytecodeSize = 0; - char* bytecode = luau_compile(source.data(), source.size(), &copts, &bytecodeSize); + char* bytecode = luau_compile(source.data(), source.size(), copts, &bytecodeSize); int result = luau_load(L, chunkname.c_str(), bytecode, bytecodeSize, 0); free(bytecode); @@ -373,29 +368,37 @@ TEST_CASE("Vector") { ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; - runConformance("vector.lua", [](lua_State* L) { - lua_pushcfunction(L, lua_vector, "vector"); - lua_setglobal(L, "vector"); + lua_CompileOptions copts = {}; + copts.optimizationLevel = 1; + copts.debugLevel = 1; + copts.vectorCtor = "vector"; + + runConformance( + "vector.lua", + [](lua_State* L) { + lua_pushcfunction(L, lua_vector, "vector"); + lua_setglobal(L, "vector"); #if LUA_VECTOR_SIZE == 4 - lua_pushvector(L, 0.0f, 0.0f, 0.0f, 0.0f); + lua_pushvector(L, 0.0f, 0.0f, 0.0f, 0.0f); #else - lua_pushvector(L, 0.0f, 0.0f, 0.0f); + lua_pushvector(L, 0.0f, 0.0f, 0.0f); #endif - luaL_newmetatable(L, "vector"); + luaL_newmetatable(L, "vector"); - lua_pushstring(L, "__index"); - lua_pushcfunction(L, lua_vector_index, nullptr); - lua_settable(L, -3); + lua_pushstring(L, "__index"); + lua_pushcfunction(L, lua_vector_index, nullptr); + lua_settable(L, -3); - lua_pushstring(L, "__namecall"); - lua_pushcfunction(L, lua_vector_namecall, nullptr); - lua_settable(L, -3); + lua_pushstring(L, "__namecall"); + lua_pushcfunction(L, lua_vector_namecall, nullptr); + lua_settable(L, -3); - lua_setreadonly(L, -1, true); - lua_setmetatable(L, -2); - lua_pop(L, 1); - }); + lua_setreadonly(L, -1, true); + lua_setmetatable(L, -2); + lua_pop(L, 1); + }, + nullptr, nullptr, &copts); } static void populateRTTI(lua_State* L, Luau::TypeId type) @@ -499,6 +502,10 @@ TEST_CASE("Debugger") breakhits = 0; interruptedthread = nullptr; + lua_CompileOptions copts = {}; + copts.optimizationLevel = 1; + copts.debugLevel = 2; + runConformance( "debugger.lua", [](lua_State* L) { @@ -614,7 +621,8 @@ TEST_CASE("Debugger") lua_resume(interruptedthread, nullptr, 0); interruptedthread = nullptr; } - }); + }, + nullptr, &copts); CHECK(breakhits == 10); // 2 hits per breakpoint } @@ -863,4 +871,46 @@ TEST_CASE("TagMethodError") }); } +TEST_CASE("Coverage") +{ + lua_CompileOptions copts = {}; + copts.optimizationLevel = 1; + copts.debugLevel = 1; + copts.coverageLevel = 2; + + runConformance( + "coverage.lua", + [](lua_State* L) { + lua_pushcfunction( + L, + [](lua_State* L) -> int { + luaL_argexpected(L, lua_isLfunction(L, 1), 1, "function"); + + lua_newtable(L); + lua_getcoverage(L, 1, L, [](void* context, const char* function, int linedefined, int depth, const int* hits, size_t size) { + lua_State* L = static_cast(context); + + lua_newtable(L); + + lua_pushstring(L, function); + lua_setfield(L, -2, "name"); + + for (size_t i = 0; i < size; ++i) + if (hits[i] != -1) + { + lua_pushinteger(L, hits[i]); + lua_rawseti(L, -2, int(i)); + } + + lua_rawseti(L, -2, lua_objlen(L, -2) + 1); + }); + + return 1; + }, + "getcoverage"); + lua_setglobal(L, "getcoverage"); + }, + nullptr, nullptr, &copts); +} + TEST_SUITE_END(); diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index 2800d2fe6..e3993cc53 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -278,7 +278,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") #if defined(_DEBUG) || defined(_NOOPT) int limit = 250; #else - int limit = 500; + int limit = 400; #endif ScopedFastInt luauTypeCloneRecursionLimit{"LuauTypeCloneRecursionLimit", limit}; @@ -287,7 +287,7 @@ TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") TypeId table = src.addType(TableTypeVar{}); TypeId nested = table; - for (unsigned i = 0; i < limit + 100; i++) + for (int i = 0; i < limit + 100; i++) { TableTypeVar* ttv = getMutable(nested); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 72d3a9a64..5abcb09ac 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2303,8 +2303,6 @@ TEST_CASE_FIXTURE(Fixture, "capture_comments") TEST_CASE_FIXTURE(Fixture, "capture_broken_comment_at_the_start_of_the_file") { - ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - ParseOptions options; options.captureComments = true; @@ -2319,8 +2317,6 @@ TEST_CASE_FIXTURE(Fixture, "capture_broken_comment_at_the_start_of_the_file") TEST_CASE_FIXTURE(Fixture, "capture_broken_comment") { - ScopedFastFlag sff{"LuauCaptureBrokenCommentSpans", true}; - ParseOptions options; options.captureComments = true; diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 091c2f012..71ff4e1b8 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -207,6 +207,36 @@ TEST_CASE_FIXTURE(Fixture, "as_expr_does_not_propagate_type_info") CHECK_EQ("number", toString(requireType("b"))); } +TEST_CASE_FIXTURE(Fixture, "as_expr_is_bidirectional") +{ + ScopedFastFlag sff{"LuauBidirectionalAsExpr", true}; + + CheckResult result = check(R"( + local a = 55 :: number? + local b = a :: number + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number?", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "as_expr_warns_on_unrelated_cast") +{ + ScopedFastFlag sff{"LuauBidirectionalAsExpr", true}; + ScopedFastFlag sff2{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( + local a = 55 :: string + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Cannot cast 'number' into 'string' because the types are unrelated", toString(result.errors[0])); + CHECK_EQ("string", toString(requireType("a"))); +} + TEST_CASE_FIXTURE(Fixture, "type_annotations_inside_function_bodies") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 1e2eae147..1d8135d4c 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -7,6 +7,8 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauFixTonumberReturnType) + using namespace Luau; TEST_SUITE_BEGIN("BuiltinTests"); @@ -814,6 +816,30 @@ TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_posi CHECK_EQ(TypeErrorData(TypeMismatch{stringType, booleanType}), result.errors[2].data); } +TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type") +{ + CheckResult result = check(R"( + --!strict + local b: number = tonumber('asdf') + )"); + + if (FFlag::LuauFixTonumberReturnType) + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number?' could not be converted into 'number'", toString(result.errors[0])); + } +} + +TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type2") +{ + CheckResult result = check(R"( + --!strict + local b: number = tonumber('asdf') or 1 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "dont_add_definitions_to_persistent_types") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index aba508918..b62044fa9 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -9,6 +9,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauExtendedFunctionMismatchError) + TEST_SUITE_BEGIN("GenericsTests"); TEST_CASE_FIXTURE(Fixture, "check_generic_function") @@ -644,4 +646,42 @@ f(1, 2, 3) CHECK_EQ(toString(*ty, opts), "(a: number, number, number) -> ()"); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_generic_types") +{ + CheckResult result = check(R"( +type C = () -> () +type D = () -> () + +local c: C +local d: D = c + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauExtendedFunctionMismatchError) + CHECK_EQ( + toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type parameters)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_generic_pack") +{ + CheckResult result = check(R"( +type C = () -> () +type D = () -> () + +local c: C +local d: D = c + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauExtendedFunctionMismatchError) + CHECK_EQ(toString(result.errors[0]), + R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type pack parameters)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()')"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index fe8e7ff90..688680c10 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -279,7 +279,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_non_binary_expressions_actually_resolve_const TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_type_is_illegal") { - ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; + ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( @@ -1085,4 +1085,19 @@ TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression") CHECK_EQ("any", toString(requireTypeAtPosition({6, 66}))); } +TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string") +{ + ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true}; + + CheckResult result = check(R"( + type T = { [string]: { prop: number }? } + local t: T = {} + if t["hello"] then + local foo = t["hello"].prop + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 5f95efd52..1621ef32f 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -374,4 +374,54 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") toString(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_string") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauUnionHeuristic", true}, + {"LuauExpectedTypesOfProperties", true}, + {"LuauExtendedUnionMismatchError", true}, + }; + + CheckResult result = check(R"( +type Cat = { tag: 'cat', catfood: string } +type Dog = { tag: 'dog', dogfood: string } +type Animal = Cat | Dog + +local a: Animal = { tag = 'cat', cafood = 'something' } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type 'a' could not be converted into 'Cat | Dog' +caused by: + None of the union options are compatible. For example: Table type 'a' not compatible with type 'Cat' because the former is missing field 'catfood')", + toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_bool") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauUnionHeuristic", true}, + {"LuauExpectedTypesOfProperties", true}, + {"LuauExtendedUnionMismatchError", true}, + }; + + CheckResult result = check(R"( +type Good = { success: true, result: string } +type Bad = { success: false, error: string } +type Result = Good | Bad + +local a: Result = { success = false, result = 'something' } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type 'a' could not be converted into 'Bad | Good' +caused by: + None of the union options are compatible. For example: Table type 'a' not compatible with type 'Bad' because the former is missing field 'error')", + toString(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index cb72faaf4..3ea9b80c3 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -12,6 +12,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauExtendedFunctionMismatchError) + TEST_SUITE_BEGIN("TableTests"); TEST_CASE_FIXTURE(Fixture, "basic") @@ -275,7 +277,7 @@ TEST_CASE_FIXTURE(Fixture, "open_table_unification") TEST_CASE_FIXTURE(Fixture, "open_table_unification_2") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( local a = {} @@ -346,7 +348,7 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_1") TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( --!strict @@ -369,7 +371,7 @@ TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_2") TEST_CASE_FIXTURE(Fixture, "table_param_row_polymorphism_3") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( local T = {} @@ -476,7 +478,7 @@ TEST_CASE_FIXTURE(Fixture, "ok_to_add_property_to_free_table") TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_assignment") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( --!strict @@ -511,7 +513,7 @@ TEST_CASE_FIXTURE(Fixture, "okay_to_add_property_to_unsealed_tables_by_function_ TEST_CASE_FIXTURE(Fixture, "width_subtyping") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( --!strict @@ -771,7 +773,7 @@ TEST_CASE_FIXTURE(Fixture, "infer_indexer_for_left_unsealed_table_from_right_han TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( local t: { a: string, [number]: string } = { a = "foo" } @@ -782,7 +784,7 @@ TEST_CASE_FIXTURE(Fixture, "sealed_table_value_can_infer_an_indexer") TEST_CASE_FIXTURE(Fixture, "array_factory_function") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( function empty() return {} end @@ -1465,7 +1467,7 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer2") TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( local function foo(a: {[string]: number, a: string}) end @@ -1550,7 +1552,7 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multipl TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_is_ok") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( local vec3 = {x = 1, y = 2, z = 3} @@ -1937,7 +1939,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in_strict") { - ScopedFastFlag sff{"LuauTableSubtypingVariance", true}; + ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( --!strict @@ -1952,7 +1954,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in TEST_CASE_FIXTURE(Fixture, "error_detailed_prop") { - ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; // Only for new path + ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( @@ -1971,7 +1973,7 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_prop_nested") { - ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; // Only for new path + ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( @@ -1995,7 +1997,7 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_metatable_prop") { - ScopedFastFlag luauTableSubtypingVariance{"LuauTableSubtypingVariance", true}; // Only for new path + ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( @@ -2015,11 +2017,22 @@ caused by: caused by: Property 'y' is not compatible. Type 'string' could not be converted into 'number')"); - CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' + if (FFlag::LuauExtendedFunctionMismatchError) + { + CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' +caused by: + Type '{| __call: (a, b) -> () |}' could not be converted into '{| __call: (a) -> () |}' +caused by: + Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); + } + else + { + CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' caused by: Type '{| __call: (a, b) -> () |}' could not be converted into '{| __call: (a) -> () |}' caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()')"); + } } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") @@ -2027,7 +2040,7 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") ScopedFastFlag sffs[] { {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauTableSubtypingVariance", true}, + {"LuauTableSubtypingVariance2", true}, }; CheckResult result = check(R"( @@ -2048,7 +2061,7 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") ScopedFastFlag sffs[] { {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauTableSubtypingVariance", true}, + {"LuauTableSubtypingVariance2", true}, {"LuauExtendedTypeMismatchError", true}, }; @@ -2076,7 +2089,7 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") ScopedFastFlag sffs[] { {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauTableSubtypingVariance", true}, + {"LuauTableSubtypingVariance2", true}, }; CheckResult result = check(R"( @@ -2092,4 +2105,18 @@ a.p = { x = 9 } LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "recursive_metatable_type_call") +{ + ScopedFastFlag luauFixRecursiveMetatableCall{"LuauFixRecursiveMetatableCall", true}; + + CheckResult result = check(R"( +local b +b = setmetatable({}, {__call = b}) +b() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Cannot call non-function t1 where t1 = { @metatable {| __call: t1 |}, { } })"); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index e3222a410..ad9ea8276 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -16,6 +16,7 @@ LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr) LUAU_FASTFLAG(LuauEqConstraint) +LUAU_FASTFLAG(LuauExtendedFunctionMismatchError) using namespace Luau; @@ -2084,7 +2085,7 @@ TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") { CheckResult result = check(R"( function add(a: number, b: string) - return a + tonumber(b), a .. b + return a + (tonumber(b) :: number), a .. b end local n, s = add(2,"3") )"); @@ -2485,7 +2486,7 @@ TEST_CASE_FIXTURE(Fixture, "inferring_crazy_table_should_also_be_quick") CheckResult result = check(R"( --!strict function f(U) - U(w:s(an):c()():c():U(s):c():c():U(s):c():U(s):cU()):c():U(s):c():U(s):c():c():U(s):c():U(s):cU() + U(w:s(an):c()():c():U(s):c():c():U(s):c():U(s):cU()):c():U(s):c():U(s):c():c():U(s):c():U(s):cU() end )"); @@ -3329,7 +3330,7 @@ TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable { CheckResult result = check(R"( local x - print((x == true and (x .. "y")) .. 1) + print((x == true and (x .. "y")) .. 1) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -4473,7 +4474,18 @@ f(function(a, b, c, ...) return a + b end) )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Type '(number, number, a) -> number' could not be converted into '(number, number) -> number'", toString(result.errors[0])); + + if (FFlag::LuauExtendedFunctionMismatchError) + { + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' +caused by: + Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", + toString(result.errors[0])); + } + else + { + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number')", toString(result.errors[0])); + } // Infer from variadic packs into elements result = check(R"( @@ -4604,7 +4616,17 @@ local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not i )"); LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'", toString(result.errors[0])); + if (FFlag::LuauExtendedFunctionMismatchError) + { + CHECK_EQ( + "Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'; different number of generic type " + "parameters", + toString(result.errors[0])); + } + else + { + CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'", toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "infer_return_value_type") @@ -4799,4 +4821,211 @@ local ModuleA = require(game.A) CHECK_EQ("*unknown*", toString(*oty)); } +/* + * If it wasn't instantly obvious, we have the fuzzer to thank for this gem of a test. + * + * We had an issue here where the scope for the `if` block here would + * have an elevated TypeLevel even though there is no function nesting going on. + * This would result in a free typevar for the type of _ that was much higher than + * it should be. This type would be erroneously quantified in the definition of `aaa`. + * This in turn caused an ice when evaluating `_()` in the while loop. + */ +TEST_CASE_FIXTURE(Fixture, "free_typevars_introduced_within_control_flow_constructs_do_not_get_an_elevated_TypeLevel") +{ + check(R"( + --!strict + if _ then + _[_], _ = nil + _() + end + + local aaa = function():typeof(_) return 1 end + + if aaa then + while _() do + end + end + )"); + + // No ice()? No problem. +} + +/* + * This is a bit elaborate. Bear with me. + * + * The type of _ becomes free with the first statement. With the second, we unify it with a function. + * + * At this point, it is important that the newly created fresh types of this new function type are promoted + * to the same level as the original free type. If we do not, they are incorrectly ascribed the level of the + * containing function. + * + * If this is allowed to happen, the final lambda erroneously quantifies the type of _ to something ridiculous + * just before we typecheck the invocation to _. + */ +TEST_CASE_FIXTURE(Fixture, "fuzzer_found_this") +{ + check(R"( + l0, _ = nil + + local function p() + _() + end + + a = _( + function():(typeof(p),typeof(_)) + end + )[nil] + )"); +} + +/* + * We had an issue where part of the type of pairs() was an unsealed table. + * This test depends on FFlagDebugLuauFreezeArena to trigger it. + */ +TEST_CASE_FIXTURE(Fixture, "pairs_parameters_are_not_unsealed_tables") +{ + check(R"( + function _(l0:{n0:any}) + _ = pairs + end + )"); +} + +TEST_CASE_FIXTURE(Fixture, "inferred_methods_of_free_tables_have_the_same_level_as_the_enclosing_table") +{ + check(R"( + function Base64FileReader(data) + local reader = {} + local index: number + + function reader:PeekByte() + return data:byte(index) + end + + function reader:Byte() + return data:byte(index - 1) + end + + return reader + end + + Base64FileReader() + + function ReadMidiEvents(data) + + local reader = Base64FileReader(data) + + while reader:HasMore() do + (reader:Byte() % 128) + end + end + )"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg_count") +{ + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +type A = (number, number) -> string +type B = (number) -> string + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number) -> string' +caused by: + Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg") +{ + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +type A = (number, number) -> string +type B = (number, string) -> string + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, string) -> string' +caused by: + Argument #2 type is not compatible. Type 'string' could not be converted into 'number')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_count") +{ + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +type A = (number, number) -> (number) +type B = (number, number) -> (number, boolean) + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> number' could not be converted into '(number, number) -> (number, boolean)' +caused by: + Function only returns 1 value. 2 are required here)"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret") +{ + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +type A = (number, number) -> string +type B = (number, number) -> number + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, number) -> number' +caused by: + Return type is not compatible. Type 'string' could not be converted into 'number')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_mult") +{ + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +type A = (number, number) -> (number, string) +type B = (number, number) -> (number, boolean) + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ( + toString(result.errors[0]), R"(Type '(number, number) -> (number, string)' could not be converted into '(number, number) -> (number, boolean)' +caused by: + Return #2 type is not compatible. Type 'string' could not be converted into 'boolean')"); +} + +TEST_CASE_FIXTURE(Fixture, "table_function_check_use_after_free") +{ + ScopedFastFlag luauUnifyFunctionCheckResult{"LuauUpdateFunctionNameBinding", true}; + + CheckResult result = check(R"( +local t = {} + +function t.x(value) + for k,v in pairs(t) do end +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 9f9a007f1..f55b46a40 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -214,4 +214,32 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_41095_concat_log_in_sealed_table_unifica CHECK_EQ(toString(result.errors[1]), "Available overloads: ({a}, a) -> (); and ({a}, number, a) -> ()"); } +TEST_CASE_FIXTURE(TryUnifyFixture, "undo_new_prop_on_unsealed_table") +{ + ScopedFastFlag flags[] = { + {"LuauTableSubtypingVariance2", true}, + }; + // I am not sure how to make this happen in Luau code. + + TypeId unsealedTable = arena.addType(TableTypeVar{TableState::Unsealed, TypeLevel{}}); + TypeId sealedTable = arena.addType(TableTypeVar{ + {{"prop", Property{getSingletonTypes().numberType}}}, + std::nullopt, + TypeLevel{}, + TableState::Sealed + }); + + const TableTypeVar* ttv = get(unsealedTable); + REQUIRE(ttv); + + state.tryUnify(unsealedTable, sealedTable); + + // To be honest, it's really quite spooky here that we're amending an unsealed table in this case. + CHECK(!ttv->props.empty()); + + state.log.rollback(); + + CHECK(ttv->props.empty()); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 48496b895..b095a0db0 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -462,4 +462,20 @@ local a: XYZ = { w = 4 } CHECK_EQ(toString(result.errors[0]), R"(Type 'a' could not be converted into 'X | Y | Z'; none of the union options are compatible)"); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_optional") +{ + ScopedFastFlag luauExtendedUnionMismatchError{"LuauExtendedUnionMismatchError", true}; + + CheckResult result = check(R"( +type X = { x: number } + +local a: X? = { w = 4 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'a' could not be converted into 'X?' +caused by: + None of the union options are compatible. For example: Table type 'a' not compatible with type 'X' because the former is missing field 'x')"); +} + TEST_SUITE_END(); diff --git a/tests/conformance/coverage.lua b/tests/conformance/coverage.lua new file mode 100644 index 000000000..f899603f9 --- /dev/null +++ b/tests/conformance/coverage.lua @@ -0,0 +1,64 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print("testing coverage") + +function foo() + local x = 1 + local y = 2 + assert(x + y) +end + +function bar() + local function one(x) + return x + end + + local two = function(x) + return x + end + + one(1) +end + +function validate(stats, hits, misses) + local checked = {} + + for _,l in ipairs(hits) do + if not (stats[l] and stats[l] > 0) then + return false, string.format("expected line %d to be hit", l) + end + checked[l] = true + end + + for _,l in ipairs(misses) do + if not (stats[l] and stats[l] == 0) then + return false, string.format("expected line %d to be missed", l) + end + checked[l] = true + end + + for k,v in pairs(stats) do + if type(k) == "number" and not checked[k] then + return false, string.format("expected line %d to be absent", k) + end + end + + return true +end + +foo() +c = getcoverage(foo) +assert(#c == 1) +assert(c[1].name == "foo") +assert(validate(c[1], {5, 6, 7}, {})) + +bar() +c = getcoverage(bar) +assert(#c == 3) +assert(c[1].name == "bar") +assert(validate(c[1], {11, 15, 19}, {})) +assert(c[2].name == "one") +assert(validate(c[2], {12}, {})) +assert(c[3].name == nil) +assert(validate(c[3], {}, {16})) + +return 'OK' From a9aa4faf24e6cea1ac0e33d0054a7328a35f9d4a Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 6 Jan 2022 14:08:56 -0800 Subject: [PATCH 11/32] Sync to upstream/release/508 This version isn't for release because we've skipped some internal numbers due to year-end schedule changes, but it's better to merge separately. --- Analysis/include/Luau/LValue.h | 63 +++++++ Analysis/include/Luau/Predicate.h | 34 +--- Analysis/include/Luau/TypeInfer.h | 1 + Analysis/src/{Predicate.cpp => LValue.cpp} | 88 ++++++++- Analysis/src/TypeInfer.cpp | 83 ++++++++- Analysis/src/TypeVar.cpp | 11 +- Analysis/src/Unifier.cpp | 121 ++++--------- Ast/src/Parser.cpp | 4 - Sources.cmake | 5 +- VM/src/lbaselib.cpp | 8 +- tests/Autocomplete.test.cpp | 6 +- tests/LValue.test.cpp | 198 +++++++++++++++++++++ tests/Predicate.test.cpp | 117 ------------ tests/Symbol.test.cpp | 33 +++- tests/Transpiler.test.cpp | 2 - tests/TypeInfer.builtins.test.cpp | 12 +- tests/TypeInfer.classes.test.cpp | 2 - tests/TypeInfer.intersectionTypes.test.cpp | 4 - tests/TypeInfer.refinements.test.cpp | 37 +++- tests/TypeInfer.singletons.test.cpp | 1 - tests/TypeInfer.tables.test.cpp | 4 - tests/TypeInfer.test.cpp | 13 ++ tests/TypeInfer.unionTypes.test.cpp | 4 - tests/conformance/math.lua | 1 + 24 files changed, 569 insertions(+), 283 deletions(-) create mode 100644 Analysis/include/Luau/LValue.h rename Analysis/src/{Predicate.cpp => LValue.cpp} (50%) create mode 100644 tests/LValue.test.cpp delete mode 100644 tests/Predicate.test.cpp diff --git a/Analysis/include/Luau/LValue.h b/Analysis/include/Luau/LValue.h new file mode 100644 index 000000000..8fd96f05a --- /dev/null +++ b/Analysis/include/Luau/LValue.h @@ -0,0 +1,63 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Variant.h" +#include "Luau/Symbol.h" + +#include // TODO: Kill with LuauLValueAsKey. +#include +#include + +namespace Luau +{ + +struct TypeVar; +using TypeId = const TypeVar*; + +struct Field; +using LValue = Variant; + +struct Field +{ + std::shared_ptr parent; + std::string key; + + bool operator==(const Field& rhs) const; + bool operator!=(const Field& rhs) const; +}; + +struct LValueHasher +{ + size_t operator()(const LValue& lvalue) const; +}; + +const LValue* baseof(const LValue& lvalue); + +std::optional tryGetLValue(const class AstExpr& expr); + +// Utility function: breaks down an LValue to get at the Symbol, and reverses the vector of keys. +std::pair> getFullName(const LValue& lvalue); + +// Kill with LuauLValueAsKey. +std::string toString(const LValue& lvalue); + +template +const T* get(const LValue& lvalue) +{ + return get_if(&lvalue); +} + +using NEW_RefinementMap = std::unordered_map; +using DEPRECATED_RefinementMap = std::map; + +// Transient. Kill with LuauLValueAsKey. +struct RefinementMap +{ + NEW_RefinementMap NEW_refinements; + DEPRECATED_RefinementMap DEPRECATED_refinements; +}; + +void merge(RefinementMap& l, const RefinementMap& r, std::function f); +void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty); + +} // namespace Luau diff --git a/Analysis/include/Luau/Predicate.h b/Analysis/include/Luau/Predicate.h index a5e8b6ae1..df93b4f49 100644 --- a/Analysis/include/Luau/Predicate.h +++ b/Analysis/include/Luau/Predicate.h @@ -1,12 +1,10 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once -#include "Luau/Variant.h" #include "Luau/Location.h" -#include "Luau/Symbol.h" +#include "Luau/LValue.h" +#include "Luau/Variant.h" -#include -#include #include namespace Luau @@ -15,34 +13,6 @@ namespace Luau struct TypeVar; using TypeId = const TypeVar*; -struct Field; -using LValue = Variant; - -struct Field -{ - std::shared_ptr parent; // TODO: Eventually use unique_ptr to enforce non-copyable trait. - std::string key; -}; - -std::optional tryGetLValue(const class AstExpr& expr); - -// Utility function: breaks down an LValue to get at the Symbol, and reverses the vector of keys. -std::pair> getFullName(const LValue& lvalue); - -std::string toString(const LValue& lvalue); - -template -const T* get(const LValue& lvalue) -{ - return get_if(&lvalue); -} - -// Key is a stringified encoding of an LValue. -using RefinementMap = std::map; - -void merge(RefinementMap& l, const RefinementMap& r, std::function f); -void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty); - struct TruthyPredicate; struct IsAPredicate; struct TypeGuardPredicate; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 451976e48..862f50d79 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -350,6 +350,7 @@ struct TypeChecker private: std::optional resolveLValue(const ScopePtr& scope, const LValue& lvalue); + std::optional DEPRECATED_resolveLValue(const ScopePtr& scope, const LValue& lvalue); std::optional resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue); void resolve(const PredicateVec& predicates, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr = false); diff --git a/Analysis/src/Predicate.cpp b/Analysis/src/LValue.cpp similarity index 50% rename from Analysis/src/Predicate.cpp rename to Analysis/src/LValue.cpp index 7bd8001e3..da6804c6b 100644 --- a/Analysis/src/Predicate.cpp +++ b/Analysis/src/LValue.cpp @@ -1,11 +1,59 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Predicate.h" +#include "Luau/LValue.h" #include "Luau/Ast.h" +#include + +LUAU_FASTFLAG(LuauLValueAsKey) + namespace Luau { +bool Field::operator==(const Field& rhs) const +{ + LUAU_ASSERT(parent && rhs.parent); + return key == rhs.key && (parent == rhs.parent || *parent == *rhs.parent); +} + +bool Field::operator!=(const Field& rhs) const +{ + return !(*this == rhs); +} + +size_t LValueHasher::operator()(const LValue& lvalue) const +{ + // Most likely doesn't produce high quality hashes, but we're probably ok enough with it. + // When an evidence is shown that operator==(LValue) is used more often than it should, we can have a look at improving the hash quality. + size_t acc = 0; + size_t offset = 0; + + const LValue* current = &lvalue; + while (current) + { + if (auto field = get(*current)) + acc ^= (std::hash{}(field->key) << 1) >> ++offset; + else if (auto symbol = get(*current)) + acc ^= std::hash{}(*symbol) << 1; + else + LUAU_ASSERT(!"Hash not accumulated for this new LValue alternative."); + + current = baseof(*current); + } + + return acc; +} + +const LValue* baseof(const LValue& lvalue) +{ + if (auto field = get(lvalue)) + return field->parent.get(); + + auto symbol = get(lvalue); + LUAU_ASSERT(symbol); + return nullptr; // Base of root is null. +} + std::optional tryGetLValue(const AstExpr& node) { const AstExpr* expr = &node; @@ -38,15 +86,15 @@ std::pair> getFullName(const LValue& lvalue) while (auto field = get(*current)) { keys.push_back(field->key); - current = field->parent.get(); - if (!current) - LUAU_ASSERT(!"LValue root is a Field?"); + current = baseof(*current); } const Symbol* symbol = get(*current); + LUAU_ASSERT(symbol); return {*symbol, std::vector(keys.rbegin(), keys.rend())}; } +// Kill with LuauLValueAsKey. std::string toString(const LValue& lvalue) { auto [symbol, keys] = getFullName(lvalue); @@ -56,7 +104,18 @@ std::string toString(const LValue& lvalue) return s; } -void merge(RefinementMap& l, const RefinementMap& r, std::function f) +static void merge(NEW_RefinementMap& l, const NEW_RefinementMap& r, std::function f) +{ + for (const auto& [k, a] : r) + { + if (auto it = l.find(k); it != l.end()) + l[k] = f(it->second, a); + else + l[k] = a; + } +} + +static void merge(DEPRECATED_RefinementMap& l, const DEPRECATED_RefinementMap& r, std::function f) { auto itL = l.begin(); auto itR = r.begin(); @@ -69,21 +128,32 @@ void merge(RefinementMap& l, const RefinementMap& r, std::functionfirst > k) + else if (itL->first < k) + ++itL; + else { l[k] = a; ++itR; } - else - ++itL; } l.insert(itR, r.end()); } +void merge(RefinementMap& l, const RefinementMap& r, std::function f) +{ + if (FFlag::LuauLValueAsKey) + return merge(l.NEW_refinements, r.NEW_refinements, f); + else + return merge(l.DEPRECATED_refinements, r.DEPRECATED_refinements, f); +} + void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty) { - refis[toString(lvalue)] = ty; + if (FFlag::LuauLValueAsKey) + refis.NEW_refinements[lvalue] = ty; + else + refis.DEPRECATED_refinements[toString(lvalue)] = ty; } } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index abbc2901b..e29b6ec65 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -36,6 +36,7 @@ LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauTailArgumentTypeInfo, false) LUAU_FASTFLAGVARIABLE(LuauModuleRequireErrorPack, false) +LUAU_FASTFLAGVARIABLE(LuauLValueAsKey, false) LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false) LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) @@ -1626,6 +1627,10 @@ std::optional TypeChecker::getIndexTypeFromType( { RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); + // Not needed when we normalize types. + if (FFlag::LuauLValueAsKey && get(follow(t))) + return t; + if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) goodOptions.push_back(*ty); else @@ -4967,13 +4972,83 @@ std::pair, std::vector> TypeChecker::createGener std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LValue& lvalue) { - std::string path = toString(lvalue); + if (!FFlag::LuauLValueAsKey) + return DEPRECATED_resolveLValue(scope, lvalue); + + // We want to be walking the Scope parents. + // We'll also want to walk up the LValue path. As we do this, we need to save each LValue because we must walk back. + // For example: + // There exists an entry t.x. + // We are asked to look for t.x.y. + // We need to search in the provided Scope. Find t.x.y first. + // We fail to find t.x.y. Try t.x. We found it. Now we must return the type of the property y from the mapped-to type of t.x. + // If we completely fail to find the Symbol t but the Scope has that entry, then we should walk that all the way through and terminate. + const auto& [symbol, keys] = getFullName(lvalue); + + ScopePtr currentScope = scope; + while (currentScope) + { + std::optional found; + + std::vector childKeys; + const LValue* currentLValue = &lvalue; + while (currentLValue) + { + if (auto it = currentScope->refinements.NEW_refinements.find(*currentLValue); it != currentScope->refinements.NEW_refinements.end()) + { + found = it->second; + break; + } + + childKeys.push_back(*currentLValue); + currentLValue = baseof(*currentLValue); + } + + if (!found) + { + // Should not be using scope->lookup. This is already recursive. + if (auto it = currentScope->bindings.find(symbol); it != currentScope->bindings.end()) + found = it->second.typeId; + else + { + // Nothing exists in this Scope. Just skip and try the parent one. + currentScope = currentScope->parent; + continue; + } + } + + for (auto it = childKeys.rbegin(); it != childKeys.rend(); ++it) + { + const LValue& key = *it; + + // Symbol can happen. Skip. + if (get(key)) + continue; + else if (auto field = get(key)) + { + found = getIndexTypeFromType(scope, *found, field->key, Location(), false); + if (!found) + return std::nullopt; // Turns out this type doesn't have the property at all. We're done. + } + else + LUAU_ASSERT(!"New LValue alternative not handled here."); + } + + return found; + } + + // No entry for it at all. Can happen when LValue root is a global. + return std::nullopt; +} + +std::optional TypeChecker::DEPRECATED_resolveLValue(const ScopePtr& scope, const LValue& lvalue) +{ auto [symbol, keys] = getFullName(lvalue); ScopePtr currentScope = scope; while (currentScope) { - if (auto it = currentScope->refinements.find(path); it != currentScope->refinements.end()) + if (auto it = currentScope->refinements.DEPRECATED_refinements.find(toString(lvalue)); it != currentScope->refinements.DEPRECATED_refinements.end()) return it->second; // Should not be using scope->lookup. This is already recursive. @@ -5000,7 +5075,9 @@ std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LV std::optional TypeChecker::resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue) { - if (auto it = refis.find(toString(lvalue)); it != refis.end()) + if (auto it = refis.DEPRECATED_refinements.find(toString(lvalue)); it != refis.DEPRECATED_refinements.end()) + return it->second; + else if (auto it = refis.NEW_refinements.find(lvalue); it != refis.NEW_refinements.end()) return it->second; else return resolveLValue(scope, lvalue); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 571b13ca0..fb75aa02e 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -996,18 +996,19 @@ std::optional> magicFunctionFormat( std::vector expected = parseFormatString(typechecker, fmt->value.data, fmt->value.size); const auto& [params, tail] = flatten(paramPack); - const size_t dataOffset = 1; + size_t paramOffset = 1; + size_t dataOffset = expr.self ? 0 : 1; // unify the prefix one argument at a time - for (size_t i = 0; i < expected.size() && i + dataOffset < params.size(); ++i) + for (size_t i = 0; i < expected.size() && i + paramOffset < params.size(); ++i) { - Location location = expr.args.data[std::min(i, expr.args.size - 1)]->location; + Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location; - typechecker.unify(expected[i], params[i + dataOffset], location); + typechecker.unify(expected[i], params[i + paramOffset], location); } // if we know the argument count or if we have too many arguments for sure, we can issue an error - const size_t actualParamSize = params.size() - dataOffset; + size_t actualParamSize = params.size() - paramOffset; if (expected.size() != actualParamSize && (!tail || expected.size() < actualParamSize)) typechecker.reportError(TypeError{expr.location, CountMismatch{expected.size(), actualParamSize}}); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index c5aab8562..43ea37e7b 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -18,9 +18,7 @@ LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) -LUAU_FASTFLAGVARIABLE(LuauExtendedTypeMismatchError, false) LUAU_FASTFLAG(LuauSingletonTypes) -LUAU_FASTFLAGVARIABLE(LuauExtendedClassMismatchError, false) LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauProperTypeLevels); LUAU_FASTFLAGVARIABLE(LuauExtendedUnionMismatchError, false) @@ -416,7 +414,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool else if (!innerState.errors.empty()) { // 'nil' option is skipped from extended report because we present the type in a special way - 'T?' - if (FFlag::LuauExtendedTypeMismatchError && !firstFailedOption && !isNil(type)) + if (!firstFailedOption && !isNil(type)) firstFailedOption = {innerState.errors.front()}; failed = true; @@ -434,7 +432,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool errors.push_back(*unificationTooComplex); else if (failed) { - if (FFlag::LuauExtendedTypeMismatchError && firstFailedOption) + if (firstFailedOption) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); else errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); @@ -536,49 +534,36 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool if (FFlag::LuauExtendedUnionMismatchError && (failedOptionCount == 1 || foundHeuristic) && failedOption) errors.push_back( TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}}); - else if (FFlag::LuauExtendedTypeMismatchError) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); } } else if (const IntersectionTypeVar* uv = get(superTy)) { - if (FFlag::LuauExtendedTypeMismatchError) + std::optional unificationTooComplex; + std::optional firstFailedOption; + + // T <: A & B if A <: T and B <: T + for (TypeId type : uv->parts) { - std::optional unificationTooComplex; - std::optional firstFailedOption; + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); - // T <: A & B if A <: T and B <: T - for (TypeId type : uv->parts) + if (auto e = hasUnificationTooComplex(innerState.errors)) + unificationTooComplex = e; + else if (!innerState.errors.empty()) { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); - - if (auto e = hasUnificationTooComplex(innerState.errors)) - unificationTooComplex = e; - else if (!innerState.errors.empty()) - { - if (!firstFailedOption) - firstFailedOption = {innerState.errors.front()}; - } - - log.concat(std::move(innerState.log)); + if (!firstFailedOption) + firstFailedOption = {innerState.errors.front()}; } - if (unificationTooComplex) - errors.push_back(*unificationTooComplex); - else if (firstFailedOption) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); - } - else - { - // T <: A & B if A <: T and B <: T - for (TypeId type : uv->parts) - { - tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); - } + log.concat(std::move(innerState.log)); } + + if (unificationTooComplex) + errors.push_back(*unificationTooComplex); + else if (firstFailedOption) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); } else if (const IntersectionTypeVar* uv = get(subTy)) { @@ -626,10 +611,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool errors.push_back(*unificationTooComplex); else if (!found) { - if (FFlag::LuauExtendedTypeMismatchError) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); - else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); } } else if (get(superTy) && get(subTy)) @@ -1241,10 +1223,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) Unifier innerState = makeChildUnifier(); innerState.tryUnify_(prop.type, r->second.type); - if (FFlag::LuauExtendedTypeMismatchError) - checkChildUnifierTypeMismatch(innerState.errors, name, left, right); - else - checkChildUnifierTypeMismatch(innerState.errors, left, right); + checkChildUnifierTypeMismatch(innerState.errors, name, left, right); if (innerState.errors.empty()) log.concat(std::move(innerState.log)); @@ -1261,10 +1240,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) Unifier innerState = makeChildUnifier(); innerState.tryUnify_(prop.type, rt->indexer->indexResultType); - if (FFlag::LuauExtendedTypeMismatchError) - checkChildUnifierTypeMismatch(innerState.errors, name, left, right); - else - checkChildUnifierTypeMismatch(innerState.errors, left, right); + checkChildUnifierTypeMismatch(innerState.errors, name, left, right); if (innerState.errors.empty()) log.concat(std::move(innerState.log)); @@ -1302,10 +1278,7 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) Unifier innerState = makeChildUnifier(); innerState.tryUnify_(prop.type, lt->indexer->indexResultType); - if (FFlag::LuauExtendedTypeMismatchError) - checkChildUnifierTypeMismatch(innerState.errors, name, left, right); - else - checkChildUnifierTypeMismatch(innerState.errors, left, right); + checkChildUnifierTypeMismatch(innerState.errors, name, left, right); if (innerState.errors.empty()) log.concat(std::move(innerState.log)); @@ -1723,18 +1696,11 @@ void Unifier::tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reverse innerState.tryUnify_(lhs->table, rhs->table); innerState.tryUnify_(lhs->metatable, rhs->metatable); - if (FFlag::LuauExtendedTypeMismatchError) - { - if (auto e = hasUnificationTooComplex(innerState.errors)) - errors.push_back(*e); - else if (!innerState.errors.empty()) - errors.push_back( - TypeError{location, TypeMismatch{reversed ? other : metatable, reversed ? metatable : other, "", innerState.errors.front()}}); - } - else - { - checkChildUnifierTypeMismatch(innerState.errors, reversed ? other : metatable, reversed ? metatable : other); - } + if (auto e = hasUnificationTooComplex(innerState.errors)) + errors.push_back(*e); + else if (!innerState.errors.empty()) + errors.push_back( + TypeError{location, TypeMismatch{reversed ? other : metatable, reversed ? metatable : other, "", innerState.errors.front()}}); log.concat(std::move(innerState.log)); } @@ -1821,31 +1787,22 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) { ok = false; errors.push_back(TypeError{location, UnknownProperty{superTy, propName}}); - if (!FFlag::LuauExtendedClassMismatchError) - tryUnify_(prop.type, getSingletonTypes().errorRecoveryType()); } else { - if (FFlag::LuauExtendedClassMismatchError) - { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(prop.type, classProp->type); + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(prop.type, classProp->type); - checkChildUnifierTypeMismatch(innerState.errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); + checkChildUnifierTypeMismatch(innerState.errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); - if (innerState.errors.empty()) - { - log.concat(std::move(innerState.log)); - } - else - { - ok = false; - innerState.log.rollback(); - } + if (innerState.errors.empty()) + { + log.concat(std::move(innerState.log)); } else { - tryUnify_(prop.type, classProp->type); + ok = false; + innerState.log.rollback(); } } } @@ -2185,8 +2142,6 @@ void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType) { - LUAU_ASSERT(FFlag::LuauExtendedTypeMismatchError || FFlag::LuauExtendedClassMismatchError); - if (auto e = hasUnificationTooComplex(innerErrors)) errors.push_back(*e); else if (!innerErrors.empty()) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index dd24f27cb..72f616497 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -14,7 +14,6 @@ LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) -LUAU_FASTFLAGVARIABLE(LuauParseGenericFunctionTypeBegin, false) namespace Luau { @@ -1368,9 +1367,6 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) Lexeme parameterStart = lexer.current(); - if (!FFlag::LuauParseGenericFunctionTypeBegin) - begin = parameterStart; - expectAndConsume('(', "function parameters"); matchRecoveryStopOnToken[Lexeme::SkinnyArrow]++; diff --git a/Sources.cmake b/Sources.cmake index 14834b3a5..a7153eb37 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -45,6 +45,7 @@ target_sources(Luau.Analysis PRIVATE Analysis/include/Luau/IostreamHelpers.h Analysis/include/Luau/JsonEncoder.h Analysis/include/Luau/Linter.h + Analysis/include/Luau/LValue.h Analysis/include/Luau/Module.h Analysis/include/Luau/ModuleResolver.h Analysis/include/Luau/Predicate.h @@ -80,8 +81,8 @@ target_sources(Luau.Analysis PRIVATE Analysis/src/IostreamHelpers.cpp Analysis/src/JsonEncoder.cpp Analysis/src/Linter.cpp + Analysis/src/LValue.cpp Analysis/src/Module.cpp - Analysis/src/Predicate.cpp Analysis/src/Quantify.cpp Analysis/src/RequireTracer.cpp Analysis/src/Scope.cpp @@ -194,10 +195,10 @@ if(TARGET Luau.UnitTest) tests/Frontend.test.cpp tests/JsonEncoder.test.cpp tests/Linter.test.cpp + tests/LValue.test.cpp tests/Module.test.cpp tests/NonstrictMode.test.cpp tests/Parser.test.cpp - tests/Predicate.test.cpp tests/RequireTracer.test.cpp tests/StringUtils.test.cpp tests/Symbol.test.cpp diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp index 881c804db..988fd315e 100644 --- a/VM/src/lbaselib.cpp +++ b/VM/src/lbaselib.cpp @@ -36,12 +36,14 @@ static int luaB_tonumber(lua_State* L) int base = luaL_optinteger(L, 2, 10); if (base == 10) { /* standard conversion */ - luaL_checkany(L, 1); - if (lua_isnumber(L, 1)) + int isnum = 0; + double n = lua_tonumberx(L, 1, &isnum); + if (isnum) { - lua_pushnumber(L, lua_tonumber(L, 1)); + lua_pushnumber(L, n); return 1; } + luaL_checkany(L, 1); /* error if we don't have any argument */ } else { diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 62a9999b0..8ca09c0e0 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1394,7 +1394,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_function_return_types") check(R"( local function target(a: number, b: string) return a + #b end local function bar1(a: number) return -a end -local function bar2(a: string) reutrn a .. 'x' end +local function bar2(a: string) return a .. 'x' end return target(b@1 )"); @@ -1422,7 +1422,7 @@ return target(bar1, b@1 check(R"( local function target(a: number, b: string) return a + #b end local function bar1(a: number): (...number) return -a, a end -local function bar2(a: string) reutrn a .. 'x' end +local function bar2(a: string) return a .. 'x' end return target(b@1 )"); @@ -1918,7 +1918,7 @@ TEST_CASE_FIXTURE(ACFixture, "type_correct_function_no_parenthesis") check(R"( local function target(a: (number) -> number) return a(4) end local function bar1(a: number) return -a end -local function bar2(a: string) reutrn a .. 'x' end +local function bar2(a: string) return a .. 'x' end return target(b@1 )"); diff --git a/tests/LValue.test.cpp b/tests/LValue.test.cpp new file mode 100644 index 000000000..8a092779c --- /dev/null +++ b/tests/LValue.test.cpp @@ -0,0 +1,198 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/TypeInfer.h" + +#include "Fixture.h" +#include "ScopedFlags.h" + +#include "doctest.h" + +using namespace Luau; + +static void merge(TypeArena& arena, RefinementMap& l, const RefinementMap& r) +{ + Luau::merge(l, r, [&arena](TypeId a, TypeId b) -> TypeId { + // TODO: normalize here also. + std::unordered_set s; + + if (auto utv = get(follow(a))) + s.insert(begin(utv), end(utv)); + else + s.insert(a); + + if (auto utv = get(follow(b))) + s.insert(begin(utv), end(utv)); + else + s.insert(b); + + std::vector options(s.begin(), s.end()); + return options.size() == 1 ? options[0] : arena.addType(UnionTypeVar{std::move(options)}); + }); +} + +static LValue mkSymbol(const std::string& s) +{ + return Symbol{AstName{s.data()}}; +} + +TEST_SUITE_BEGIN("LValue"); + +TEST_CASE("Luau_merge_hashmap_order") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + std::string a = "a"; + std::string b = "b"; + std::string c = "c"; + + RefinementMap m{{ + {mkSymbol(b), getSingletonTypes().stringType}, + {mkSymbol(c), getSingletonTypes().numberType}, + }}; + + RefinementMap other{{ + {mkSymbol(a), getSingletonTypes().stringType}, + {mkSymbol(b), getSingletonTypes().stringType}, + {mkSymbol(c), getSingletonTypes().booleanType}, + }}; + + TypeArena arena; + merge(arena, m, other); + + REQUIRE_EQ(3, m.NEW_refinements.size()); + REQUIRE(m.NEW_refinements.count(mkSymbol(a))); + REQUIRE(m.NEW_refinements.count(mkSymbol(b))); + REQUIRE(m.NEW_refinements.count(mkSymbol(c))); + + CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(a)])); + CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(b)])); + CHECK_EQ("boolean | number", toString(m.NEW_refinements[mkSymbol(c)])); +} + +TEST_CASE("Luau_merge_hashmap_order2") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + std::string a = "a"; + std::string b = "b"; + std::string c = "c"; + + RefinementMap m{{ + {mkSymbol(a), getSingletonTypes().stringType}, + {mkSymbol(b), getSingletonTypes().stringType}, + {mkSymbol(c), getSingletonTypes().numberType}, + }}; + + RefinementMap other{{ + {mkSymbol(b), getSingletonTypes().stringType}, + {mkSymbol(c), getSingletonTypes().booleanType}, + }}; + + TypeArena arena; + merge(arena, m, other); + + REQUIRE_EQ(3, m.NEW_refinements.size()); + REQUIRE(m.NEW_refinements.count(mkSymbol(a))); + REQUIRE(m.NEW_refinements.count(mkSymbol(b))); + REQUIRE(m.NEW_refinements.count(mkSymbol(c))); + + CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(a)])); + CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(b)])); + CHECK_EQ("boolean | number", toString(m.NEW_refinements[mkSymbol(c)])); +} + +TEST_CASE("one_map_has_overlap_at_end_whereas_other_has_it_in_start") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + std::string a = "a"; + std::string b = "b"; + std::string c = "c"; + std::string d = "d"; + std::string e = "e"; + + RefinementMap m{{ + {mkSymbol(a), getSingletonTypes().stringType}, + {mkSymbol(b), getSingletonTypes().numberType}, + {mkSymbol(c), getSingletonTypes().booleanType}, + }}; + + RefinementMap other{{ + {mkSymbol(c), getSingletonTypes().stringType}, + {mkSymbol(d), getSingletonTypes().numberType}, + {mkSymbol(e), getSingletonTypes().booleanType}, + }}; + + TypeArena arena; + merge(arena, m, other); + + REQUIRE_EQ(5, m.NEW_refinements.size()); + REQUIRE(m.NEW_refinements.count(mkSymbol(a))); + REQUIRE(m.NEW_refinements.count(mkSymbol(b))); + REQUIRE(m.NEW_refinements.count(mkSymbol(c))); + REQUIRE(m.NEW_refinements.count(mkSymbol(d))); + REQUIRE(m.NEW_refinements.count(mkSymbol(e))); + + CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(a)])); + CHECK_EQ("number", toString(m.NEW_refinements[mkSymbol(b)])); + CHECK_EQ("boolean | string", toString(m.NEW_refinements[mkSymbol(c)])); + CHECK_EQ("number", toString(m.NEW_refinements[mkSymbol(d)])); + CHECK_EQ("boolean", toString(m.NEW_refinements[mkSymbol(e)])); +} + +TEST_CASE("hashing_lvalue_global_prop_access") +{ + std::string t1 = "t"; + std::string x1 = "x"; + + LValue t_x1{Field{std::make_shared(Symbol{AstName{t1.data()}}), x1}}; + + std::string t2 = "t"; + std::string x2 = "x"; + + LValue t_x2{Field{std::make_shared(Symbol{AstName{t2.data()}}), x2}}; + + CHECK_EQ(t_x1, t_x1); + CHECK_EQ(t_x1, t_x2); + CHECK_EQ(t_x2, t_x2); + + CHECK_EQ(LValueHasher{}(t_x1), LValueHasher{}(t_x1)); + CHECK_EQ(LValueHasher{}(t_x1), LValueHasher{}(t_x2)); + CHECK_EQ(LValueHasher{}(t_x2), LValueHasher{}(t_x2)); + + NEW_RefinementMap m; + m[t_x1] = getSingletonTypes().stringType; + m[t_x2] = getSingletonTypes().numberType; + + CHECK_EQ(1, m.size()); +} + +TEST_CASE("hashing_lvalue_local_prop_access") +{ + std::string t1 = "t"; + std::string x1 = "x"; + + AstLocal localt1{AstName{t1.data()}, Location(), nullptr, 0, 0, nullptr}; + LValue t_x1{Field{std::make_shared(Symbol{&localt1}), x1}}; + + std::string t2 = "t"; + std::string x2 = "x"; + + AstLocal localt2{AstName{t2.data()}, Location(), &localt1, 0, 0, nullptr}; + LValue t_x2{Field{std::make_shared(Symbol{&localt2}), x2}}; + + CHECK_EQ(t_x1, t_x1); + CHECK_NE(t_x1, t_x2); + CHECK_EQ(t_x2, t_x2); + + CHECK_EQ(LValueHasher{}(t_x1), LValueHasher{}(t_x1)); + CHECK_NE(LValueHasher{}(t_x1), LValueHasher{}(t_x2)); + CHECK_EQ(LValueHasher{}(t_x2), LValueHasher{}(t_x2)); + + NEW_RefinementMap m; + m[t_x1] = getSingletonTypes().stringType; + m[t_x2] = getSingletonTypes().numberType; + + CHECK_EQ(2, m.size()); +} + +TEST_SUITE_END(); diff --git a/tests/Predicate.test.cpp b/tests/Predicate.test.cpp deleted file mode 100644 index 7081693e2..000000000 --- a/tests/Predicate.test.cpp +++ /dev/null @@ -1,117 +0,0 @@ -// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/TypeInfer.h" - -#include "Fixture.h" -#include "ScopedFlags.h" - -#include "doctest.h" - -using namespace Luau; - -static void merge(TypeArena& arena, RefinementMap& l, const RefinementMap& r) -{ - Luau::merge(l, r, [&arena](TypeId a, TypeId b) -> TypeId { - // TODO: normalize here also. - std::unordered_set s; - - if (auto utv = get(follow(a))) - s.insert(begin(utv), end(utv)); - else - s.insert(a); - - if (auto utv = get(follow(b))) - s.insert(begin(utv), end(utv)); - else - s.insert(b); - - std::vector options(s.begin(), s.end()); - return options.size() == 1 ? options[0] : arena.addType(UnionTypeVar{std::move(options)}); - }); -} - -TEST_SUITE_BEGIN("Predicate"); - -TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order") -{ - RefinementMap m{ - {"b", typeChecker.stringType}, - {"c", typeChecker.numberType}, - }; - - RefinementMap other{ - {"a", typeChecker.stringType}, - {"b", typeChecker.stringType}, - {"c", typeChecker.booleanType}, - }; - - TypeArena arena; - merge(arena, m, other); - - REQUIRE_EQ(3, m.size()); - REQUIRE(m.count("a")); - REQUIRE(m.count("b")); - REQUIRE(m.count("c")); - - CHECK_EQ("string", toString(m["a"])); - CHECK_EQ("string", toString(m["b"])); - CHECK_EQ("boolean | number", toString(m["c"])); -} - -TEST_CASE_FIXTURE(Fixture, "Luau_merge_hashmap_order2") -{ - RefinementMap m{ - {"a", typeChecker.stringType}, - {"b", typeChecker.stringType}, - {"c", typeChecker.numberType}, - }; - - RefinementMap other{ - {"b", typeChecker.stringType}, - {"c", typeChecker.booleanType}, - }; - - TypeArena arena; - merge(arena, m, other); - - REQUIRE_EQ(3, m.size()); - REQUIRE(m.count("a")); - REQUIRE(m.count("b")); - REQUIRE(m.count("c")); - - CHECK_EQ("string", toString(m["a"])); - CHECK_EQ("string", toString(m["b"])); - CHECK_EQ("boolean | number", toString(m["c"])); -} - -TEST_CASE_FIXTURE(Fixture, "one_map_has_overlap_at_end_whereas_other_has_it_in_start") -{ - RefinementMap m{ - {"a", typeChecker.stringType}, - {"b", typeChecker.numberType}, - {"c", typeChecker.booleanType}, - }; - - RefinementMap other{ - {"c", typeChecker.stringType}, - {"d", typeChecker.numberType}, - {"e", typeChecker.booleanType}, - }; - - TypeArena arena; - merge(arena, m, other); - - REQUIRE_EQ(5, m.size()); - REQUIRE(m.count("a")); - REQUIRE(m.count("b")); - REQUIRE(m.count("c")); - REQUIRE(m.count("d")); - REQUIRE(m.count("e")); - - CHECK_EQ("string", toString(m["a"])); - CHECK_EQ("number", toString(m["b"])); - CHECK_EQ("boolean | string", toString(m["c"])); - CHECK_EQ("number", toString(m["d"])); - CHECK_EQ("boolean", toString(m["e"])); -} - -TEST_SUITE_END(); diff --git a/tests/Symbol.test.cpp b/tests/Symbol.test.cpp index 44fe3a3c7..e7d2973b8 100644 --- a/tests/Symbol.test.cpp +++ b/tests/Symbol.test.cpp @@ -10,7 +10,7 @@ using namespace Luau; TEST_SUITE_BEGIN("SymbolTests"); -TEST_CASE("hashing") +TEST_CASE("hashing_globals") { std::string s1 = "name"; std::string s2 = "name"; @@ -31,10 +31,37 @@ TEST_CASE("hashing") CHECK_EQ(std::hash()(two), std::hash()(two)); std::unordered_map theMap; - theMap[AstName{s1.data()}] = 5; - theMap[AstName{s2.data()}] = 1; + theMap[n1] = 5; + theMap[n2] = 1; REQUIRE_EQ(1, theMap.size()); } +TEST_CASE("hashing_locals") +{ + std::string s1 = "name"; + std::string s2 = "name"; + + // These two names point to distinct memory areas. + AstLocal one{AstName{s1.data()}, Location(), nullptr, 0, 0, nullptr}; + AstLocal two{AstName{s2.data()}, Location(), &one, 0, 0, nullptr}; + + Symbol n1{&one}; + Symbol n2{&two}; + + CHECK(n1 == n1); + CHECK(n1 != n2); + CHECK(n2 == n2); + + CHECK_EQ(std::hash()(&one), std::hash()(&one)); + CHECK_NE(std::hash()(&one), std::hash()(&two)); + CHECK_EQ(std::hash()(&two), std::hash()(&two)); + + std::unordered_map theMap; + theMap[n1] = 5; + theMap[n2] = 1; + + REQUIRE_EQ(2, theMap.size()); +} + TEST_SUITE_END(); diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 327fa0bbd..47c3883c1 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -555,8 +555,6 @@ TEST_CASE_FIXTURE(Fixture, "transpile_assign_multiple") TEST_CASE_FIXTURE(Fixture, "transpile_generic_function") { - ScopedFastFlag luauParseGenericFunctionTypeBegin("LuauParseGenericFunctionTypeBegin", true); - std::string code = R"( local function foo(a: T, ...: S...) return 1 end local f: (T, S...)->(number) = foo diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 1d8135d4c..506279b94 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -798,13 +798,14 @@ TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_posi { CheckResult result = check(R"( ("%s%d%s"):format(1, "hello", true) + string.format("%s%d%s", 1, "hello", true) )"); TypeId stringType = typeChecker.stringType; TypeId numberType = typeChecker.numberType; TypeId booleanType = typeChecker.booleanType; - LUAU_REQUIRE_ERROR_COUNT(3, result); + LUAU_REQUIRE_ERROR_COUNT(6, result); CHECK_EQ(Location(Position{1, 26}, Position{1, 27}), result.errors[0].location); CHECK_EQ(TypeErrorData(TypeMismatch{stringType, numberType}), result.errors[0].data); @@ -814,6 +815,15 @@ TEST_CASE_FIXTURE(Fixture, "string_format_report_all_type_errors_at_correct_posi CHECK_EQ(Location(Position{1, 38}, Position{1, 42}), result.errors[2].location); CHECK_EQ(TypeErrorData(TypeMismatch{stringType, booleanType}), result.errors[2].data); + + CHECK_EQ(Location(Position{2, 32}, Position{2, 33}), result.errors[3].location); + CHECK_EQ(TypeErrorData(TypeMismatch{stringType, numberType}), result.errors[3].data); + + CHECK_EQ(Location(Position{2, 35}, Position{2, 42}), result.errors[4].location); + CHECK_EQ(TypeErrorData(TypeMismatch{numberType, stringType}), result.errors[4].data); + + CHECK_EQ(Location(Position{2, 44}, Position{2, 48}), result.errors[5].location); + CHECK_EQ(TypeErrorData(TypeMismatch{stringType, booleanType}), result.errors[5].data); } TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type") diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 1ff23fe69..0283ae192 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -449,8 +449,6 @@ b.X = 2 -- real Vector2.X is also read-only TEST_CASE_FIXTURE(ClassFixture, "detailed_class_unification_error") { - ScopedFastFlag luauExtendedClassMismatchError{"LuauExtendedClassMismatchError", true}; - CheckResult result = check(R"( local function foo(v) return v.X :: number + string.len(v.Y) diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 893bc2b30..93c0baf6d 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -343,8 +343,6 @@ TEST_CASE_FIXTURE(Fixture, "table_intersection_setmetatable") TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_part") { - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; - CheckResult result = check(R"( type X = { x: number } type Y = { y: number } @@ -363,8 +361,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_intersection_all") { - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; - CheckResult result = check(R"( type X = { x: number } type Y = { y: number } diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 688680c10..503b613f4 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -280,7 +280,6 @@ TEST_CASE_FIXTURE(Fixture, "assert_non_binary_expressions_actually_resolve_const TEST_CASE_FIXTURE(Fixture, "assign_table_with_refined_property_with_a_similar_type_is_illegal") { ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( local t: {x: number?} = {x = nil} @@ -1085,6 +1084,41 @@ TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression") CHECK_EQ("any", toString(requireTypeAtPosition({6, 66}))); } +TEST_CASE_FIXTURE(Fixture, "correctly_lookup_a_shadowed_local_that_which_was_previously_refined") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + CheckResult result = check(R"( + local foo: string? = "hi" + assert(foo) + local foo: number = 5 + print(foo:sub(1, 1)) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Type 'number' does not have key 'sub'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_refined") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + CheckResult result = check(R"( + type T = {x: string | number} + local t: T? = {x = "hi"} + if t then + if type(t.x) == "string" then + local foo = t.x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("string", toString(requireTypeAtPosition({5, 30}))); +} + TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string") { ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true}; @@ -1092,6 +1126,7 @@ TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscrip CheckResult result = check(R"( type T = { [string]: { prop: number }? } local t: T = {} + if t["hello"] then local foo = t["hello"].prop end diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 1621ef32f..68dc1b4fa 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -202,7 +202,6 @@ TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch") ScopedFastFlag sffs[] = { {"LuauSingletonTypes", true}, {"LuauParseSingletonTypes", true}, - {"LuauExtendedTypeMismatchError", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 3ea9b80c3..80f40407e 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1955,7 +1955,6 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_should_cope_with_optional_properties_in TEST_CASE_FIXTURE(Fixture, "error_detailed_prop") { ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( type A = { x: number, y: number } @@ -1974,7 +1973,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_prop_nested") { ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( type AS = { x: number, y: number } @@ -1998,7 +1996,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_metatable_prop") { ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; CheckResult result = check(R"( local a1 = setmetatable({ x = 2, y = 3 }, { __call = function(s) end }); @@ -2062,7 +2059,6 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, {"LuauTableSubtypingVariance2", true}, - {"LuauExtendedTypeMismatchError", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index ad9ea8276..76324556a 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -5013,6 +5013,19 @@ caused by: Return #2 type is not compatible. Type 'string' could not be converted into 'boolean')"); } +TEST_CASE_FIXTURE(Fixture, "prop_access_on_any_with_other_options") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + CheckResult result = check(R"( + local function f(thing: any | string) + local foo = thing.SomeRandomKey + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "table_function_check_use_after_free") { ScopedFastFlag luauUnifyFunctionCheckResult{"LuauUpdateFunctionNameBinding", true}; diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index b095a0db0..2357869e9 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -425,8 +425,6 @@ y = x TEST_CASE_FIXTURE(Fixture, "error_detailed_union_part") { - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; - CheckResult result = check(R"( type X = { x: number } type Y = { y: number } @@ -446,8 +444,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_union_all") { - ScopedFastFlag luauExtendedTypeMismatchError{"LuauExtendedTypeMismatchError", true}; - CheckResult result = check(R"( type X = { x: number } type Y = { y: number } diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index d5bca44f0..bfea0e1f1 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -26,6 +26,7 @@ function f(...) end end +assert(pcall(tonumber) == false) assert(tonumber{} == nil) assert(tonumber'+0.01' == 1/100 and tonumber'+.01' == 0.01 and tonumber'.01' == 0.01 and tonumber'-1.' == -1 and From 44ccd8282244228a3a3108cb8e5237c45b9d92c2 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 6 Jan 2022 14:10:07 -0800 Subject: [PATCH 12/32] Sync to upstream/release/509 --- .gitignore | 18 +- Analysis/include/Luau/Frontend.h | 10 +- Analysis/include/Luau/TxnLog.h | 280 ++- Analysis/include/Luau/TypeInfer.h | 40 +- Analysis/include/Luau/TypePack.h | 8 + Analysis/include/Luau/TypeVar.h | 1 + Analysis/include/Luau/TypedAllocator.h | 23 +- Analysis/include/Luau/Unifier.h | 48 +- Analysis/src/Autocomplete.cpp | 33 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 7 +- Analysis/src/Frontend.cpp | 28 +- Analysis/src/Module.cpp | 2 +- Analysis/src/Quantify.cpp | 14 +- Analysis/src/ToString.cpp | 48 +- Analysis/src/TxnLog.cpp | 295 ++- Analysis/src/TypeInfer.cpp | 807 +++++-- Analysis/src/TypePack.cpp | 51 +- Analysis/src/TypeVar.cpp | 19 +- Analysis/src/TypedAllocator.cpp | 1 + Analysis/src/Unifier.cpp | 2250 +++++++++++++------ Ast/include/Luau/Common.h | 8 +- CLI/Analyze.cpp | 2 +- CLI/Repl.cpp | 46 +- CLI/Web.cpp | 24 +- Compiler/include/Luau/Bytecode.h | 1 + Compiler/include/Luau/BytecodeBuilder.h | 2 + Compiler/src/BytecodeBuilder.cpp | 69 +- Compiler/src/Compiler.cpp | 7 +- Sources.cmake | 1 + VM/include/luaconf.h | 4 - VM/src/lapi.cpp | 34 +- VM/src/laux.cpp | 38 +- VM/src/lbitlib.cpp | 8 - VM/src/ldebug.cpp | 17 +- VM/src/ldo.cpp | 29 +- VM/src/lgc.cpp | 2 - VM/src/lgcdebug.cpp | 7 +- VM/src/lnumprint.cpp | 375 ++++ VM/src/lnumutils.h | 7 +- VM/src/lobject.h | 1 + VM/src/ltable.cpp | 6 +- VM/src/lvmload.cpp | 28 +- VM/src/lvmutils.cpp | 7 +- fuzz/number.cpp | 35 + tests/AstQuery.test.cpp | 2 - tests/Autocomplete.test.cpp | 43 +- tests/Conformance.test.cpp | 31 +- tests/Fixture.cpp | 5 - tests/Fixture.h | 9 - tests/Frontend.test.cpp | 7 +- tests/TypeInfer.annotations.test.cpp | 2 +- tests/TypeInfer.builtins.test.cpp | 37 +- tests/TypeInfer.generics.test.cpp | 2 - tests/TypeInfer.test.cpp | 2 - tests/TypeInfer.tryUnify.test.cpp | 54 +- tests/TypeInfer.typePacks.cpp | 14 +- tests/TypeInfer.unionTypes.test.cpp | 10 +- tests/TypeVar.test.cpp | 4 +- tests/conformance/debug.lua | 9 + tests/conformance/strconv.lua | 51 + tests/main.cpp | 23 +- tools/numprint.py | 82 + 62 files changed, 3827 insertions(+), 1301 deletions(-) create mode 100644 VM/src/lnumprint.cpp create mode 100644 fuzz/number.cpp create mode 100644 tests/conformance/strconv.lua create mode 100644 tools/numprint.py diff --git a/.gitignore b/.gitignore index fa11b45b5..5688dff52 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,10 @@ -^build/ -^coverage/ -^fuzz/luau.pb.* -^crash-* -^default.prof* -^fuzz-* -^luau$ -/.vs +/build/ +/build[.-]*/ +/coverage/ +/.vs/ +/.vscode/ +/fuzz/luau.pb.* +/crash-* +/default.prof* +/fuzz-* +/luau diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 07a0296a2..1f64db30c 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -68,7 +68,7 @@ struct FrontendOptions // is complete. bool retainFullTypeGraphs = false; - // When true, we run typechecking twice, one in the regular mode, ond once in strict mode + // When true, we run typechecking twice, once in the regular mode, and once in strict mode // in order to get more precise type information (e.g. for autocomplete). bool typecheckTwice = false; }; @@ -109,18 +109,18 @@ struct Frontend Frontend(FileResolver* fileResolver, ConfigResolver* configResolver, const FrontendOptions& options = {}); - CheckResult check(const ModuleName& name); // new shininess - LintResult lint(const ModuleName& name, std::optional enabledLintWarnings = {}); + CheckResult check(const ModuleName& name, std::optional optionOverride = {}); // new shininess + LintResult lint(const ModuleName& name, std::optional enabledLintWarnings = {}); /** Lint some code that has no associated DataModel object * * Since this source fragment has no name, we cannot cache its AST. Instead, * we return it to the caller to use as they wish. */ - std::pair lintFragment(std::string_view source, std::optional enabledLintWarnings = {}); + std::pair lintFragment(std::string_view source, std::optional enabledLintWarnings = {}); CheckResult check(const SourceModule& module); // OLD. TODO KILL - LintResult lint(const SourceModule& module, std::optional enabledLintWarnings = {}); + LintResult lint(const SourceModule& module, std::optional enabledLintWarnings = {}); bool isDirty(const ModuleName& name) const; void markDirty(const ModuleName& name, std::vector* markedDirty = nullptr); diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 29988a3b9..dc45bebf4 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -1,7 +1,11 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include +#include + #include "Luau/TypeVar.h" +#include "Luau/TypePack.h" LUAU_FASTFLAG(LuauShareTxnSeen); @@ -9,27 +13,28 @@ namespace Luau { // Log of where what TypeIds we are rebinding and what they used to be -struct TxnLog +// Remove with LuauUseCommitTxnLog +struct DEPRECATED_TxnLog { - TxnLog() + DEPRECATED_TxnLog() : originalSeenSize(0) , ownedSeen() , sharedSeen(&ownedSeen) { } - explicit TxnLog(std::vector>* sharedSeen) + explicit DEPRECATED_TxnLog(std::vector>* sharedSeen) : originalSeenSize(sharedSeen->size()) , ownedSeen() , sharedSeen(sharedSeen) { } - TxnLog(const TxnLog&) = delete; - TxnLog& operator=(const TxnLog&) = delete; + DEPRECATED_TxnLog(const DEPRECATED_TxnLog&) = delete; + DEPRECATED_TxnLog& operator=(const DEPRECATED_TxnLog&) = delete; - TxnLog(TxnLog&&) = default; - TxnLog& operator=(TxnLog&&) = default; + DEPRECATED_TxnLog(DEPRECATED_TxnLog&&) = default; + DEPRECATED_TxnLog& operator=(DEPRECATED_TxnLog&&) = default; void operator()(TypeId a); void operator()(TypePackId a); @@ -37,7 +42,7 @@ struct TxnLog void rollback(); - void concat(TxnLog rhs); + void concat(DEPRECATED_TxnLog rhs); bool haveSeen(TypeId lhs, TypeId rhs); void pushSeen(TypeId lhs, TypeId rhs); @@ -54,4 +59,263 @@ struct TxnLog std::vector>* sharedSeen; // shared with all the descendent logs }; +// Pending state for a TypeVar. Generated by a TxnLog and committed via +// TxnLog::commit. +struct PendingType +{ + // The pending TypeVar state. + TypeVar pending; + + explicit PendingType(TypeVar state) + : pending(std::move(state)) + { + } +}; + +// Pending state for a TypePackVar. Generated by a TxnLog and committed via +// TxnLog::commit. +struct PendingTypePack +{ + // The pending TypePackVar state. + TypePackVar pending; + + explicit PendingTypePack(TypePackVar state) + : pending(std::move(state)) + { + } +}; + +template +T* getMutable(PendingType* pending) +{ + // We use getMutable here because this state is intended to be mutated freely. + return getMutable(&pending->pending); +} + +template +T* getMutable(PendingTypePack* pending) +{ + // We use getMutable here because this state is intended to be mutated freely. + return getMutable(&pending->pending); +} + +// Log of what TypeIds we are rebinding, to be committed later. +struct TxnLog +{ + TxnLog() + : ownedSeen() + , sharedSeen(&ownedSeen) + { + } + + explicit TxnLog(TxnLog* parent) + : parent(parent) + { + if (parent) + { + sharedSeen = parent->sharedSeen; + } + else + { + sharedSeen = &ownedSeen; + } + } + + explicit TxnLog(std::vector>* sharedSeen) + : sharedSeen(sharedSeen) + { + } + + TxnLog(TxnLog* parent, std::vector>* sharedSeen) + : parent(parent) + , sharedSeen(sharedSeen) + { + } + + TxnLog(const TxnLog&) = delete; + TxnLog& operator=(const TxnLog&) = delete; + + TxnLog(TxnLog&&) = default; + TxnLog& operator=(TxnLog&&) = default; + + // Gets an empty TxnLog pointer. This is useful for constructs that + // take a TxnLog, like TypePackIterator - use the empty log if you + // don't have a TxnLog to give it. + static const TxnLog* empty(); + + // Joins another TxnLog onto this one. You should use std::move to avoid + // copying the rhs TxnLog. + // + // If both logs talk about the same type, pack, or table, the rhs takes + // priority. + void concat(TxnLog rhs); + + // Commits the TxnLog, rebinding all type pointers to their pending states. + // Clears the TxnLog afterwards. + void commit(); + + // Clears the TxnLog without committing any pending changes. + void clear(); + + // Computes an inverse of this TxnLog at the current time. + // This method should be called before commit is called in order to give an + // accurate result. Committing the inverse of a TxnLog will undo the changes + // made by commit, assuming the inverse log is accurate. + TxnLog inverse(); + + bool haveSeen(TypeId lhs, TypeId rhs) const; + void pushSeen(TypeId lhs, TypeId rhs); + void popSeen(TypeId lhs, TypeId rhs); + + // Queues a type for modification. The original type will not change until commit + // is called. Use pending to get the pending state. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* queue(TypeId ty); + + // Queues a type pack for modification. The original type pack will not change + // until commit is called. Use pending to get the pending state. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingTypePack* queue(TypePackId tp); + + // Returns the pending state of a type, or nullptr if there isn't any. It is important + // to note that this pending state is not transitive: the pending state may reference + // non-pending types freely, so you may need to call pending multiple times to view the + // entire pending state of a type graph. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* pending(TypeId ty) const; + + // Returns the pending state of a type pack, or nullptr if there isn't any. It is + // important to note that this pending state is not transitive: the pending state may + // reference non-pending types freely, so you may need to call pending multiple times + // to view the entire pending state of a type graph. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingTypePack* pending(TypePackId tp) const; + + // Queues a replacement of a type with another type. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* replace(TypeId ty, TypeVar replacement); + + // Queues a replacement of a type pack with another type pack. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingTypePack* replace(TypePackId tp, TypePackVar replacement); + + // Queues a replacement of a table type with another table type that is bound + // to a specific value. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* bindTable(TypeId ty, std::optional newBoundTo); + + // Queues a replacement of a type with a level with a duplicate of that type + // with a new type level. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* changeLevel(TypeId ty, TypeLevel newLevel); + + // Queues a replacement of a type pack with a level with a duplicate of that + // type pack with a new type level. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingTypePack* changeLevel(TypePackId tp, TypeLevel newLevel); + + // Queues a replacement of a table type with another table type with a new + // indexer. + // + // The pointer returned lives until `commit` or `clear` is called. + PendingType* changeIndexer(TypeId ty, std::optional indexer); + + // Returns the type level of the pending state of the type, or the level of that + // type, if no pending state exists. If the type doesn't have a notion of a level, + // returns nullopt. If the pending state doesn't have a notion of a level, but the + // original state does, returns nullopt. + std::optional getLevel(TypeId ty) const; + + // Follows a type, accounting for pending type states. The returned type may have + // pending state; you should use `pending` or `get` to find out. + TypeId follow(TypeId ty); + + // Follows a type pack, accounting for pending type states. The returned type pack + // may have pending state; you should use `pending` or `get` to find out. + TypePackId follow(TypePackId tp) const; + + // Replaces a given type's state with a new variant. Returns the new pending state + // of that type. + // + // The pointer returned lives until `commit` or `clear` is called. + template + PendingType* replace(TypeId ty, T replacement) + { + return replace(ty, TypeVar(replacement)); + } + + // Replaces a given type pack's state with a new variant. Returns the new + // pending state of that type pack. + // + // The pointer returned lives until `commit` or `clear` is called. + template + PendingTypePack* replace(TypePackId tp, T replacement) + { + return replace(tp, TypePackVar(replacement)); + } + + // Returns T if a given type or type pack is this variant, respecting the + // log's pending state. + // + // Do not retain this pointer; it has the potential to be invalidated when + // commit or clear is called. + template + T* getMutable(TID ty) const + { + auto* pendingTy = pending(ty); + if (pendingTy) + return Luau::getMutable(pendingTy); + + return Luau::getMutable(ty); + } + + // Returns whether a given type or type pack is a given state, respecting the + // log's pending state. + // + // This method will not assert if called on a BoundTypeVar or BoundTypePack. + template + bool is(TID ty) const + { + // We do not use getMutable here because this method can be called on + // BoundTypeVars, which triggers an assertion. + auto* pendingTy = pending(ty); + if (pendingTy) + return Luau::get_if(&pendingTy->pending.ty) != nullptr; + + return Luau::get_if(&ty->ty) != nullptr; + } + +private: + // unique_ptr is used to give us stable pointers across insertions into the + // map. Otherwise, it would be really easy to accidentally invalidate the + // pointers returned from queue/pending. + // + // We can't use a DenseHashMap here because we need a non-const iterator + // over the map when we concatenate. + std::unordered_map> typeVarChanges; + std::unordered_map> typePackChanges; + + TxnLog* parent = nullptr; + + // Owned version of sharedSeen. This should not be accessed directly in + // TxnLogs; use sharedSeen instead. This field exists because in the tree + // of TxnLogs, the root must own its seen set. In all descendant TxnLogs, + // this is an empty vector. + std::vector> ownedSeen; + +public: + // Used to avoid infinite recursion when types are cyclic. + // Shared with all the descendent TxnLogs. + std::vector>* sharedSeen; +}; + } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 862f50d79..312283b05 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -198,32 +198,32 @@ struct TypeChecker */ TypeId anyIfNonstrict(TypeId ty) const; - /** Attempt to unify the types left and right. Treat any failures as type errors - * in the final typecheck report. + /** Attempt to unify the types. + * Treat any failures as type errors in the final typecheck report. */ - bool unify(TypeId left, TypeId right, const Location& location); - bool unify(TypePackId left, TypePackId right, const Location& location, CountMismatch::Context ctx = CountMismatch::Context::Arg); + bool unify(TypeId subTy, TypeId superTy, const Location& location); + bool unify(TypePackId subTy, TypePackId superTy, const Location& location, CountMismatch::Context ctx = CountMismatch::Context::Arg); - /** Attempt to unify the types left and right. - * If this fails, and the right type can be instantiated, do so and try unification again. + /** Attempt to unify the types. + * If this fails, and the subTy type can be instantiated, do so and try unification again. */ - bool unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, const Location& location); - void unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, Unifier& state); + bool unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId subTy, TypeId superTy, const Location& location); + void unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId subTy, TypeId superTy, Unifier& state); - /** Attempt to unify left with right. + /** Attempt to unify. * If there are errors, undo everything and return the errors. * If there are no errors, commit and return an empty error vector. */ - ErrorVec tryUnify(TypeId left, TypeId right, const Location& location); - ErrorVec tryUnify(TypePackId left, TypePackId right, const Location& location); + template + ErrorVec tryUnify_(Id subTy, Id superTy, const Location& location); + ErrorVec tryUnify(TypeId subTy, TypeId superTy, const Location& location); + ErrorVec tryUnify(TypePackId subTy, TypePackId superTy, const Location& location); // Test whether the two type vars unify. Never commits the result. - ErrorVec canUnify(TypeId superTy, TypeId subTy, const Location& location); - ErrorVec canUnify(TypePackId superTy, TypePackId subTy, const Location& location); - - // Variant that takes a preexisting 'seen' set. We need this in certain cases to avoid infinitely recursing - // into cyclic types. - ErrorVec canUnify(const std::vector>& seen, TypeId left, TypeId right, const Location& location); + template + ErrorVec canUnify_(Id subTy, Id superTy, const Location& location); + ErrorVec canUnify(TypeId subTy, TypeId superTy, const Location& location); + ErrorVec canUnify(TypePackId subTy, TypePackId superTy, const Location& location); std::optional findMetatableEntry(TypeId type, std::string entry, const Location& location); std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location); @@ -237,12 +237,6 @@ struct TypeChecker std::optional tryStripUnionFromNil(TypeId ty); TypeId stripFromNilAndReport(TypeId ty, const Location& location); - template - ErrorVec tryUnify_(Id left, Id right, const Location& location); - - template - ErrorVec canUnify_(Id left, Id right, const Location& location); - public: /* * Convert monotype into a a polytype, by replacing any metavariables in descendant scopes diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index e72808da7..ca588ccb7 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -18,6 +18,8 @@ struct VariadicTypePack; struct TypePackVar; +struct TxnLog; + using TypePackId = const TypePackVar*; using FreeTypePack = Unifiable::Free; using BoundTypePack = Unifiable::Bound; @@ -84,6 +86,7 @@ struct TypePackIterator TypePackIterator() = default; explicit TypePackIterator(TypePackId tp); + TypePackIterator(TypePackId tp, const TxnLog* log); TypePackIterator& operator++(); TypePackIterator operator++(int); @@ -104,9 +107,13 @@ struct TypePackIterator TypePackId currentTypePack = nullptr; const TypePack* tp = nullptr; size_t currentIndex = 0; + + // Only used if LuauUseCommittingTxnLog is true. + const TxnLog* log; }; TypePackIterator begin(TypePackId tp); +TypePackIterator begin(TypePackId tp, TxnLog* log); TypePackIterator end(TypePackId tp); using SeenSet = std::set>; @@ -114,6 +121,7 @@ using SeenSet = std::set>; bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs); TypePackId follow(TypePackId tp); +TypePackId follow(TypePackId tp, std::function mapper); size_t size(TypePackId tp); bool finite(TypePackId tp); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index f6829ec3e..d6e177142 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -453,6 +453,7 @@ bool areEqual(SeenSet& seen, const TypeVar& lhs, const TypeVar& rhs); // Follow BoundTypeVars until we get to something real TypeId follow(TypeId t); +TypeId follow(TypeId t, std::function mapper); std::vector flattenIntersection(TypeId ty); diff --git a/Analysis/include/Luau/TypedAllocator.h b/Analysis/include/Luau/TypedAllocator.h index 0ded14890..64227e7c1 100644 --- a/Analysis/include/Luau/TypedAllocator.h +++ b/Analysis/include/Luau/TypedAllocator.h @@ -6,6 +6,8 @@ #include #include +LUAU_FASTFLAG(LuauTypedAllocatorZeroStart) + namespace Luau { @@ -20,7 +22,10 @@ class TypedAllocator public: TypedAllocator() { - appendBlock(); + if (FFlag::LuauTypedAllocatorZeroStart) + currentBlockSize = kBlockSize; + else + appendBlock(); } ~TypedAllocator() @@ -59,12 +64,18 @@ class TypedAllocator bool empty() const { - return stuff.size() == 1 && currentBlockSize == 0; + if (FFlag::LuauTypedAllocatorZeroStart) + return stuff.empty(); + else + return stuff.size() == 1 && currentBlockSize == 0; } size_t size() const { - return kBlockSize * (stuff.size() - 1) + currentBlockSize; + if (FFlag::LuauTypedAllocatorZeroStart) + return stuff.empty() ? 0 : kBlockSize * (stuff.size() - 1) + currentBlockSize; + else + return kBlockSize * (stuff.size() - 1) + currentBlockSize; } void clear() @@ -72,7 +83,11 @@ class TypedAllocator if (frozen) unfreeze(); free(); - appendBlock(); + + if (FFlag::LuauTypedAllocatorZeroStart) + currentBlockSize = kBlockSize; + else + appendBlock(); } void freeze() diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 7681b9662..a3be739a6 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -25,6 +25,7 @@ struct Unifier Mode mode; ScopePtr globalScope; // sigh. Needed solely to get at string's metatable. + DEPRECATED_TxnLog DEPRECATED_log; TxnLog log; ErrorVec errors; Location location; @@ -33,44 +34,45 @@ struct Unifier UnifierSharedState& sharedState; - Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState); + Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState, + TxnLog* parentLog = nullptr); Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState); + Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); // Test whether the two type vars unify. Never commits the result. - ErrorVec canUnify(TypeId superTy, TypeId subTy); - ErrorVec canUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); + ErrorVec canUnify(TypeId subTy, TypeId superTy); + ErrorVec canUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false); - /** Attempt to unify left with right. + /** Attempt to unify. * Populate the vector errors with any type errors that may arise. * Populate the transaction log with the set of TypeIds that need to be reset to undo the unification attempt. */ - void tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall = false, bool isIntersection = false); + void tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false); private: - void tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall = false, bool isIntersection = false); - void tryUnifyPrimitives(TypeId superTy, TypeId subTy); - void tryUnifySingletons(TypeId superTy, TypeId subTy); - void tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall = false); - void tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); - void DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection = false); - void tryUnifyFreeTable(TypeId free, TypeId other); - void tryUnifySealedTables(TypeId left, TypeId right, bool isIntersection); - void tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed); - void tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed); - void tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer); + void tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false); + void tryUnifyPrimitives(TypeId subTy, TypeId superTy); + void tryUnifySingletons(TypeId subTy, TypeId superTy); + void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false); + void tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false); + void DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection = false); + void tryUnifyFreeTable(TypeId subTy, TypeId superTy); + void tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersection); + void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed); + void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed); + void tryUnifyIndexer(const TableIndexer& subIndexer, const TableIndexer& superIndexer); TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); - void cacheResult(TypeId superTy, TypeId subTy); + void cacheResult(TypeId subTy, TypeId superTy); public: - void tryUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); + void tryUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false); private: - void tryUnify_(TypePackId superTy, TypePackId subTy, bool isFunctionCall = false); - void tryUnifyVariadics(TypePackId superTy, TypePackId subTy, bool reversed, int subOffset = 0); + void tryUnify_(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false); + void tryUnifyVariadics(TypePackId subTy, TypePackId superTy, bool reversed, int subOffset = 0); - void tryUnifyWithAny(TypeId any, TypeId ty); - void tryUnifyWithAny(TypePackId any, TypePackId ty); + void tryUnifyWithAny(TypeId subTy, TypeId anyTy); + void tryUnifyWithAny(TypePackId subTy, TypePackId anyTp); std::optional findTablePropertyRespectingMeta(TypeId lhsType, Name name); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 4b583792c..67ebd0755 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -12,10 +12,12 @@ #include #include +LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); LUAU_FASTFLAGVARIABLE(LuauAutocompletePreferToCallFunctions, false); LUAU_FASTFLAGVARIABLE(LuauAutocompleteFirstArg, false); +LUAU_FASTFLAGVARIABLE(LuauCompleteBrokenStringParams, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -236,28 +238,31 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ { ty = follow(ty); - auto canUnify = [&typeArena, &module](TypeId expectedType, TypeId actualType) { + auto canUnify = [&typeArena, &module](TypeId subTy, TypeId superTy) { InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState); - if (FFlag::LuauAutocompleteAvoidMutation) + if (FFlag::LuauAutocompleteAvoidMutation && !FFlag::LuauUseCommittingTxnLog) { SeenTypes seenTypes; SeenTypePacks seenTypePacks; CloneState cloneState; - expectedType = clone(expectedType, *typeArena, seenTypes, seenTypePacks, cloneState); - actualType = clone(actualType, *typeArena, seenTypes, seenTypePacks, cloneState); + superTy = clone(superTy, *typeArena, seenTypes, seenTypePacks, cloneState); + subTy = clone(subTy, *typeArena, seenTypes, seenTypePacks, cloneState); - auto errors = unifier.canUnify(expectedType, actualType); + auto errors = unifier.canUnify(subTy, superTy); return errors.empty(); } else { - unifier.tryUnify(expectedType, actualType); + unifier.tryUnify(subTy, superTy); bool ok = unifier.errors.empty(); - unifier.log.rollback(); + + if (!FFlag::LuauUseCommittingTxnLog) + unifier.DEPRECATED_log.rollback(); + return ok; } }; @@ -293,22 +298,22 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ { auto [retHead, retTail] = flatten(ftv->retType); - if (!retHead.empty() && canUnify(expectedType, retHead.front())) + if (!retHead.empty() && canUnify(retHead.front(), expectedType)) return TypeCorrectKind::CorrectFunctionResult; // We might only have a variadic tail pack, check if the element is compatible if (retTail) { - if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(expectedType, vtp->ty)) + if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) return TypeCorrectKind::CorrectFunctionResult; } } - return canUnify(expectedType, ty) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; } else { - if (canUnify(expectedType, ty)) + if (canUnify(ty, expectedType)) return TypeCorrectKind::Correct; // We also want to suggest functions that return compatible result @@ -320,13 +325,13 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ auto [retHead, retTail] = flatten(ftv->retType); if (!retHead.empty()) - return canUnify(expectedType, retHead.front()) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + return canUnify(retHead.front(), expectedType) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; // We might only have a variadic tail pack, check if the element is compatible if (retTail) { if (const VariadicTypePack* vtp = get(follow(*retTail))) - return canUnify(expectedType, vtp->ty) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + return canUnify(vtp->ty, expectedType) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; } return TypeCorrectKind::None; @@ -1319,7 +1324,7 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } - if (!nodes.back()->is()) + if (!nodes.back()->is() && (!FFlag::LuauCompleteBrokenStringParams || !nodes.back()->is())) { return std::nullopt; } diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index d0afa7424..249825067 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -138,12 +138,7 @@ declare function gcinfo(): number -- (nil, string). declare function loadstring(src: string, chunkname: string?): (((A...) -> any)?, string?) - -- a userdata object is "roughly" the same as a sealed empty table - -- except `type(newproxy(false))` evaluates to "userdata" so we may need another special type here too. - -- another important thing to note: the value passed in conditionally creates an empty metatable, and you have to use getmetatable, NOT - -- setmetatable. - -- FIXME: change this to something Luau can understand how to reject `setmetatable(newproxy(false or true), {})`. - declare function newproxy(mt: boolean?): {} + declare function newproxy(mt: boolean?): any declare coroutine: { create: ((A...) -> R...) -> thread, diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index e332f07d4..fe4b6529a 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -351,7 +351,7 @@ FrontendModuleResolver::FrontendModuleResolver(Frontend* frontend) { } -CheckResult Frontend::check(const ModuleName& name) +CheckResult Frontend::check(const ModuleName& name, std::optional optionOverride) { LUAU_TIMETRACE_SCOPE("Frontend::check", "Frontend"); LUAU_TIMETRACE_ARGUMENT("name", name.c_str()); @@ -372,6 +372,8 @@ CheckResult Frontend::check(const ModuleName& name) std::vector buildQueue; bool cycleDetected = parseGraph(buildQueue, checkResult, name); + FrontendOptions frontendOptions = optionOverride.value_or(options); + // Keep track of which AST nodes we've reported cycles in std::unordered_set reportedCycles; @@ -411,31 +413,11 @@ CheckResult Frontend::check(const ModuleName& name) // If we're typechecking twice, we do so. // The second typecheck is always in strict mode with DM awareness // to provide better typen information for IDE features. - if (options.typecheckTwice) + if (frontendOptions.typecheckTwice) { ModulePtr moduleForAutocomplete = typeCheckerForAutocomplete.check(sourceModule, Mode::Strict); moduleResolverForAutocomplete.modules[moduleName] = moduleForAutocomplete; } - else if (options.retainFullTypeGraphs && options.typecheckTwice && mode != Mode::Strict) - { - ModulePtr strictModule = typeChecker.check(sourceModule, Mode::Strict, environmentScope); - module->astTypes.clear(); - module->astOriginalCallTypes.clear(); - module->astExpectedTypes.clear(); - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; - CloneState cloneState; - - for (const auto& [expr, strictTy] : strictModule->astTypes) - module->astTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks, cloneState); - - for (const auto& [expr, strictTy] : strictModule->astOriginalCallTypes) - module->astOriginalCallTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks, cloneState); - - for (const auto& [expr, strictTy] : strictModule->astExpectedTypes) - module->astExpectedTypes[expr] = clone(strictTy, module->interfaceTypes, seenTypes, seenTypePacks, cloneState); - } stats.timeCheck += getTimestamp() - timestamp; stats.filesStrict += mode == Mode::Strict; @@ -444,7 +426,7 @@ CheckResult Frontend::check(const ModuleName& name) if (module == nullptr) throw std::runtime_error("Frontend::check produced a nullptr module for " + moduleName); - if (!options.retainFullTypeGraphs) + if (!frontendOptions.retainFullTypeGraphs) { // copyErrors needs to allocate into interfaceTypes as it copies // types out of internalTypes, so we unfreeze it here. diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index e1e53c971..cff85897c 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -13,7 +13,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) -LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 0) +LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) namespace Luau { diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index c773e208b..04ebffc1b 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -4,8 +4,6 @@ #include "Luau/VisitTypeVar.h" -LUAU_FASTFLAGVARIABLE(LuauQuantifyVisitOnce, false) - namespace Luau { @@ -81,16 +79,8 @@ struct Quantifier void quantify(ModulePtr module, TypeId ty, TypeLevel level) { Quantifier q{std::move(module), level}; - - if (FFlag::LuauQuantifyVisitOnce) - { - DenseHashSet seen{nullptr}; - visitTypeVarOnce(ty, q, seen); - } - else - { - visitTypeVar(ty, q); - } + DenseHashSet seen{nullptr}; + visitTypeVarOnce(ty, q, seen); FunctionTypeVar* ftv = getMutable(ty); LUAU_ASSERT(ftv); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index a6be53482..889dd6dc5 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -11,7 +11,6 @@ #include LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) -LUAU_FASTFLAGVARIABLE(LuauFunctionArgumentNameSize, false) /* * Prefix generic typenames with gen- @@ -766,24 +765,12 @@ struct TypePackStringifier else state.emit(", "); - if (FFlag::LuauFunctionArgumentNameSize) + if (elemIndex < elemNames.size() && elemNames[elemIndex]) { - if (elemIndex < elemNames.size() && elemNames[elemIndex]) - { - state.emit(elemNames[elemIndex]->name); - state.emit(": "); - } + state.emit(elemNames[elemIndex]->name); + state.emit(": "); } - else - { - LUAU_ASSERT(elemNames.empty() || elemIndex < elemNames.size()); - if (!elemNames.empty() && elemNames[elemIndex]) - { - state.emit(elemNames[elemIndex]->name); - state.emit(": "); - } - } elemIndex++; stringify(typeId); @@ -1151,38 +1138,19 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV s += ", "; first = false; - if (FFlag::LuauFunctionArgumentNameSize) + // We don't currently respect opts.functionTypeArguments. I don't think this function should. + if (argNameIter != ftv.argNames.end()) { - // We don't currently respect opts.functionTypeArguments. I don't think this function should. - if (argNameIter != ftv.argNames.end()) - { - s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; - ++argNameIter; - } - else - { - s += "_: "; - } + s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; + ++argNameIter; } else { - // argNames is guaranteed to be equal to argTypes iff argNames is not empty. - // We don't currently respect opts.functionTypeArguments. I don't think this function should. - if (!ftv.argNames.empty()) - s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; + s += "_: "; } s += toString_(*argPackIter); ++argPackIter; - - if (!FFlag::LuauFunctionArgumentNameSize) - { - if (!ftv.argNames.empty()) - { - LUAU_ASSERT(argNameIter != ftv.argNames.end()); - ++argNameIter; - } - } } if (argPackIter.tail()) diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index f6a61581e..a46ac0c35 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -4,27 +4,34 @@ #include "Luau/TypePack.h" #include +#include + +LUAU_FASTFLAGVARIABLE(LuauUseCommittingTxnLog, false) namespace Luau { -void TxnLog::operator()(TypeId a) +void DEPRECATED_TxnLog::operator()(TypeId a) { + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); typeVarChanges.emplace_back(a, *a); } -void TxnLog::operator()(TypePackId a) +void DEPRECATED_TxnLog::operator()(TypePackId a) { + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); typePackChanges.emplace_back(a, *a); } -void TxnLog::operator()(TableTypeVar* a) +void DEPRECATED_TxnLog::operator()(TableTypeVar* a) { + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); tableChanges.emplace_back(a, a->boundTo); } -void TxnLog::rollback() +void DEPRECATED_TxnLog::rollback() { + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); for (auto it = typeVarChanges.rbegin(); it != typeVarChanges.rend(); ++it) std::swap(*asMutable(it->first), it->second); @@ -38,8 +45,9 @@ void TxnLog::rollback() sharedSeen->resize(originalSeenSize); } -void TxnLog::concat(TxnLog rhs) +void DEPRECATED_TxnLog::concat(DEPRECATED_TxnLog rhs) { + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); typeVarChanges.insert(typeVarChanges.end(), rhs.typeVarChanges.begin(), rhs.typeVarChanges.end()); rhs.typeVarChanges.clear(); @@ -50,23 +58,298 @@ void TxnLog::concat(TxnLog rhs) rhs.tableChanges.clear(); } -bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) +bool DEPRECATED_TxnLog::haveSeen(TypeId lhs, TypeId rhs) { + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)); } +void DEPRECATED_TxnLog::pushSeen(TypeId lhs, TypeId rhs) +{ + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + sharedSeen->push_back(sortedPair); +} + +void DEPRECATED_TxnLog::popSeen(TypeId lhs, TypeId rhs) +{ + LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + LUAU_ASSERT(sortedPair == sharedSeen->back()); + sharedSeen->pop_back(); +} + +static const TxnLog emptyLog; + +const TxnLog* TxnLog::empty() +{ + return &emptyLog; +} + +void TxnLog::concat(TxnLog rhs) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + for (auto& [ty, rep] : rhs.typeVarChanges) + typeVarChanges[ty] = std::move(rep); + + for (auto& [tp, rep] : rhs.typePackChanges) + typePackChanges[tp] = std::move(rep); +} + +void TxnLog::commit() +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + for (auto& [ty, rep] : typeVarChanges) + *asMutable(ty) = rep.get()->pending; + + for (auto& [tp, rep] : typePackChanges) + *asMutable(tp) = rep.get()->pending; + + clear(); +} + +void TxnLog::clear() +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + typeVarChanges.clear(); + typePackChanges.clear(); +} + +TxnLog TxnLog::inverse() +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + TxnLog inversed(sharedSeen); + + for (auto& [ty, _rep] : typeVarChanges) + inversed.typeVarChanges[ty] = std::make_unique(*ty); + + for (auto& [tp, _rep] : typePackChanges) + inversed.typePackChanges[tp] = std::make_unique(*tp); + + return inversed; +} + +bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) const +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + if (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)) + { + return true; + } + + if (parent) + { + return parent->haveSeen(lhs, rhs); + } + + return false; +} + void TxnLog::pushSeen(TypeId lhs, TypeId rhs) { + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); sharedSeen->push_back(sortedPair); } void TxnLog::popSeen(TypeId lhs, TypeId rhs) { + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); LUAU_ASSERT(sortedPair == sharedSeen->back()); sharedSeen->pop_back(); } +PendingType* TxnLog::queue(TypeId ty) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + LUAU_ASSERT(!ty->persistent); + + // Explicitly don't look in ancestors. If we have discovered something new + // about this type, we don't want to mutate the parent's state. + auto& pending = typeVarChanges[ty]; + if (!pending) + pending = std::make_unique(*ty); + + return pending.get(); +} + +PendingTypePack* TxnLog::queue(TypePackId tp) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + LUAU_ASSERT(!tp->persistent); + + // Explicitly don't look in ancestors. If we have discovered something new + // about this type, we don't want to mutate the parent's state. + auto& pending = typePackChanges[tp]; + if (!pending) + pending = std::make_unique(*tp); + + return pending.get(); +} + +PendingType* TxnLog::pending(TypeId ty) const +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + for (const TxnLog* current = this; current; current = current->parent) + { + if (auto it = current->typeVarChanges.find(ty); it != current->typeVarChanges.end()) + return it->second.get(); + } + + return nullptr; +} + +PendingTypePack* TxnLog::pending(TypePackId tp) const +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + for (const TxnLog* current = this; current; current = current->parent) + { + if (auto it = current->typePackChanges.find(tp); it != current->typePackChanges.end()) + return it->second.get(); + } + + return nullptr; +} + +PendingType* TxnLog::replace(TypeId ty, TypeVar replacement) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + PendingType* newTy = queue(ty); + newTy->pending = replacement; + return newTy; +} + +PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + PendingTypePack* newTp = queue(tp); + newTp->pending = replacement; + return newTp; +} + +PendingType* TxnLog::bindTable(TypeId ty, std::optional newBoundTo) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + LUAU_ASSERT(get(ty)); + + PendingType* newTy = queue(ty); + if (TableTypeVar* ttv = Luau::getMutable(newTy)) + ttv->boundTo = newBoundTo; + + return newTy; +} + +PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + LUAU_ASSERT(get(ty) || get(ty) || get(ty)); + + PendingType* newTy = queue(ty); + if (FreeTypeVar* ftv = Luau::getMutable(newTy)) + { + ftv->level = newLevel; + } + else if (TableTypeVar* ttv = Luau::getMutable(newTy)) + { + LUAU_ASSERT(ttv->state == TableState::Free || ttv->state == TableState::Generic); + ttv->level = newLevel; + } + else if (FunctionTypeVar* ftv = Luau::getMutable(newTy)) + { + ftv->level = newLevel; + } + + return newTy; +} + +PendingTypePack* TxnLog::changeLevel(TypePackId tp, TypeLevel newLevel) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + LUAU_ASSERT(get(tp)); + + PendingTypePack* newTp = queue(tp); + if (FreeTypePack* ftp = Luau::getMutable(newTp)) + { + ftp->level = newLevel; + } + + return newTp; +} + +PendingType* TxnLog::changeIndexer(TypeId ty, std::optional indexer) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + LUAU_ASSERT(get(ty)); + + PendingType* newTy = queue(ty); + if (TableTypeVar* ttv = Luau::getMutable(newTy)) + { + ttv->indexer = indexer; + } + + return newTy; +} + +std::optional TxnLog::getLevel(TypeId ty) const +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + if (FreeTypeVar* ftv = getMutable(ty)) + return ftv->level; + else if (TableTypeVar* ttv = getMutable(ty); ttv && (ttv->state == TableState::Free || ttv->state == TableState::Generic)) + return ttv->level; + else if (FunctionTypeVar* ftv = getMutable(ty)) + return ftv->level; + + return std::nullopt; +} + +TypeId TxnLog::follow(TypeId ty) +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + return Luau::follow(ty, [this](TypeId ty) { + PendingType* state = this->pending(ty); + + if (state == nullptr) + return ty; + + // Ugly: Fabricate a TypeId that doesn't adhere to most of the invariants + // that normally apply. This is safe because follow will only call get<> + // on the returned pointer. + return const_cast(&state->pending); + }); +} + +TypePackId TxnLog::follow(TypePackId tp) const +{ + LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); + + return Luau::follow(tp, [this](TypePackId tp) { + PendingTypePack* state = this->pending(tp); + + if (state == nullptr) + return tp; + + // Ugly: Fabricate a TypePackId that doesn't adhere to most of the + // invariants that normally apply. This is safe because follow will + // only call get<> on the returned pointer. + return const_cast(&state->pending); + }); +} + } // namespace Luau diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index e29b6ec65..1689a5c3d 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -27,15 +27,16 @@ LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) +LUAU_FASTFLAG(LuauUseCommittingTxnLog) +LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) +LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) -LUAU_FASTFLAGVARIABLE(LuauTailArgumentTypeInfo, false) -LUAU_FASTFLAGVARIABLE(LuauModuleRequireErrorPack, false) LUAU_FASTFLAGVARIABLE(LuauLValueAsKey, false) LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false) LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) @@ -450,7 +451,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) ++subLevel; TypeId leftType = checkFunctionName(scope, *fun->name, funScope->level); - unify(leftType, funTy, fun->location); + unify(funTy, leftType, fun->location); } else if (auto fun = (*protoIter)->as()) { @@ -556,21 +557,21 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatIf& statement) } } -ErrorVec TypeChecker::canUnify(TypeId left, TypeId right, const Location& location) +template +ErrorVec TypeChecker::canUnify_(Id subTy, Id superTy, const Location& location) { - return canUnify_(left, right, location); + Unifier state = mkUnifier(location); + return state.canUnify(subTy, superTy); } -ErrorVec TypeChecker::canUnify(TypePackId left, TypePackId right, const Location& location) +ErrorVec TypeChecker::canUnify(TypeId subTy, TypeId superTy, const Location& location) { - return canUnify_(left, right, location); + return canUnify_(subTy, superTy, location); } -template -ErrorVec TypeChecker::canUnify_(Id superTy, Id subTy, const Location& location) +ErrorVec TypeChecker::canUnify(TypePackId subTy, TypePackId superTy, const Location& location) { - Unifier state = mkUnifier(location); - return state.canUnify(superTy, subTy); + return canUnify_(subTy, superTy, location); } void TypeChecker::check(const ScopePtr& scope, const AstStatWhile& statement) @@ -619,7 +620,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) // start typechecking everything across module boundaries. if (isNonstrictMode() && follow(scope->returnType) == follow(currentModule->getModuleScope()->returnType)) { - ErrorVec errors = tryUnify(scope->returnType, retPack, return_.location); + ErrorVec errors = tryUnify(retPack, scope->returnType, return_.location); if (!errors.empty()) currentModule->getModuleScope()->returnType = addTypePack({anyType}); @@ -627,29 +628,39 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatReturn& return_) return; } - unify(scope->returnType, retPack, return_.location, CountMismatch::Context::Return); + unify(retPack, scope->returnType, return_.location, CountMismatch::Context::Return); } -ErrorVec TypeChecker::tryUnify(TypeId left, TypeId right, const Location& location) +template +ErrorVec TypeChecker::tryUnify_(Id subTy, Id superTy, const Location& location) { - return tryUnify_(left, right, location); + Unifier state = mkUnifier(location); + + if (FFlag::LuauUseCommittingTxnLog && FFlag::DebugLuauFreezeDuringUnification) + freeze(currentModule->internalTypes); + + state.tryUnify(subTy, superTy); + + if (FFlag::LuauUseCommittingTxnLog && FFlag::DebugLuauFreezeDuringUnification) + unfreeze(currentModule->internalTypes); + + if (!state.errors.empty() && !FFlag::LuauUseCommittingTxnLog) + state.DEPRECATED_log.rollback(); + + if (state.errors.empty() && FFlag::LuauUseCommittingTxnLog) + state.log.commit(); + + return state.errors; } -ErrorVec TypeChecker::tryUnify(TypePackId left, TypePackId right, const Location& location) +ErrorVec TypeChecker::tryUnify(TypeId subTy, TypeId superTy, const Location& location) { - return tryUnify_(left, right, location); + return tryUnify_(subTy, superTy, location); } -template -ErrorVec TypeChecker::tryUnify_(Id left, Id right, const Location& location) +ErrorVec TypeChecker::tryUnify(TypePackId subTy, TypePackId superTy, const Location& location) { - Unifier state = mkUnifier(location); - state.tryUnify(left, right); - - if (!state.errors.empty()) - state.log.rollback(); - - return state.errors; + return tryUnify_(subTy, superTy, location); } void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) @@ -743,9 +754,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatAssign& assign) { // In nonstrict mode, any assignments where the lhs is free and rhs isn't a function, we give it any typevar. if (isNonstrictMode() && get(follow(left)) && !get(follow(right))) - unify(left, anyType, loc); + unify(anyType, left, loc); else - unify(left, right, loc); + unify(right, left, loc); } } } @@ -760,7 +771,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatCompoundAssign& assi TypeId result = checkBinaryOperation(scope, expr, left, right); - unify(left, result, assign.location); + unify(result, left, assign.location); } void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) @@ -817,9 +828,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) Unifier state = mkUnifier(local.location); state.ctx = CountMismatch::Result; - state.tryUnify(variablePack, valuePack); + state.tryUnify(valuePack, variablePack); reportErrors(state.errors); + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); + // In the code 'local T = {}', we wish to ascribe the name 'T' to the type of the table for error-reporting purposes. // We also want to do this for 'local T = setmetatable(...)'. if (local.vars.size == 1 && local.values.size == 1) @@ -889,7 +903,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr) TypeId loopVarType = numberType; if (expr.var->annotation) - unify(resolveType(scope, *expr.var->annotation), loopVarType, expr.location); + unify(loopVarType, resolveType(scope, *expr.var->annotation), expr.location); loopScope->bindings[expr.var] = {loopVarType, expr.var->location}; @@ -899,11 +913,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatFor& expr) if (!expr.to) ice("Bad AstStatFor has no to expr"); - unify(loopVarType, checkExpr(loopScope, *expr.from).type, expr.from->location); - unify(loopVarType, checkExpr(loopScope, *expr.to).type, expr.to->location); + unify(checkExpr(loopScope, *expr.from).type, loopVarType, expr.from->location); + unify(checkExpr(loopScope, *expr.to).type, loopVarType, expr.to->location); if (expr.step) - unify(loopVarType, checkExpr(loopScope, *expr.step).type, expr.step->location); + unify(checkExpr(loopScope, *expr.step).type, loopVarType, expr.step->location); check(loopScope, *expr.body); } @@ -956,12 +970,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) if (get(callRetPack)) { iterTy = freshType(scope); - unify(addTypePack({{iterTy}, freshTypePack(scope)}), callRetPack, forin.location); + unify(callRetPack, addTypePack({{iterTy}, freshTypePack(scope)}), forin.location); } else if (get(callRetPack) || !first(callRetPack)) { for (TypeId var : varTypes) - unify(var, errorRecoveryType(scope), forin.location); + unify(errorRecoveryType(scope), var, forin.location); return check(loopScope, *forin.body); } @@ -982,7 +996,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) TypeId varTy = get(iterTy) ? anyType : errorRecoveryType(loopScope); for (TypeId var : varTypes) - unify(var, varTy, forin.location); + unify(varTy, var, forin.location); if (!get(iterTy) && !get(iterTy) && !get(iterTy)) reportError(TypeError{firstValue->location, TypeMismatch{globalScope->bindings[AstName{"next"}].typeId, iterTy}}); @@ -1010,6 +1024,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) Unifier state = mkUnifier(firstValue->location); checkArgumentList(loopScope, state, argPack, iterFunc->argTypes, /*argLocations*/ {}); + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); + reportErrors(state.errors); } @@ -1024,10 +1041,10 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) AstExprCall exprCall{Location(start, end), firstValue, arguments, /* self= */ false, Location()}; TypePackId retPack = checkExprPack(scope, exprCall).type; - unify(varPack, retPack, forin.location); + unify(retPack, varPack, forin.location); } else - unify(varPack, iterFunc->retType, forin.location); + unify(iterFunc->retType, varPack, forin.location); check(loopScope, *forin.body); } @@ -1112,7 +1129,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); - unify(leftType, ty, function.location); + unify(ty, leftType, function.location); if (FFlag::LuauUpdateFunctionNameBinding) { @@ -1242,7 +1259,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias else if (auto mtv = getMutable(follow(ty))) mtv->syntheticName = name; - unify(bindingsMap[name].type, ty, typealias.location); + unify(ty, bindingsMap[name].type, typealias.location); } } @@ -1526,7 +1543,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa { TypeId head = freshType(scope); TypePackId pack = addTypePack(TypePackVar{TypePack{{head}, freshTypePack(scope)}}); - unify(retPack, pack, expr.location); + unify(pack, retPack, expr.location); return {head, std::move(result.predicates)}; } if (get(retPack)) @@ -1598,7 +1615,7 @@ std::optional TypeChecker::getIndexTypeFromType( return it->second.type; else if (auto indexer = tableType->indexer) { - tryUnify(indexer->indexType, stringType, location); + tryUnify(stringType, indexer->indexType, location); return indexer->indexResultType; } else if (tableType->state == TableState::Free) @@ -1824,7 +1841,7 @@ TypeId TypeChecker::checkExprTable( indexer = expectedTable->indexer; if (indexer) - unify(indexer->indexResultType, valueType, value->location); + unify(valueType, indexer->indexResultType, value->location); else indexer = TableIndexer{numberType, anyIfNonstrict(valueType)}; } @@ -1842,13 +1859,13 @@ TypeId TypeChecker::checkExprTable( if (it != expectedTable->props.end()) { Property expectedProp = it->second; - ErrorVec errors = tryUnify(expectedProp.type, exprType, k->location); + ErrorVec errors = tryUnify(exprType, expectedProp.type, k->location); if (errors.empty()) exprType = expectedProp.type; } else if (expectedTable->indexer && isString(expectedTable->indexer->indexType)) { - ErrorVec errors = tryUnify(expectedTable->indexer->indexResultType, exprType, k->location); + ErrorVec errors = tryUnify(exprType, expectedTable->indexer->indexResultType, k->location); if (errors.empty()) exprType = expectedTable->indexer->indexResultType; } @@ -1863,8 +1880,8 @@ TypeId TypeChecker::checkExprTable( if (indexer) { - unify(indexer->indexType, keyType, k->location); - unify(indexer->indexResultType, valueType, value->location); + unify(keyType, indexer->indexType, k->location); + unify(valueType, indexer->indexResultType, value->location); } else if (isNonstrictMode()) { @@ -1992,7 +2009,10 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); Unifier state = mkUnifier(expr.location); - state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); + state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); + + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); TypeId retType = first(retTypePack).value_or(nilType); if (!state.errors.empty()) @@ -2006,7 +2026,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn return {errorRecoveryType(scope)}; } - reportErrors(tryUnify(numberType, operandType, expr.location)); + reportErrors(tryUnify(operandType, numberType, expr.location)); return {numberType}; } case AstExprUnary::Len: @@ -2072,7 +2092,7 @@ TypeId TypeChecker::unionOfTypes(TypeId a, TypeId b, const Location& location, b { if (unifyFreeTypes && (get(a) || get(b))) { - if (unify(a, b, location)) + if (unify(b, a, location)) return a; return errorRecoveryType(anyType); @@ -2175,7 +2195,13 @@ TypeId TypeChecker::checkRelationalOperation( */ Unifier state = mkUnifier(expr.location); if (!isEquality) - state.tryUnify(lhsType, rhsType); + { + state.tryUnify(rhsType, lhsType); + + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); + } + bool needsMetamethod = !isEquality; @@ -2216,13 +2242,16 @@ TypeId TypeChecker::checkRelationalOperation( if (isEquality) { Unifier state = mkUnifier(expr.location); - state.tryUnify(ftv->retType, addTypePack({booleanType})); + state.tryUnify(addTypePack({booleanType}), ftv->retType); if (!state.errors.empty()) { reportError(expr.location, GenericError{format("Metamethod '%s' must return type 'boolean'", metamethodName.c_str())}); return errorRecoveryType(booleanType); } + + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); } } @@ -2230,7 +2259,10 @@ TypeId TypeChecker::checkRelationalOperation( TypeId actualFunctionType = addType(FunctionTypeVar(scope->level, addTypePack({lhsType, rhsType}), addTypePack({booleanType}))); state.tryUnify( - instantiate(scope, *metamethod, expr.location), instantiate(scope, actualFunctionType, expr.location), /*isFunctionCall*/ true); + instantiate(scope, actualFunctionType, expr.location), instantiate(scope, *metamethod, expr.location), /*isFunctionCall*/ true); + + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); reportErrors(state.errors); return booleanType; @@ -2323,7 +2355,7 @@ TypeId TypeChecker::checkBinaryOperation( } if (get(rhsType)) - unify(lhsType, rhsType, expr.location); + unify(rhsType, lhsType, expr.location); if (typeCouldHaveMetatable(lhsType) || typeCouldHaveMetatable(rhsType)) { @@ -2334,7 +2366,7 @@ TypeId TypeChecker::checkBinaryOperation( TypeId expectedFunctionType = addType(FunctionTypeVar(scope->level, arguments, retTypePack)); Unifier state = mkUnifier(expr.location); - state.tryUnify(expectedFunctionType, actualFunctionType, /*isFunctionCall*/ true); + state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); reportErrors(state.errors); bool hasErrors = !state.errors.empty(); @@ -2345,11 +2377,28 @@ TypeId TypeChecker::checkBinaryOperation( // so we loosen the argument types to see if that helps. TypePackId fallbackArguments = freshTypePack(scope); TypeId fallbackFunctionType = addType(FunctionTypeVar(scope->level, fallbackArguments, retTypePack)); - state.log.rollback(); state.errors.clear(); - state.tryUnify(fallbackFunctionType, actualFunctionType, /*isFunctionCall*/ true); - if (!state.errors.empty()) - state.log.rollback(); + + if (FFlag::LuauUseCommittingTxnLog) + { + state.log.clear(); + } + else + { + state.DEPRECATED_log.rollback(); + } + + state.tryUnify(actualFunctionType, fallbackFunctionType, /*isFunctionCall*/ true); + + if (FFlag::LuauUseCommittingTxnLog && state.errors.empty()) + state.log.commit(); + else if (!state.errors.empty() && !FFlag::LuauUseCommittingTxnLog) + state.DEPRECATED_log.rollback(); + } + + if (FFlag::LuauUseCommittingTxnLog && !hasErrors) + { + state.log.commit(); } TypeId retType = first(retTypePack).value_or(nilType); @@ -2377,8 +2426,8 @@ TypeId TypeChecker::checkBinaryOperation( switch (expr.op) { case AstExprBinary::Concat: - reportErrors(tryUnify(addType(UnionTypeVar{{stringType, numberType}}), lhsType, expr.left->location)); - reportErrors(tryUnify(addType(UnionTypeVar{{stringType, numberType}}), rhsType, expr.right->location)); + reportErrors(tryUnify(lhsType, addType(UnionTypeVar{{stringType, numberType}}), expr.left->location)); + reportErrors(tryUnify(rhsType, addType(UnionTypeVar{{stringType, numberType}}), expr.right->location)); return stringType; case AstExprBinary::Add: case AstExprBinary::Sub: @@ -2386,8 +2435,8 @@ TypeId TypeChecker::checkBinaryOperation( case AstExprBinary::Div: case AstExprBinary::Mod: case AstExprBinary::Pow: - reportErrors(tryUnify(numberType, lhsType, expr.left->location)); - reportErrors(tryUnify(numberType, rhsType, expr.right->location)); + reportErrors(tryUnify(lhsType, numberType, expr.left->location)); + reportErrors(tryUnify(rhsType, numberType, expr.right->location)); return numberType; default: // These should have been handled with checkRelationalOperation @@ -2466,10 +2515,10 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy if (FFlag::LuauBidirectionalAsExpr) { // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. - if (canUnify(result.type, annotationType, expr.location).empty()) + if (canUnify(annotationType, result.type, expr.location).empty()) return {annotationType, std::move(result.predicates)}; - if (canUnify(annotationType, result.type, expr.location).empty()) + if (canUnify(result.type, annotationType, expr.location).empty()) return {annotationType, std::move(result.predicates)}; reportError(expr.location, TypesAreUnrelated{result.type, annotationType}); @@ -2477,7 +2526,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy } else { - ErrorVec errorVec = canUnify(result.type, annotationType, expr.location); + ErrorVec errorVec = canUnify(annotationType, result.type, expr.location); reportErrors(errorVec); if (!errorVec.empty()) annotationType = errorRecoveryType(annotationType); @@ -2512,7 +2561,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIf resolve(result.predicates, falseScope, false); ExprResult falseType = checkExpr(falseScope, *expr.falseExpr); - unify(trueType.type, falseType.type, expr.location); + unify(falseType.type, trueType.type, expr.location); // TODO: normalize(UnionTypeVar{{trueType, falseType}}) // For now both trueType and falseType must be the same type. @@ -2607,14 +2656,18 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope else if (auto indexer = lhsTable->indexer) { Unifier state = mkUnifier(expr.location); - state.tryUnify(indexer->indexType, stringType); + state.tryUnify(stringType, indexer->indexType); TypeId retType = indexer->indexResultType; if (!state.errors.empty()) { - state.log.rollback(); + if (!FFlag::LuauUseCommittingTxnLog) + state.DEPRECATED_log.rollback(); + reportError(expr.location, UnknownProperty{lhs, name}); retType = errorRecoveryType(retType); } + else if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); return std::pair(retType, nullptr); } @@ -2713,7 +2766,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (exprTable->indexer) { const TableIndexer& indexer = *exprTable->indexer; - unify(indexer.indexType, indexType, expr.index->location); + unify(indexType, indexer.indexType, expr.index->location); return std::pair(indexer.indexResultType, nullptr); } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) @@ -3106,204 +3159,402 @@ void TypeChecker::checkArgumentList( * A function requires parameters. * To call a function, you supply arguments. */ - TypePackIterator argIter = begin(argPack); - TypePackIterator paramIter = begin(paramPack); + TypePackIterator argIter = begin(argPack, &state.log); + TypePackIterator paramIter = begin(paramPack, &state.log); TypePackIterator endIter = end(argPack); // Important subtlety: All end TypePackIterators are equivalent size_t paramIndex = 0; size_t minParams = getMinParameterCount(paramPack); - while (true) + if (FFlag::LuauUseCommittingTxnLog) { - state.location = paramIndex < argLocations.size() ? argLocations[paramIndex] : state.location; - - if (argIter == endIter && paramIter == endIter) + while (true) { - std::optional argTail = argIter.tail(); - std::optional paramTail = paramIter.tail(); - - // If we hit the end of both type packs simultaneously, then there are definitely no further type - // errors to report. All we need to do is tie up any free tails. - // - // If one side has a free tail and the other has none at all, we create an empty pack and bind the - // free tail to that. + state.location = paramIndex < argLocations.size() ? argLocations[paramIndex] : state.location; - if (argTail) + if (argIter == endIter && paramIter == endIter) { - if (get(*argTail)) + std::optional argTail = argIter.tail(); + std::optional paramTail = paramIter.tail(); + + // If we hit the end of both type packs simultaneously, then there are definitely no further type + // errors to report. All we need to do is tie up any free tails. + // + // If one side has a free tail and the other has none at all, we create an empty pack and bind the + // free tail to that. + + if (argTail) { - if (paramTail) - state.tryUnify(*argTail, *paramTail); - else + if (state.log.getMutable(state.log.follow(*argTail))) { - state.log(*argTail); - *asMutable(*argTail) = TypePack{{}}; + if (paramTail) + state.tryUnify(*paramTail, *argTail); + else + state.log.replace(*argTail, TypePackVar(TypePack{{}})); } } - } - else if (paramTail) - { - // argTail is definitely empty - if (get(*paramTail)) + else if (paramTail) { - state.log(*paramTail); - *asMutable(*paramTail) = TypePack{{}}; + // argTail is definitely empty + if (state.log.getMutable(state.log.follow(*paramTail))) + state.log.replace(*paramTail, TypePackVar(TypePack{{}})); } + + return; } + else if (argIter == endIter) + { + // Not enough arguments. - return; - } - else if (argIter == endIter) - { - // Not enough arguments. + // Might be ok if we are forwarding a vararg along. This is a common thing to occur in nonstrict mode. + if (argIter.tail()) + { + TypePackId tail = *argIter.tail(); + if (state.log.getMutable(tail)) + { + // Unify remaining parameters so we don't leave any free-types hanging around. + while (paramIter != endIter) + { + state.tryUnify(errorRecoveryType(anyType), *paramIter); + ++paramIter; + } + return; + } + else if (auto vtp = state.log.getMutable(tail)) + { + while (paramIter != endIter) + { + state.tryUnify(vtp->ty, *paramIter); + ++paramIter; + } + + return; + } + else if (state.log.getMutable(tail)) + { + std::vector rest; + rest.reserve(std::distance(paramIter, endIter)); + while (paramIter != endIter) + { + rest.push_back(*paramIter); + ++paramIter; + } + + TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, paramIter.tail()}}); + state.tryUnify(varPack, tail); + return; + } + } - // Might be ok if we are forwarding a vararg along. This is a common thing to occur in nonstrict mode. - if (argIter.tail()) + // If any remaining unfulfilled parameters are nonoptional, this is a problem. + while (paramIter != endIter) + { + TypeId t = state.log.follow(*paramIter); + if (isOptional(t)) + { + } // ok + else if (state.log.getMutable(t)) + { + } // ok + else if (isNonstrictMode() && state.log.getMutable(t)) + { + } // ok + else + { + state.errors.push_back(TypeError{state.location, CountMismatch{minParams, paramIndex}}); + return; + } + ++paramIter; + } + } + else if (paramIter == endIter) { - TypePackId tail = *argIter.tail(); - if (get(tail)) + // too many parameters passed + if (!paramIter.tail()) { - // Unify remaining parameters so we don't leave any free-types hanging around. - while (paramIter != endIter) + while (argIter != endIter) { - state.tryUnify(*paramIter, errorRecoveryType(anyType)); - ++paramIter; + // The use of unify here is deliberate. We don't want this unification + // to be undoable. + unify(errorRecoveryType(scope), *argIter, state.location); + ++argIter; } + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); return; } - else if (auto vtp = get(tail)) + TypePackId tail = state.log.follow(*paramIter.tail()); + + if (state.log.getMutable(tail)) { - while (paramIter != endIter) + // Function is variadic. Ok. + return; + } + else if (auto vtp = state.log.getMutable(tail)) + { + // Function is variadic and requires that all subsequent parameters + // be compatible with a type. + size_t argIndex = paramIndex; + while (argIter != endIter) { - state.tryUnify(*paramIter, vtp->ty); - ++paramIter; + Location location = state.location; + + if (argIndex < argLocations.size()) + location = argLocations[argIndex]; + + unify(*argIter, vtp->ty, location); + ++argIter; + ++argIndex; } return; } - else if (get(tail)) + else if (state.log.getMutable(tail)) { + // Create a type pack out of the remaining argument types + // and unify it with the tail. std::vector rest; - rest.reserve(std::distance(paramIter, endIter)); - while (paramIter != endIter) + rest.reserve(std::distance(argIter, endIter)); + while (argIter != endIter) { - rest.push_back(*paramIter); - ++paramIter; + rest.push_back(*argIter); + ++argIter; } - TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, paramIter.tail()}}); - state.tryUnify(tail, varPack); + TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); + state.tryUnify(varPack, tail); return; } - } - - // If any remaining unfulfilled parameters are nonoptional, this is a problem. - while (paramIter != endIter) - { - TypeId t = follow(*paramIter); - if (isOptional(t)) + else if (state.log.getMutable(tail)) { - } // ok - else if (get(t)) - { - } // ok - else if (isNonstrictMode() && get(t)) - { - } // ok - else + state.log.replace(tail, TypePackVar(TypePack{{}})); + return; + } + else if (state.log.getMutable(tail)) { - state.errors.push_back(TypeError{state.location, CountMismatch{minParams, paramIndex}}); + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + // TODO: Better error message? + state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); return; } + } + else + { + unifyWithInstantiationIfNeeded(scope, *argIter, *paramIter, state); + ++argIter; ++paramIter; } + + ++paramIndex; } - else if (paramIter == endIter) + } + else + { + while (true) { - // too many parameters passed - if (!paramIter.tail()) + state.location = paramIndex < argLocations.size() ? argLocations[paramIndex] : state.location; + + if (argIter == endIter && paramIter == endIter) { - while (argIter != endIter) + std::optional argTail = argIter.tail(); + std::optional paramTail = paramIter.tail(); + + // If we hit the end of both type packs simultaneously, then there are definitely no further type + // errors to report. All we need to do is tie up any free tails. + // + // If one side has a free tail and the other has none at all, we create an empty pack and bind the + // free tail to that. + + if (argTail) { - unify(*argIter, errorRecoveryType(scope), state.location); - ++argIter; + if (get(*argTail)) + { + if (paramTail) + state.tryUnify(*paramTail, *argTail); + else + { + state.DEPRECATED_log(*argTail); + *asMutable(*argTail) = TypePack{{}}; + } + } + } + else if (paramTail) + { + // argTail is definitely empty + if (get(*paramTail)) + { + state.DEPRECATED_log(*paramTail); + *asMutable(*paramTail) = TypePack{{}}; + } } - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); - return; - } - TypePackId tail = *paramIter.tail(); - if (get(tail)) - { - // Function is variadic. Ok. return; } - else if (auto vtp = get(tail)) + else if (argIter == endIter) { - // Function is variadic and requires that all subsequent parameters - // be compatible with a type. - size_t argIndex = paramIndex; - while (argIter != endIter) + // Not enough arguments. + + // Might be ok if we are forwarding a vararg along. This is a common thing to occur in nonstrict mode. + if (argIter.tail()) { - Location location = state.location; + TypePackId tail = *argIter.tail(); + if (get(tail)) + { + // Unify remaining parameters so we don't leave any free-types hanging around. + while (paramIter != endIter) + { + state.tryUnify(*paramIter, errorRecoveryType(anyType)); + ++paramIter; + } + return; + } + else if (auto vtp = get(tail)) + { + while (paramIter != endIter) + { + state.tryUnify(*paramIter, vtp->ty); + ++paramIter; + } - if (argIndex < argLocations.size()) - location = argLocations[argIndex]; + return; + } + else if (get(tail)) + { + std::vector rest; + rest.reserve(std::distance(paramIter, endIter)); + while (paramIter != endIter) + { + rest.push_back(*paramIter); + ++paramIter; + } - unify(vtp->ty, *argIter, location); - ++argIter; - ++argIndex; + TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, paramIter.tail()}}); + state.tryUnify(varPack, tail); + return; + } } - return; + // If any remaining unfulfilled parameters are nonoptional, this is a problem. + while (paramIter != endIter) + { + TypeId t = follow(*paramIter); + if (isOptional(t)) + { + } // ok + else if (get(t)) + { + } // ok + else if (isNonstrictMode() && get(t)) + { + } // ok + else + { + state.errors.push_back(TypeError{state.location, CountMismatch{minParams, paramIndex}}); + return; + } + ++paramIter; + } } - else if (get(tail)) + else if (paramIter == endIter) { - // Create a type pack out of the remaining argument types - // and unify it with the tail. - std::vector rest; - rest.reserve(std::distance(argIter, endIter)); - while (argIter != endIter) + // too many parameters passed + if (!paramIter.tail()) { - rest.push_back(*argIter); - ++argIter; + while (argIter != endIter) + { + unify(*argIter, errorRecoveryType(scope), state.location); + ++argIter; + } + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + return; } + TypePackId tail = *paramIter.tail(); - TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); - state.tryUnify(tail, varPack); - return; - } - else if (get(tail)) - { - state.log(tail); - *asMutable(tail) = TypePack{}; + if (get(tail)) + { + // Function is variadic. Ok. + return; + } + else if (auto vtp = get(tail)) + { + // Function is variadic and requires that all subsequent parameters + // be compatible with a type. + size_t argIndex = paramIndex; + while (argIter != endIter) + { + Location location = state.location; - return; + if (argIndex < argLocations.size()) + location = argLocations[argIndex]; + + unify(*argIter, vtp->ty, location); + ++argIter; + ++argIndex; + } + + return; + } + else if (get(tail)) + { + // Create a type pack out of the remaining argument types + // and unify it with the tail. + std::vector rest; + rest.reserve(std::distance(argIter, endIter)); + while (argIter != endIter) + { + rest.push_back(*argIter); + ++argIter; + } + + TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); + state.tryUnify(tail, varPack); + return; + } + else if (get(tail)) + { + if (FFlag::LuauUseCommittingTxnLog) + { + state.log.replace(tail, TypePackVar(TypePack{{}})); + } + else + { + state.DEPRECATED_log(tail); + *asMutable(tail) = TypePack{}; + } + + return; + } + else if (get(tail)) + { + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + // TODO: Better error message? + state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + return; + } } - else if (get(tail)) + else { - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - // TODO: Better error message? - state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); - return; + unifyWithInstantiationIfNeeded(scope, *argIter, *paramIter, state); + ++argIter; + ++paramIter; } - } - else - { - unifyWithInstantiationIfNeeded(scope, *paramIter, *argIter, state); - ++argIter; - ++paramIter; - } - ++paramIndex; + ++paramIndex; + } } } @@ -3475,7 +3726,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (get(fn)) { - unify(argPack, anyTypePack, expr.location); + unify(anyTypePack, argPack, expr.location); return {{anyTypePack}}; } @@ -3490,7 +3741,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope // has been instantiated, so is a monotype. We can therefore // unify it with a monomorphic function. TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); - unify(r, fn, expr.location); + unify(fn, r, expr.location); return {{retPack}}; } @@ -3533,7 +3784,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope if (!ftv) { reportError(TypeError{expr.func->location, CannotCallNonFunction{fn}}); - unify(retPack, errorRecoveryTypePack(scope), expr.func->location); + unify(errorRecoveryTypePack(scope), retPack, expr.func->location); return {{errorRecoveryTypePack(retPack)}}; } @@ -3552,7 +3803,9 @@ std::optional> TypeChecker::checkCallOverload(const Scope checkArgumentList(scope, state, retPack, ftv->retType, /*argLocations*/ {}); if (!state.errors.empty()) { - state.log.rollback(); + if (!FFlag::LuauUseCommittingTxnLog) + state.DEPRECATED_log.rollback(); + return {}; } @@ -3580,10 +3833,15 @@ std::optional> TypeChecker::checkCallOverload(const Scope overloadsThatDont.push_back(fn); errors.emplace_back(std::move(state.errors), args->head, ftv); - state.log.rollback(); + + if (!FFlag::LuauUseCommittingTxnLog) + state.DEPRECATED_log.rollback(); } else { + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); + if (isNonstrictMode() && !expr.self && expr.func->is() && ftv->hasSelf) { // If we are running in nonstrict mode, passing fewer arguments than the function is declared to take AND @@ -3640,6 +3898,9 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal if (editedState.errors.empty()) { + if (FFlag::LuauUseCommittingTxnLog) + editedState.log.commit(); + reportError(TypeError{expr.location, FunctionDoesNotTakeSelf{}}); // This is a little bit suspect: If this overload would work with a . replaced by a : // we eagerly assume that that's what you actually meant and we commit to it. @@ -3648,8 +3909,8 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); return true; } - else - editedState.log.rollback(); + else if (!FFlag::LuauUseCommittingTxnLog) + editedState.DEPRECATED_log.rollback(); } else if (ftv->hasSelf) { @@ -3671,6 +3932,9 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal if (editedState.errors.empty()) { + if (FFlag::LuauUseCommittingTxnLog) + editedState.log.commit(); + reportError(TypeError{expr.location, FunctionRequiresSelf{}}); // This is a little bit suspect: If this overload would work with a : replaced by a . // we eagerly assume that that's what you actually meant and we commit to it. @@ -3679,8 +3943,8 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); return true; } - else - editedState.log.rollback(); + else if (!FFlag::LuauUseCommittingTxnLog) + editedState.DEPRECATED_log.rollback(); } } } @@ -3740,6 +4004,9 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast checkArgumentList(scope, state, argPack, ftv->argTypes, argLocations); } + if (FFlag::LuauUseCommittingTxnLog && state.errors.empty()) + state.log.commit(); + if (i > 0) s += "; "; @@ -3748,7 +4015,8 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast s += toString(overload); - state.log.rollback(); + if (!FFlag::LuauUseCommittingTxnLog) + state.DEPRECATED_log.rollback(); } if (overloadsThatMatchArgCount.size() == 0) @@ -3781,6 +4049,8 @@ ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const L Unifier state = mkUnifier(location); + std::vector inverseLogs; + for (size_t i = 0; i < exprs.size; ++i) { AstExpr* expr = exprs.data[i]; @@ -3791,18 +4061,15 @@ ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const L auto [typePack, exprPredicates] = checkExprPack(scope, *expr); insert(exprPredicates); - if (FFlag::LuauTailArgumentTypeInfo) + if (std::optional firstTy = first(typePack)) { - if (std::optional firstTy = first(typePack)) - { - if (!currentModule->astTypes.find(expr)) - currentModule->astTypes[expr] = follow(*firstTy); - } - - if (expectedType) - currentModule->astExpectedTypes[expr] = *expectedType; + if (!currentModule->astTypes.find(expr)) + currentModule->astTypes[expr] = follow(*firstTy); } + if (expectedType) + currentModule->astExpectedTypes[expr] = *expectedType; + tp->tail = typePack; } else @@ -3816,13 +4083,31 @@ ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const L actualType = instantiate(scope, actualType, expr->location); if (expectedType) - state.tryUnify(*expectedType, actualType); + { + state.tryUnify(actualType, *expectedType); + + // Ugly: In future iterations of the loop, we might need the state of the unification we + // just performed. There's not a great way to pass that into checkExpr. Instead, we store + // the inverse of the current log, and commit it. When we're done, we'll commit all the + // inverses. This isn't optimal, and a better solution is welcome here. + if (FFlag::LuauUseCommittingTxnLog) + { + inverseLogs.push_back(state.log.inverse()); + state.log.commit(); + } + } tp->head.push_back(actualType); } } - state.log.rollback(); + if (FFlag::LuauUseCommittingTxnLog) + { + for (TxnLog& log : inverseLogs) + log.commit(); + } + else + state.DEPRECATED_log.rollback(); return {pack, predicates}; } @@ -3884,7 +4169,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module TypePackId modulePack = module->getModuleScope()->returnType; - if (FFlag::LuauModuleRequireErrorPack && get(modulePack)) + if (get(modulePack)) return errorRecoveryType(scope); std::optional moduleType = first(modulePack); @@ -3917,72 +4202,94 @@ TypeId TypeChecker::anyIfNonstrict(TypeId ty) const return ty; } -bool TypeChecker::unify(TypeId left, TypeId right, const Location& location) +bool TypeChecker::unify(TypeId subTy, TypeId superTy, const Location& location) { Unifier state = mkUnifier(location); - state.tryUnify(left, right); + state.tryUnify(subTy, superTy); + + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); reportErrors(state.errors); return state.errors.empty(); } -bool TypeChecker::unify(TypePackId left, TypePackId right, const Location& location, CountMismatch::Context ctx) +bool TypeChecker::unify(TypePackId subTy, TypePackId superTy, const Location& location, CountMismatch::Context ctx) { Unifier state = mkUnifier(location); state.ctx = ctx; - state.tryUnify(left, right); + state.tryUnify(subTy, superTy); + + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); reportErrors(state.errors); return state.errors.empty(); } -bool TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, const Location& location) +bool TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId subTy, TypeId superTy, const Location& location) { Unifier state = mkUnifier(location); - unifyWithInstantiationIfNeeded(scope, left, right, state); + unifyWithInstantiationIfNeeded(scope, subTy, superTy, state); + + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); reportErrors(state.errors); return state.errors.empty(); } -void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId left, TypeId right, Unifier& state) +void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId subTy, TypeId superTy, Unifier& state) { - if (!maybeGeneric(right)) + if (!maybeGeneric(subTy)) // Quick check to see if we definitely can't instantiate - state.tryUnify(left, right, /*isFunctionCall*/ false); - else if (!maybeGeneric(left) && isGeneric(right)) + state.tryUnify(subTy, superTy, /*isFunctionCall*/ false); + else if (!maybeGeneric(superTy) && isGeneric(subTy)) { // Quick check to see if we definitely have to instantiate - TypeId instantiated = instantiate(scope, right, state.location); - state.tryUnify(left, instantiated, /*isFunctionCall*/ false); + TypeId instantiated = instantiate(scope, subTy, state.location); + state.tryUnify(instantiated, superTy, /*isFunctionCall*/ false); } else { // First try unifying with the original uninstantiated type // but if that fails, try the instantiated one. Unifier child = state.makeChildUnifier(); - child.tryUnify(left, right, /*isFunctionCall*/ false); + child.tryUnify(subTy, superTy, /*isFunctionCall*/ false); if (!child.errors.empty()) { - TypeId instantiated = instantiate(scope, right, state.location); - if (right == instantiated) + TypeId instantiated = instantiate(scope, subTy, state.location); + if (subTy == instantiated) { // Instantiating the argument made no difference, so just report any child errors - state.log.concat(std::move(child.log)); + if (FFlag::LuauUseCommittingTxnLog) + state.log.concat(std::move(child.log)); + else + state.DEPRECATED_log.concat(std::move(child.DEPRECATED_log)); + state.errors.insert(state.errors.end(), child.errors.begin(), child.errors.end()); } else { - child.log.rollback(); - state.tryUnify(left, instantiated, /*isFunctionCall*/ false); + if (!FFlag::LuauUseCommittingTxnLog) + child.DEPRECATED_log.rollback(); + + state.tryUnify(instantiated, superTy, /*isFunctionCall*/ false); } } else { - state.log.concat(std::move(child.log)); + if (FFlag::LuauUseCommittingTxnLog) + { + state.log.concat(std::move(child.log)); + } + else + { + state.DEPRECATED_log.concat(std::move(child.DEPRECATED_log)); + } } } } @@ -4139,7 +4446,7 @@ TypePackId Quantification::clean(TypePackId tp) bool Anyification::isDirty(TypeId ty) { if (const TableTypeVar* ttv = get(ty)) - return (ttv->state == TableState::Free); + return (ttv->state == TableState::Free || (FFlag::LuauSealExports && ttv->state == TableState::Unsealed)); else if (get(ty)) return true; else @@ -4162,6 +4469,12 @@ TypeId Anyification::clean(TypeId ty) TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed}; clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; + if (FFlag::LuauSealExports) + { + clone.name = ttv->name; + clone.syntheticName = ttv->syntheticName; + clone.tags = ttv->tags; + } return addType(std::move(clone)); } else @@ -5194,8 +5507,8 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement // This by itself is not truly enough to determine that A is stronger than B or vice versa. // The best unambiguous way about this would be to have a function that returns the relationship ordering of a pair. // i.e. TypeRelationship relationshipOf(TypeId superTy, TypeId subTy) - bool optionIsSubtype = canUnify(isaP.ty, option, isaP.location).empty(); - bool targetIsSubtype = canUnify(option, isaP.ty, isaP.location).empty(); + bool optionIsSubtype = canUnify(option, isaP.ty, isaP.location).empty(); + bool targetIsSubtype = canUnify(isaP.ty, option, isaP.location).empty(); // If A is a superset of B, then if sense is true, we promote A to B, otherwise we keep A. if (!optionIsSubtype && targetIsSubtype) @@ -5379,7 +5692,7 @@ void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMa for (TypeId right : rhs) { // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. - if (canUnify(left, right, eqP.location).empty() == sense || (!sense && !isNil(left))) + if (canUnify(right, left, eqP.location).empty() == sense || (!sense && !isNil(left))) set.insert(left); } } @@ -5406,7 +5719,7 @@ std::vector TypeChecker::unTypePack(const ScopePtr& scope, TypePackId tp for (size_t i = 0; i < expectedLength; ++i) expectedPack->head.push_back(freshType(scope)); - unify(expectedTypePack, tp, location); + unify(tp, expectedTypePack, location); for (TypeId& tp : expectedPack->head) tp = follow(tp); diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index d3221c732..b15548a8d 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -1,8 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypePack.h" +#include "Luau/TxnLog.h" + #include +LUAU_FASTFLAG(LuauUseCommittingTxnLog) + namespace Luau { @@ -35,14 +39,28 @@ TypePackVar& TypePackVar::operator=(TypePackVariant&& tp) } TypePackIterator::TypePackIterator(TypePackId typePack) + : TypePackIterator(typePack, TxnLog::empty()) +{ +} + +TypePackIterator::TypePackIterator(TypePackId typePack, const TxnLog* log) : currentTypePack(follow(typePack)) , tp(get(currentTypePack)) , currentIndex(0) + , log(log) { while (tp && tp->head.empty()) { - currentTypePack = tp->tail ? follow(*tp->tail) : nullptr; - tp = currentTypePack ? get(currentTypePack) : nullptr; + if (FFlag::LuauUseCommittingTxnLog) + { + currentTypePack = tp->tail ? log->follow(*tp->tail) : nullptr; + tp = currentTypePack ? log->getMutable(currentTypePack) : nullptr; + } + else + { + currentTypePack = tp->tail ? follow(*tp->tail) : nullptr; + tp = currentTypePack ? get(currentTypePack) : nullptr; + } } } @@ -53,8 +71,17 @@ TypePackIterator& TypePackIterator::operator++() ++currentIndex; while (tp && currentIndex >= tp->head.size()) { - currentTypePack = tp->tail ? follow(*tp->tail) : nullptr; - tp = currentTypePack ? get(currentTypePack) : nullptr; + if (FFlag::LuauUseCommittingTxnLog) + { + currentTypePack = tp->tail ? log->follow(*tp->tail) : nullptr; + tp = currentTypePack ? log->getMutable(currentTypePack) : nullptr; + } + else + { + currentTypePack = tp->tail ? follow(*tp->tail) : nullptr; + tp = currentTypePack ? get(currentTypePack) : nullptr; + } + currentIndex = 0; } @@ -95,6 +122,11 @@ TypePackIterator begin(TypePackId tp) return TypePackIterator{tp}; } +TypePackIterator begin(TypePackId tp, TxnLog* log) +{ + return TypePackIterator{tp, log}; +} + TypePackIterator end(TypePackId tp) { return TypePackIterator{}; @@ -160,8 +192,15 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs) TypePackId follow(TypePackId tp) { - auto advance = [](TypePackId ty) -> std::optional { - if (const Unifiable::Bound* btv = get>(ty)) + return follow(tp, [](TypePackId t) { + return t; + }); +} + +TypePackId follow(TypePackId tp, std::function mapper) +{ + auto advance = [&mapper](TypePackId ty) -> std::optional { + if (const Unifiable::Bound* btv = get>(mapper(ty))) return btv->boundTo; else return std::nullopt; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index fb75aa02e..4cab79c8a 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -31,17 +31,24 @@ std::optional> magicFunctionFormat( TypeId follow(TypeId t) { - auto advance = [](TypeId ty) -> std::optional { - if (auto btv = get>(ty)) + return follow(t, [](TypeId t) { + return t; + }); +} + +TypeId follow(TypeId t, std::function mapper) +{ + auto advance = [&mapper](TypeId ty) -> std::optional { + if (auto btv = get>(mapper(ty))) return btv->boundTo; - else if (auto ttv = get(ty)) + else if (auto ttv = get(mapper(ty))) return ttv->boundTo; else return std::nullopt; }; - auto force = [](TypeId ty) { - if (auto ltv = get_if(&ty->ty)) + auto force = [&mapper](TypeId ty) { + if (auto ltv = get_if(&mapper(ty)->ty)) { TypeId res = ltv->thunk(); if (get(res)) @@ -1004,7 +1011,7 @@ std::optional> magicFunctionFormat( { Location location = expr.args.data[std::min(i + dataOffset, expr.args.size - 1)]->location; - typechecker.unify(expected[i], params[i + paramOffset], location); + typechecker.unify(params[i + paramOffset], expected[i], location); } // if we know the argument count or if we have too many arguments for sure, we can issue an error diff --git a/Analysis/src/TypedAllocator.cpp b/Analysis/src/TypedAllocator.cpp index f037351e5..1f7ef8c25 100644 --- a/Analysis/src/TypedAllocator.cpp +++ b/Analysis/src/TypedAllocator.cpp @@ -20,6 +20,7 @@ const size_t kPageSize = sysconf(_SC_PAGESIZE); #include LUAU_FASTFLAG(DebugLuauFreezeArena) +LUAU_FASTFLAGVARIABLE(LuauTypedAllocatorZeroStart, false) namespace Luau { diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 43ea37e7b..393a84a70 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -13,6 +13,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); +LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) @@ -29,27 +30,39 @@ namespace Luau struct PromoteTypeLevels { + DEPRECATED_TxnLog& DEPRECATED_log; TxnLog& log; TypeLevel minLevel; - explicit PromoteTypeLevels(TxnLog& log, TypeLevel minLevel) - : log(log) + explicit PromoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, TypeLevel minLevel) + : DEPRECATED_log(DEPRECATED_log) + , log(log) , minLevel(minLevel) - {} + { + } - template + template void promote(TID ty, T* t) { LUAU_ASSERT(t); if (minLevel.subsumesStrict(t->level)) { - log(ty); - t->level = minLevel; + if (FFlag::LuauUseCommittingTxnLog) + { + log.changeLevel(ty, minLevel); + } + else + { + DEPRECATED_log(ty); + t->level = minLevel; + } } } template - void cycle(TID) {} + void cycle(TID) + { + } template bool operator()(TID, const T&) @@ -59,39 +72,47 @@ struct PromoteTypeLevels bool operator()(TypeId ty, const FreeTypeVar&) { - promote(ty, getMutable(ty)); + // Surprise, it's actually a BoundTypeVar that hasn't been committed yet. + // Calling getMutable on this will trigger an assertion. + if (FFlag::LuauUseCommittingTxnLog && !log.is(ty)) + return true; + + promote(ty, FFlag::LuauUseCommittingTxnLog ? log.getMutable(ty) : getMutable(ty)); return true; } bool operator()(TypeId ty, const FunctionTypeVar&) { - promote(ty, getMutable(ty)); + promote(ty, FFlag::LuauUseCommittingTxnLog ? log.getMutable(ty) : getMutable(ty)); return true; } - bool operator()(TypeId ty, const TableTypeVar&) + bool operator()(TypeId ty, const TableTypeVar& ttv) { - promote(ty, getMutable(ty)); + if (ttv.state != TableState::Free && ttv.state != TableState::Generic) + return true; + + promote(ty, FFlag::LuauUseCommittingTxnLog ? log.getMutable(ty) : getMutable(ty)); return true; } bool operator()(TypePackId tp, const FreeTypePack&) { - promote(tp, getMutable(tp)); + promote(tp, FFlag::LuauUseCommittingTxnLog ? log.getMutable(tp) : getMutable(tp)); return true; } }; -void promoteTypeLevels(TxnLog& log, TypeLevel minLevel, TypeId ty) +void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, TypeLevel minLevel, TypeId ty) { - PromoteTypeLevels ptl{log, minLevel}; + PromoteTypeLevels ptl{DEPRECATED_log, log, minLevel}; DenseHashSet seen{nullptr}; visitTypeVarOnce(ty, ptl, seen); } -void promoteTypeLevels(TxnLog& log, TypeLevel minLevel, TypePackId tp) +void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, TypeLevel minLevel, TypePackId tp) { - PromoteTypeLevels ptl{log, minLevel}; + PromoteTypeLevels ptl{DEPRECATED_log, log, minLevel}; DenseHashSet seen{nullptr}; visitTypeVarOnce(tp, ptl, seen); } @@ -221,10 +242,12 @@ static std::optional> getTableMat return std::nullopt; } -Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState) +Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState, + TxnLog* parentLog) : types(types) , mode(mode) , globalScope(std::move(globalScope)) + , log(parentLog) , location(location) , variance(variance) , sharedState(sharedState) @@ -233,11 +256,12 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Locati } Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, - Variance variance, UnifierSharedState& sharedState) + Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) : types(types) , mode(mode) , globalScope(std::move(globalScope)) - , log(sharedSeen) + , DEPRECATED_log(sharedSeen) + , log(parentLog, sharedSeen) , location(location) , variance(variance) , sharedState(sharedState) @@ -245,14 +269,14 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector< LUAU_ASSERT(sharedState.iceHandler); } -void Unifier::tryUnify(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) +void Unifier::tryUnify(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) { sharedState.counters.iterationCount = 0; - tryUnify_(superTy, subTy, isFunctionCall, isIntersection); + tryUnify_(subTy, superTy, isFunctionCall, isIntersection); } -void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool isIntersection) +void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool isIntersection) { RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); @@ -264,55 +288,112 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool return; } - superTy = follow(superTy); - subTy = follow(subTy); + if (FFlag::LuauUseCommittingTxnLog) + { + superTy = log.follow(superTy); + subTy = log.follow(subTy); + } + else + { + superTy = follow(superTy); + subTy = follow(subTy); + } if (superTy == subTy) return; - auto l = getMutable(superTy); - auto r = getMutable(subTy); + auto superFree = getMutable(superTy); + auto subFree = getMutable(subTy); + + if (FFlag::LuauUseCommittingTxnLog) + { + superFree = log.getMutable(superTy); + subFree = log.getMutable(subTy); + } - if (l && r && l->level.subsumes(r->level)) + if (superFree && subFree && superFree->level.subsumes(subFree->level)) { occursCheck(subTy, superTy); // The occurrence check might have caused superTy no longer to be a free type - if (!get(subTy)) + bool occursFailed = false; + if (FFlag::LuauUseCommittingTxnLog) + occursFailed = bool(log.getMutable(subTy)); + else + occursFailed = bool(get(subTy)); + + if (!occursFailed) { - log(subTy); - *asMutable(subTy) = BoundTypeVar(superTy); + if (FFlag::LuauUseCommittingTxnLog) + { + log.replace(subTy, BoundTypeVar(superTy)); + } + else + { + DEPRECATED_log(subTy); + *asMutable(subTy) = BoundTypeVar(superTy); + } } return; } - else if (l && r) + else if (superFree && subFree) { - if (!FFlag::LuauErrorRecoveryType) - log(superTy); + if (!FFlag::LuauErrorRecoveryType && !FFlag::LuauUseCommittingTxnLog) + { + DEPRECATED_log(superTy); + subFree->level = min(subFree->level, superFree->level); + } + occursCheck(superTy, subTy); - r->level = min(r->level, l->level); - // The occurrence check might have caused superTy no longer to be a free type - if (!FFlag::LuauErrorRecoveryType) - *asMutable(superTy) = BoundTypeVar(subTy); - else if (!get(superTy)) + bool occursFailed = false; + if (FFlag::LuauUseCommittingTxnLog) + occursFailed = bool(log.getMutable(superTy)); + else + occursFailed = bool(get(superTy)); + + if (!FFlag::LuauErrorRecoveryType && !FFlag::LuauUseCommittingTxnLog) { - log(superTy); *asMutable(superTy) = BoundTypeVar(subTy); + return; + } + + if (!occursFailed) + { + if (FFlag::LuauUseCommittingTxnLog) + { + if (superFree->level.subsumes(subFree->level)) + { + log.changeLevel(subTy, superFree->level); + } + + log.replace(superTy, BoundTypeVar(subTy)); + } + else + { + DEPRECATED_log(superTy); + *asMutable(superTy) = BoundTypeVar(subTy); + subFree->level = min(subFree->level, superFree->level); + } } return; } - else if (l) + else if (superFree) { occursCheck(superTy, subTy); + bool occursFailed = false; + if (FFlag::LuauUseCommittingTxnLog) + occursFailed = bool(log.getMutable(superTy)); + else + occursFailed = bool(get(superTy)); - TypeLevel superLevel = l->level; + TypeLevel superLevel = superFree->level; // Unification can't change the level of a generic. - auto rightGeneric = get(subTy); - if (rightGeneric && !rightGeneric->level.subsumes(superLevel)) + auto subGeneric = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy); + if (subGeneric && !subGeneric->level.subsumes(superLevel)) { // TODO: a more informative error message? CLI-39912 errors.push_back(TypeError{location, GenericError{"Generic subtype escaping scope"}}); @@ -320,63 +401,83 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool } // The occurrence check might have caused superTy no longer to be a free type - if (!get(superTy)) + if (!occursFailed) { - if (FFlag::LuauProperTypeLevels) - promoteTypeLevels(log, superLevel, subTy); - else if (auto rightLevel = getMutableLevel(subTy)) + if (FFlag::LuauUseCommittingTxnLog) { - if (!rightLevel->subsumes(l->level)) - *rightLevel = l->level; + promoteTypeLevels(DEPRECATED_log, log, superLevel, subTy); + log.replace(superTy, BoundTypeVar(subTy)); } + else + { + if (FFlag::LuauProperTypeLevels) + promoteTypeLevels(DEPRECATED_log, log, superLevel, subTy); + else if (auto subLevel = getMutableLevel(subTy)) + { + if (!subLevel->subsumes(superFree->level)) + *subLevel = superFree->level; + } - log(superTy); - *asMutable(superTy) = BoundTypeVar(subTy); + DEPRECATED_log(superTy); + *asMutable(superTy) = BoundTypeVar(subTy); + } } return; } - else if (r) + else if (subFree) { - TypeLevel subLevel = r->level; + TypeLevel subLevel = subFree->level; occursCheck(subTy, superTy); + bool occursFailed = false; + if (FFlag::LuauUseCommittingTxnLog) + occursFailed = bool(log.getMutable(subTy)); + else + occursFailed = bool(get(subTy)); // Unification can't change the level of a generic. - auto leftGeneric = get(superTy); - if (leftGeneric && !leftGeneric->level.subsumes(r->level)) + auto superGeneric = FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy); + if (superGeneric && !superGeneric->level.subsumes(subFree->level)) { // TODO: a more informative error message? CLI-39912 errors.push_back(TypeError{location, GenericError{"Generic supertype escaping scope"}}); return; } - if (!get(subTy)) + if (!occursFailed) { - if (FFlag::LuauProperTypeLevels) - promoteTypeLevels(log, subLevel, superTy); - - if (auto superLevel = getMutableLevel(superTy)) + if (FFlag::LuauUseCommittingTxnLog) { - if (!superLevel->subsumes(r->level)) + promoteTypeLevels(DEPRECATED_log, log, subLevel, superTy); + log.replace(subTy, BoundTypeVar(superTy)); + } + else + { + if (FFlag::LuauProperTypeLevels) + promoteTypeLevels(DEPRECATED_log, log, subLevel, superTy); + else if (auto superLevel = getMutableLevel(superTy)) { - log(superTy); - *superLevel = r->level; + if (!superLevel->subsumes(subFree->level)) + { + DEPRECATED_log(superTy); + *superLevel = subFree->level; + } } - } - log(subTy); - *asMutable(subTy) = BoundTypeVar(superTy); + DEPRECATED_log(subTy); + *asMutable(subTy) = BoundTypeVar(superTy); + } } return; } if (get(superTy) || get(superTy)) - return tryUnifyWithAny(superTy, subTy); + return tryUnifyWithAny(subTy, superTy); if (get(subTy) || get(subTy)) - return tryUnifyWithAny(subTy, superTy); + return tryUnifyWithAny(superTy, subTy); bool cacheEnabled = !isFunctionCall && !isIntersection; auto& cache = sharedState.cachedUnify; @@ -389,12 +490,22 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool // Here, we assume that the types unify. If they do not, we will find out as we roll back // the stack. - if (log.haveSeen(superTy, subTy)) - return; + if (FFlag::LuauUseCommittingTxnLog) + { + if (log.haveSeen(superTy, subTy)) + return; + + log.pushSeen(superTy, subTy); + } + else + { + if (DEPRECATED_log.haveSeen(superTy, subTy)) + return; - log.pushSeen(superTy, subTy); + DEPRECATED_log.pushSeen(superTy, subTy); + } - if (const UnionTypeVar* uv = get(subTy)) + if (const UnionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) { // A | B <: T if A <: T and B <: T bool failed = false; @@ -407,7 +518,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool for (TypeId type : uv->options) { Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(superTy, type); + innerState.tryUnify_(type, superTy); if (auto e = hasUnificationTooComplex(innerState.errors)) unificationTooComplex = e; @@ -420,10 +531,24 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool failed = true; } - if (i != count - 1) - innerState.log.rollback(); + if (FFlag::LuauUseCommittingTxnLog) + { + if (i == count - 1) + { + log.concat(std::move(innerState.log)); + } + } else - log.concat(std::move(innerState.log)); + { + if (i != count - 1) + { + innerState.DEPRECATED_log.rollback(); + } + else + { + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + } + } ++i; } @@ -438,7 +563,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } } - else if (const UnionTypeVar* uv = get(superTy)) + else if (const UnionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) { // T <: A | B if T <: A or T <: B bool found = false; @@ -502,12 +627,16 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { TypeId type = uv->options[(i + startIndex) % uv->options.size()]; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(type, subTy, isFunctionCall); + innerState.tryUnify_(subTy, type, isFunctionCall); if (innerState.errors.empty()) { found = true; - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + log.concat(std::move(innerState.log)); + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + break; } else if (auto e = hasUnificationTooComplex(innerState.errors)) @@ -522,7 +651,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool failedOption = {innerState.errors.front()}; } - innerState.log.rollback(); + if (!FFlag::LuauUseCommittingTxnLog) + innerState.DEPRECATED_log.rollback(); } if (unificationTooComplex) @@ -538,7 +668,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); } } - else if (const IntersectionTypeVar* uv = get(superTy)) + else if (const IntersectionTypeVar* uv = + FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) { std::optional unificationTooComplex; std::optional firstFailedOption; @@ -547,7 +678,7 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool for (TypeId type : uv->parts) { Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(type, subTy, /*isFunctionCall*/ false, /*isIntersection*/ true); + innerState.tryUnify_(subTy, type, /*isFunctionCall*/ false, /*isIntersection*/ true); if (auto e = hasUnificationTooComplex(innerState.errors)) unificationTooComplex = e; @@ -557,7 +688,10 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool firstFailedOption = {innerState.errors.front()}; } - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + log.concat(std::move(innerState.log)); + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); } if (unificationTooComplex) @@ -565,7 +699,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool else if (firstFailedOption) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); } - else if (const IntersectionTypeVar* uv = get(subTy)) + else if (const IntersectionTypeVar* uv = + FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) { // A & B <: T if T <: A or T <: B bool found = false; @@ -591,12 +726,15 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool { TypeId type = uv->parts[(i + startIndex) % uv->parts.size()]; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(superTy, type, isFunctionCall); + innerState.tryUnify_(type, superTy, isFunctionCall); if (innerState.errors.empty()) { found = true; - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + log.concat(std::move(innerState.log)); + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); break; } else if (auto e = hasUnificationTooComplex(innerState.errors)) @@ -604,7 +742,8 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool unificationTooComplex = e; } - innerState.log.rollback(); + if (!FFlag::LuauUseCommittingTxnLog) + innerState.DEPRECATED_log.rollback(); } if (unificationTooComplex) @@ -614,44 +753,56 @@ void Unifier::tryUnify_(TypeId superTy, TypeId subTy, bool isFunctionCall, bool errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); } } - else if (get(superTy) && get(subTy)) - tryUnifyPrimitives(superTy, subTy); + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) + tryUnifyPrimitives(subTy, superTy); - else if (FFlag::LuauSingletonTypes && (get(superTy) || get(superTy)) && get(subTy)) - tryUnifySingletons(superTy, subTy); + else if (FFlag::LuauSingletonTypes && + ((FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) || + (FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy))) && + (FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy))) + tryUnifySingletons(subTy, superTy); - else if (get(superTy) && get(subTy)) - tryUnifyFunctions(superTy, subTy, isFunctionCall); + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) + tryUnifyFunctions(subTy, superTy, isFunctionCall); - else if (get(superTy) && get(subTy)) + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) { - tryUnifyTables(superTy, subTy, isIntersection); + tryUnifyTables(subTy, superTy, isIntersection); if (cacheEnabled && errors.empty()) - cacheResult(superTy, subTy); + cacheResult(subTy, superTy); } // tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical. - else if (get(superTy)) - tryUnifyWithMetatable(superTy, subTy, /*reversed*/ false); - else if (get(subTy)) - tryUnifyWithMetatable(subTy, superTy, /*reversed*/ true); + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(superTy))) + tryUnifyWithMetatable(subTy, superTy, /*reversed*/ false); + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(subTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(subTy))) + tryUnifyWithMetatable(superTy, subTy, /*reversed*/ true); - else if (get(superTy)) - tryUnifyWithClass(superTy, subTy, /*reversed*/ false); + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(superTy))) + tryUnifyWithClass(subTy, superTy, /*reversed*/ false); // Unification of nonclasses with classes is almost, but not quite symmetrical. // The order in which we perform this test is significant in the case that both types are classes. - else if (get(subTy)) - tryUnifyWithClass(superTy, subTy, /*reversed*/ true); + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(subTy)) || (!FFlag::LuauUseCommittingTxnLog && get(subTy))) + tryUnifyWithClass(subTy, superTy, /*reversed*/ true); else errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); - log.popSeen(superTy, subTy); + if (FFlag::LuauUseCommittingTxnLog) + log.popSeen(superTy, subTy); + else + DEPRECATED_log.popSeen(superTy, subTy); } -void Unifier::cacheResult(TypeId superTy, TypeId subTy) +void Unifier::cacheResult(TypeId subTy, TypeId superTy) { bool* superTyInfo = sharedState.skipCacheForType.find(superTy); @@ -684,7 +835,7 @@ void Unifier::cacheResult(TypeId superTy, TypeId subTy) sharedState.cachedUnify.insert({subTy, superTy}); } -struct WeirdIter +struct DEPRECATED_WeirdIter { TypePackId packId; const TypePack* pack; @@ -692,7 +843,7 @@ struct WeirdIter bool growing; TypeLevel level; - WeirdIter(TypePackId packId) + DEPRECATED_WeirdIter(TypePackId packId) : packId(packId) , pack(get(packId)) , index(0) @@ -705,7 +856,7 @@ struct WeirdIter } } - WeirdIter(const WeirdIter&) = default; + DEPRECATED_WeirdIter(const DEPRECATED_WeirdIter&) = default; const TypeId& operator*() { @@ -756,34 +907,152 @@ struct WeirdIter } }; -ErrorVec Unifier::canUnify(TypeId superTy, TypeId subTy) +struct WeirdIter +{ + TypePackId packId; + TxnLog& log; + TypePack* pack; + size_t index; + bool growing; + TypeLevel level; + + WeirdIter(TypePackId packId, TxnLog& log) + : packId(packId) + , log(log) + , pack(log.getMutable(packId)) + , index(0) + , growing(false) + { + while (pack && pack->head.empty() && pack->tail) + { + packId = *pack->tail; + pack = log.getMutable(packId); + } + } + + WeirdIter(const WeirdIter&) = default; + + TypeId& operator*() + { + LUAU_ASSERT(good()); + return pack->head[index]; + } + + bool good() const + { + return pack != nullptr && index < pack->head.size(); + } + + bool advance() + { + if (!pack) + return good(); + + if (index < pack->head.size()) + ++index; + + if (growing || index < pack->head.size()) + return good(); + + if (pack->tail) + { + packId = log.follow(*pack->tail); + pack = log.getMutable(packId); + index = 0; + } + + return good(); + } + + bool canGrow() const + { + return nullptr != log.getMutable(packId); + } + + void grow(TypePackId newTail) + { + LUAU_ASSERT(canGrow()); + LUAU_ASSERT(log.getMutable(newTail)); + + level = log.getMutable(packId)->level; + log.replace(packId, Unifiable::Bound(newTail)); + packId = newTail; + pack = log.getMutable(newTail); + index = 0; + growing = true; + } + + void pushType(TypeId ty) + { + LUAU_ASSERT(pack); + PendingTypePack* pendingPack = log.queue(packId); + if (TypePack* pending = getMutable(pendingPack)) + { + pending->head.push_back(ty); + // We've potentially just replaced the TypePack* that we need to look + // in. We need to replace pack. + pack = pending; + } + else + { + LUAU_ASSERT(!"Pending state for this pack was not a TypePack"); + } + } +}; + +ErrorVec Unifier::canUnify(TypeId subTy, TypeId superTy) { Unifier s = makeChildUnifier(); - s.tryUnify_(superTy, subTy); - s.log.rollback(); + s.tryUnify_(subTy, superTy); + + if (!FFlag::LuauUseCommittingTxnLog) + s.DEPRECATED_log.rollback(); + return s.errors; } -ErrorVec Unifier::canUnify(TypePackId superTy, TypePackId subTy, bool isFunctionCall) +ErrorVec Unifier::canUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall) { Unifier s = makeChildUnifier(); - s.tryUnify_(superTy, subTy, isFunctionCall); - s.log.rollback(); + s.tryUnify_(subTy, superTy, isFunctionCall); + + if (!FFlag::LuauUseCommittingTxnLog) + s.DEPRECATED_log.rollback(); + return s.errors; } -void Unifier::tryUnify(TypePackId superTp, TypePackId subTp, bool isFunctionCall) +void Unifier::tryUnify(TypePackId subTp, TypePackId superTp, bool isFunctionCall) { sharedState.counters.iterationCount = 0; - tryUnify_(superTp, subTp, isFunctionCall); + tryUnify_(subTp, superTp, isFunctionCall); +} + +static std::pair, std::optional> logAwareFlatten(TypePackId tp, const TxnLog& log) +{ + tp = log.follow(tp); + + std::vector flattened; + std::optional tail = std::nullopt; + + TypePackIterator it(tp, &log); + + for (; it != end(tp); ++it) + { + flattened.push_back(*it); + } + + tail = it.tail(); + + return {flattened, tail}; } /* * This is quite tricky: we are walking two rope-like structures and unifying corresponding elements. * If one is longer than the other, but the short end is free, we grow it to the required length. */ -void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCall) +void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCall) { RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); @@ -795,252 +1064,458 @@ void Unifier::tryUnify_(TypePackId superTp, TypePackId subTp, bool isFunctionCal return; } - superTp = follow(superTp); - subTp = follow(subTp); - - while (auto r = get(subTp)) - { - if (r->head.empty() && r->tail) - subTp = follow(*r->tail); - else - break; - } - - while (auto l = get(superTp)) - { - if (l->head.empty() && l->tail) - superTp = follow(*l->tail); - else - break; - } - - if (superTp == subTp) - return; - - if (get(superTp)) + if (FFlag::LuauUseCommittingTxnLog) { - occursCheck(superTp, subTp); + superTp = log.follow(superTp); + subTp = log.follow(subTp); - // The occurrence check might have caused superTp no longer to be a free type - if (!get(superTp)) + while (auto tp = log.getMutable(subTp)) { - log(superTp); - *asMutable(superTp) = Unifiable::Bound(subTp); + if (tp->head.empty() && tp->tail) + subTp = log.follow(*tp->tail); + else + break; } - } - else if (get(subTp)) - { - occursCheck(subTp, superTp); - // The occurrence check might have caused superTp no longer to be a free type - if (!get(subTp)) + while (auto tp = log.getMutable(superTp)) { - log(subTp); - *asMutable(subTp) = Unifiable::Bound(superTp); + if (tp->head.empty() && tp->tail) + superTp = log.follow(*tp->tail); + else + break; } - } - - else if (get(superTp)) - tryUnifyWithAny(superTp, subTp); - else if (get(subTp)) - tryUnifyWithAny(subTp, superTp); - - else if (get(superTp)) - tryUnifyVariadics(superTp, subTp, false); - else if (get(subTp)) - tryUnifyVariadics(subTp, superTp, true); - - else if (get(superTp) && get(subTp)) - { - auto l = get(superTp); - auto r = get(subTp); - - // If the size of two heads does not match, but both packs have free tail - // We set the sentinel variable to say so to avoid growing it forever. - auto [superTypes, superTail] = flatten(superTp); - auto [subTypes, subTail] = flatten(subTp); - - bool noInfiniteGrowth = - (superTypes.size() != subTypes.size()) && (superTail && get(*superTail)) && (subTail && get(*subTail)); - - auto superIter = WeirdIter{superTp}; - auto subIter = WeirdIter{subTp}; - - auto mkFreshType = [this](TypeLevel level) { - return types->freshType(level); - }; - - const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); - - int loopCount = 0; + if (superTp == subTp) + return; - do + if (log.getMutable(superTp)) { - if (FInt::LuauTypeInferTypePackLoopLimit > 0 && loopCount >= FInt::LuauTypeInferTypePackLoopLimit) - ice("Detected possibly infinite TypePack growth"); - - ++loopCount; - - if (superIter.good() && subIter.growing) - asMutable(subIter.pack)->head.push_back(mkFreshType(subIter.level)); + occursCheck(superTp, subTp); - if (subIter.good() && superIter.growing) - asMutable(superIter.pack)->head.push_back(mkFreshType(superIter.level)); - - if (superIter.good() && subIter.good()) + if (!log.getMutable(superTp)) { - tryUnify_(*superIter, *subIter); - - if (FFlag::LuauExtendedFunctionMismatchError && !errors.empty() && !firstPackErrorPos) - firstPackErrorPos = loopCount; - - superIter.advance(); - subIter.advance(); - continue; + log.replace(superTp, Unifiable::Bound(subTp)); } + } + else if (log.getMutable(subTp)) + { + occursCheck(subTp, superTp); - // If both are at the end, we're done - if (!superIter.good() && !subIter.good()) + if (!log.getMutable(subTp)) { - const bool lFreeTail = l->tail && get(follow(*l->tail)) != nullptr; - const bool rFreeTail = r->tail && get(follow(*r->tail)) != nullptr; - if (lFreeTail && rFreeTail) - tryUnify_(*l->tail, *r->tail); - else if (lFreeTail) - tryUnify_(*l->tail, emptyTp); - else if (rFreeTail) - tryUnify_(*r->tail, emptyTp); - - break; + log.replace(subTp, Unifiable::Bound(superTp)); } + } + else if (log.getMutable(superTp)) + tryUnifyWithAny(subTp, superTp); + else if (log.getMutable(subTp)) + tryUnifyWithAny(superTp, subTp); + else if (log.getMutable(superTp)) + tryUnifyVariadics(subTp, superTp, false); + else if (log.getMutable(subTp)) + tryUnifyVariadics(superTp, subTp, true); + else if (log.getMutable(superTp) && log.getMutable(subTp)) + { + auto superTpv = log.getMutable(superTp); + auto subTpv = log.getMutable(subTp); - // If both tails are free, bind one to the other and call it a day - if (superIter.canGrow() && subIter.canGrow()) - return tryUnify_(*superIter.pack->tail, *subIter.pack->tail); + // If the size of two heads does not match, but both packs have free tail + // We set the sentinel variable to say so to avoid growing it forever. + auto [superTypes, superTail] = logAwareFlatten(superTp, log); + auto [subTypes, subTail] = logAwareFlatten(subTp, log); - // If just one side is free on its tail, grow it to fit the other side. - // FIXME: The tail-most tail of the growing pack should be the same as the tail-most tail of the non-growing pack. - if (superIter.canGrow()) - superIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + bool noInfiniteGrowth = + (superTypes.size() != subTypes.size()) && (superTail && get(*superTail)) && (subTail && get(*subTail)); - else if (subIter.canGrow()) - subIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + auto superIter = WeirdIter(superTp, log); + auto subIter = WeirdIter(subTp, log); - else + auto mkFreshType = [this](TypeLevel level) { + return types->freshType(level); + }; + + const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); + + int loopCount = 0; + + do { - // A union type including nil marks an optional argument - if (superIter.good() && isOptional(*superIter)) - { - superIter.advance(); - continue; - } - else if (subIter.good() && isOptional(*subIter)) - { - subIter.advance(); - continue; - } + if (FInt::LuauTypeInferTypePackLoopLimit > 0 && loopCount >= FInt::LuauTypeInferTypePackLoopLimit) + ice("Detected possibly infinite TypePack growth"); - // In nonstrict mode, any also marks an optional argument. - else if (superIter.good() && isNonstrictMode() && get(follow(*superIter))) - { - superIter.advance(); - continue; - } + ++loopCount; - if (get(superIter.packId)) + if (superIter.good() && subIter.growing) { - tryUnifyVariadics(superIter.packId, subIter.packId, false, int(subIter.index)); - return; + subIter.pushType(mkFreshType(subIter.level)); } - if (get(subIter.packId)) + if (subIter.good() && superIter.growing) { - tryUnifyVariadics(subIter.packId, superIter.packId, true, int(superIter.index)); - return; + superIter.pushType(mkFreshType(superIter.level)); } - if (!isFunctionCall && subIter.good()) + if (superIter.good() && subIter.good()) { - // Sometimes it is ok to pass too many arguments - return; - } + tryUnify_(*subIter, *superIter); - // This is a bit weird because we don't actually know expected vs actual. We just know - // subtype vs supertype. If we are checking the values returned by a function, we swap - // these to produce the expected error message. - size_t expectedSize = size(superTp); - size_t actualSize = size(subTp); - if (ctx == CountMismatch::Result) - std::swap(expectedSize, actualSize); - errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); + if (FFlag::LuauExtendedFunctionMismatchError && !errors.empty() && !firstPackErrorPos) + firstPackErrorPos = loopCount; - while (superIter.good()) - { - tryUnify_(getSingletonTypes().errorRecoveryType(), *superIter); superIter.advance(); + subIter.advance(); + continue; } - while (subIter.good()) + // If both are at the end, we're done + if (!superIter.good() && !subIter.good()) { - tryUnify_(getSingletonTypes().errorRecoveryType(), *subIter); - subIter.advance(); + const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; + const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; + if (lFreeTail && rFreeTail) + tryUnify_(*subTpv->tail, *superTpv->tail); + else if (lFreeTail) + tryUnify_(emptyTp, *superTpv->tail); + else if (rFreeTail) + tryUnify_(emptyTp, *subTpv->tail); + + break; } - return; - } + // If both tails are free, bind one to the other and call it a day + if (superIter.canGrow() && subIter.canGrow()) + return tryUnify_(*subIter.pack->tail, *superIter.pack->tail); + + // If just one side is free on its tail, grow it to fit the other side. + // FIXME: The tail-most tail of the growing pack should be the same as the tail-most tail of the non-growing pack. + if (superIter.canGrow()) + superIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + else if (subIter.canGrow()) + subIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + else + { + // A union type including nil marks an optional argument + if (superIter.good() && isOptional(*superIter)) + { + superIter.advance(); + continue; + } + else if (subIter.good() && isOptional(*subIter)) + { + subIter.advance(); + continue; + } + + // In nonstrict mode, any also marks an optional argument. + else if (superIter.good() && isNonstrictMode() && log.getMutable(log.follow(*superIter))) + { + superIter.advance(); + continue; + } + + if (log.getMutable(superIter.packId)) + { + tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); + return; + } + + if (log.getMutable(subIter.packId)) + { + tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); + return; + } + + if (!isFunctionCall && subIter.good()) + { + // Sometimes it is ok to pass too many arguments + return; + } + + // This is a bit weird because we don't actually know expected vs actual. We just know + // subtype vs supertype. If we are checking the values returned by a function, we swap + // these to produce the expected error message. + size_t expectedSize = size(superTp); + size_t actualSize = size(subTp); + if (ctx == CountMismatch::Result) + std::swap(expectedSize, actualSize); + errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); + + while (superIter.good()) + { + tryUnify_(*superIter, getSingletonTypes().errorRecoveryType()); + superIter.advance(); + } + + while (subIter.good()) + { + tryUnify_(*subIter, getSingletonTypes().errorRecoveryType()); + subIter.advance(); + } - } while (!noInfiniteGrowth); + return; + } + + } while (!noInfiniteGrowth); + } + else + { + errors.push_back(TypeError{location, GenericError{"Failed to unify type packs"}}); + } } else { - errors.push_back(TypeError{location, GenericError{"Failed to unify type packs"}}); + superTp = follow(superTp); + subTp = follow(subTp); + + while (auto tp = get(subTp)) + { + if (tp->head.empty() && tp->tail) + subTp = follow(*tp->tail); + else + break; + } + + while (auto tp = get(superTp)) + { + if (tp->head.empty() && tp->tail) + superTp = follow(*tp->tail); + else + break; + } + + if (superTp == subTp) + return; + + if (get(superTp)) + { + occursCheck(superTp, subTp); + + if (!get(superTp)) + { + DEPRECATED_log(superTp); + *asMutable(superTp) = Unifiable::Bound(subTp); + } + } + else if (get(subTp)) + { + occursCheck(subTp, superTp); + + if (!get(subTp)) + { + DEPRECATED_log(subTp); + *asMutable(subTp) = Unifiable::Bound(superTp); + } + } + + else if (get(superTp)) + tryUnifyWithAny(subTp, superTp); + + else if (get(subTp)) + tryUnifyWithAny(superTp, subTp); + + else if (get(superTp)) + tryUnifyVariadics(subTp, superTp, false); + else if (get(subTp)) + tryUnifyVariadics(superTp, subTp, true); + + else if (get(superTp) && get(subTp)) + { + auto superTpv = get(superTp); + auto subTpv = get(subTp); + + // If the size of two heads does not match, but both packs have free tail + // We set the sentinel variable to say so to avoid growing it forever. + auto [superTypes, superTail] = flatten(superTp); + auto [subTypes, subTail] = flatten(subTp); + + bool noInfiniteGrowth = + (superTypes.size() != subTypes.size()) && (superTail && get(*superTail)) && (subTail && get(*subTail)); + + auto superIter = DEPRECATED_WeirdIter{superTp}; + auto subIter = DEPRECATED_WeirdIter{subTp}; + + auto mkFreshType = [this](TypeLevel level) { + return types->freshType(level); + }; + + const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); + + int loopCount = 0; + + do + { + if (FInt::LuauTypeInferTypePackLoopLimit > 0 && loopCount >= FInt::LuauTypeInferTypePackLoopLimit) + ice("Detected possibly infinite TypePack growth"); + + ++loopCount; + + if (superIter.good() && subIter.growing) + asMutable(subIter.pack)->head.push_back(mkFreshType(subIter.level)); + + if (subIter.good() && superIter.growing) + asMutable(superIter.pack)->head.push_back(mkFreshType(superIter.level)); + + if (superIter.good() && subIter.good()) + { + tryUnify_(*subIter, *superIter); + + if (FFlag::LuauExtendedFunctionMismatchError && !errors.empty() && !firstPackErrorPos) + firstPackErrorPos = loopCount; + + superIter.advance(); + subIter.advance(); + continue; + } + + // If both are at the end, we're done + if (!superIter.good() && !subIter.good()) + { + const bool lFreeTail = superTpv->tail && get(follow(*superTpv->tail)) != nullptr; + const bool rFreeTail = subTpv->tail && get(follow(*subTpv->tail)) != nullptr; + if (lFreeTail && rFreeTail) + tryUnify_(*subTpv->tail, *superTpv->tail); + else if (lFreeTail) + tryUnify_(emptyTp, *superTpv->tail); + else if (rFreeTail) + tryUnify_(emptyTp, *subTpv->tail); + + break; + } + + // If both tails are free, bind one to the other and call it a day + if (superIter.canGrow() && subIter.canGrow()) + return tryUnify_(*subIter.pack->tail, *superIter.pack->tail); + + // If just one side is free on its tail, grow it to fit the other side. + // FIXME: The tail-most tail of the growing pack should be the same as the tail-most tail of the non-growing pack. + if (superIter.canGrow()) + superIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + + else if (subIter.canGrow()) + subIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + + else + { + // A union type including nil marks an optional argument + if (superIter.good() && isOptional(*superIter)) + { + superIter.advance(); + continue; + } + else if (subIter.good() && isOptional(*subIter)) + { + subIter.advance(); + continue; + } + + // In nonstrict mode, any also marks an optional argument. + else if (superIter.good() && isNonstrictMode() && get(follow(*superIter))) + { + superIter.advance(); + continue; + } + + if (get(superIter.packId)) + { + tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); + return; + } + + if (get(subIter.packId)) + { + tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); + return; + } + + if (!isFunctionCall && subIter.good()) + { + // Sometimes it is ok to pass too many arguments + return; + } + + // This is a bit weird because we don't actually know expected vs actual. We just know + // subtype vs supertype. If we are checking the values returned by a function, we swap + // these to produce the expected error message. + size_t expectedSize = size(superTp); + size_t actualSize = size(subTp); + if (ctx == CountMismatch::Result) + std::swap(expectedSize, actualSize); + errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); + + while (superIter.good()) + { + tryUnify_(*superIter, getSingletonTypes().errorRecoveryType()); + superIter.advance(); + } + + while (subIter.good()) + { + tryUnify_(*subIter, getSingletonTypes().errorRecoveryType()); + subIter.advance(); + } + + return; + } + + } while (!noInfiniteGrowth); + } + else + { + errors.push_back(TypeError{location, GenericError{"Failed to unify type packs"}}); + } } } -void Unifier::tryUnifyPrimitives(TypeId superTy, TypeId subTy) +void Unifier::tryUnifyPrimitives(TypeId subTy, TypeId superTy) { - const PrimitiveTypeVar* lp = get(superTy); - const PrimitiveTypeVar* rp = get(subTy); - if (!lp || !rp) + const PrimitiveTypeVar* superPrim = get(superTy); + const PrimitiveTypeVar* subPrim = get(subTy); + if (!superPrim || !subPrim) ice("passed non primitive types to unifyPrimitives"); - if (lp->type != rp->type) + if (superPrim->type != subPrim->type) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } -void Unifier::tryUnifySingletons(TypeId superTy, TypeId subTy) +void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) { - const PrimitiveTypeVar* lp = get(superTy); - const SingletonTypeVar* ls = get(superTy); - const SingletonTypeVar* rs = get(subTy); + const PrimitiveTypeVar* superPrim = get(superTy); + const SingletonTypeVar* superSingleton = get(superTy); + const SingletonTypeVar* subSingleton = get(subTy); - if ((!lp && !ls) || !rs) + if ((!superPrim && !superSingleton) || !subSingleton) ice("passed non singleton/primitive types to unifySingletons"); - if (ls && *ls == *rs) + if (superSingleton && *superSingleton == *subSingleton) return; - if (lp && lp->type == PrimitiveTypeVar::Boolean && get(rs) && variance == Covariant) + if (superPrim && superPrim->type == PrimitiveTypeVar::Boolean && get(subSingleton) && variance == Covariant) return; - if (lp && lp->type == PrimitiveTypeVar::String && get(rs) && variance == Covariant) + if (superPrim && superPrim->type == PrimitiveTypeVar::String && get(subSingleton) && variance == Covariant) return; errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } -void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCall) +void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall) { - FunctionTypeVar* lf = getMutable(superTy); - FunctionTypeVar* rf = getMutable(subTy); - if (!lf || !rf) + FunctionTypeVar* superFunction = getMutable(superTy); + FunctionTypeVar* subFunction = getMutable(subTy); + + if (FFlag::LuauUseCommittingTxnLog) + { + superFunction = log.getMutable(superTy); + subFunction = log.getMutable(subTy); + } + + if (!superFunction || !subFunction) ice("passed non-function types to unifyFunction"); - size_t numGenerics = lf->generics.size(); - if (numGenerics != rf->generics.size()) + size_t numGenerics = superFunction->generics.size(); + if (numGenerics != subFunction->generics.size()) { - numGenerics = std::min(lf->generics.size(), rf->generics.size()); + numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size()); if (FFlag::LuauExtendedFunctionMismatchError) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}}); @@ -1048,10 +1523,10 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } - size_t numGenericPacks = lf->genericPacks.size(); - if (numGenericPacks != rf->genericPacks.size()) + size_t numGenericPacks = superFunction->genericPacks.size(); + if (numGenericPacks != subFunction->genericPacks.size()) { - numGenericPacks = std::min(lf->genericPacks.size(), rf->genericPacks.size()); + numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size()); if (FFlag::LuauExtendedFunctionMismatchError) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}}); @@ -1060,7 +1535,12 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal } for (size_t i = 0; i < numGenerics; i++) - log.pushSeen(lf->generics[i], rf->generics[i]); + { + if (FFlag::LuauUseCommittingTxnLog) + log.pushSeen(superFunction->generics[i], subFunction->generics[i]); + else + DEPRECATED_log.pushSeen(superFunction->generics[i], subFunction->generics[i]); + } CountMismatch::Context context = ctx; @@ -1071,7 +1551,7 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal if (FFlag::LuauExtendedFunctionMismatchError) { innerState.ctx = CountMismatch::Arg; - innerState.tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); + innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); bool reported = !innerState.errors.empty(); @@ -1085,13 +1565,13 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); innerState.ctx = CountMismatch::Result; - innerState.tryUnify_(lf->retType, rf->retType); + innerState.tryUnify_(subFunction->retType, superFunction->retType); if (!reported) { if (auto e = hasUnificationTooComplex(innerState.errors)) errors.push_back(*e); - else if (!innerState.errors.empty() && size(lf->retType) == 1 && finite(lf->retType)) + else if (!innerState.errors.empty() && size(superFunction->retType) == 1 && finite(superFunction->retType)) errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) errors.push_back( @@ -1104,38 +1584,70 @@ void Unifier::tryUnifyFunctions(TypeId superTy, TypeId subTy, bool isFunctionCal else { ctx = CountMismatch::Arg; - innerState.tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); + innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); ctx = CountMismatch::Result; - innerState.tryUnify_(lf->retType, rf->retType); + innerState.tryUnify_(subFunction->retType, superFunction->retType); checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); } - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + { + log.concat(std::move(innerState.log)); + } + else + { + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + } } else { ctx = CountMismatch::Arg; - tryUnify_(rf->argTypes, lf->argTypes, isFunctionCall); + tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); ctx = CountMismatch::Result; - tryUnify_(lf->retType, rf->retType); + tryUnify_(subFunction->retType, superFunction->retType); } - if (lf->definition && !rf->definition && !subTy->persistent) + if (FFlag::LuauUseCommittingTxnLog) { - rf->definition = lf->definition; + if (superFunction->definition && !subFunction->definition && !subTy->persistent) + { + PendingType* newSubTy = log.queue(subTy); + FunctionTypeVar* newSubFtv = getMutable(newSubTy); + LUAU_ASSERT(newSubFtv); + newSubFtv->definition = superFunction->definition; + } + else if (!superFunction->definition && subFunction->definition && !superTy->persistent) + { + PendingType* newSuperTy = log.queue(superTy); + FunctionTypeVar* newSuperFtv = getMutable(newSuperTy); + LUAU_ASSERT(newSuperFtv); + newSuperFtv->definition = subFunction->definition; + } } - else if (!lf->definition && rf->definition && !superTy->persistent) + else { - lf->definition = rf->definition; + if (superFunction->definition && !subFunction->definition && !subTy->persistent) + { + subFunction->definition = superFunction->definition; + } + else if (!superFunction->definition && subFunction->definition && !superTy->persistent) + { + superFunction->definition = subFunction->definition; + } } ctx = context; for (int i = int(numGenerics) - 1; 0 <= i; i--) - log.popSeen(lf->generics[i], rf->generics[i]); + { + if (FFlag::LuauUseCommittingTxnLog) + log.popSeen(superFunction->generics[i], subFunction->generics[i]); + else + DEPRECATED_log.popSeen(superFunction->generics[i], subFunction->generics[i]); + } } namespace @@ -1160,77 +1672,84 @@ struct Resetter } // namespace -void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) +void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { if (!FFlag::LuauTableSubtypingVariance2) - return DEPRECATED_tryUnifyTables(left, right, isIntersection); + return DEPRECATED_tryUnifyTables(subTy, superTy, isIntersection); - TableTypeVar* lt = getMutable(left); - TableTypeVar* rt = getMutable(right); - if (!lt || !rt) + TableTypeVar* superTable = getMutable(superTy); + TableTypeVar* subTable = getMutable(subTy); + if (!superTable || !subTable) ice("passed non-table types to unifyTables"); std::vector missingProperties; std::vector extraProperties; // Optimization: First test that the property sets are compatible without doing any recursive unification - if (FFlag::LuauTableUnificationEarlyTest && !rt->indexer && rt->state != TableState::Free) + if (FFlag::LuauTableUnificationEarlyTest && !subTable->indexer && subTable->state != TableState::Free) { - for (const auto& [propName, superProp] : lt->props) + for (const auto& [propName, superProp] : superTable->props) { - auto subIter = rt->props.find(propName); - if (subIter == rt->props.end() && !isOptional(superProp.type) && !get(follow(superProp.type))) + auto subIter = subTable->props.find(propName); + if (subIter == subTable->props.end() && !isOptional(superProp.type) && !get(follow(superProp.type))) missingProperties.push_back(propName); } if (!missingProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingProperties)}}); + errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); return; } } // And vice versa if we're invariant - if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !lt->indexer && lt->state != TableState::Unsealed && - lt->state != TableState::Free) + if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !superTable->indexer && superTable->state != TableState::Unsealed && + superTable->state != TableState::Free) { - for (const auto& [propName, subProp] : rt->props) + for (const auto& [propName, subProp] : subTable->props) { - auto superIter = lt->props.find(propName); - if (superIter == lt->props.end() && !isOptional(subProp.type) && !get(follow(subProp.type))) + auto superIter = superTable->props.find(propName); + if (superIter == superTable->props.end() && !isOptional(subProp.type) && !get(follow(subProp.type))) extraProperties.push_back(propName); } if (!extraProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraProperties), MissingProperties::Extra}}); + errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); return; } } - // Reminder: left is the supertype, right is the subtype. // Width subtyping: any property in the supertype must be in the subtype, // and the types must agree. - for (const auto& [name, prop] : lt->props) + for (const auto& [name, prop] : superTable->props) { - const auto& r = rt->props.find(name); - if (r != rt->props.end()) + const auto& r = subTable->props.find(name); + if (r != subTable->props.end()) { // TODO: read-only properties don't need invariance Resetter resetter{&variance}; variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(prop.type, r->second.type); + innerState.tryUnify_(r->second.type, prop.type); - checkChildUnifierTypeMismatch(innerState.errors, name, left, right); + checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + { + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + } else - innerState.log.rollback(); + { + if (innerState.errors.empty()) + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + else + innerState.DEPRECATED_log.rollback(); + } } - else if (rt->indexer && isString(rt->indexer->indexType)) + else if (subTable->indexer && isString(subTable->indexer->indexType)) { // TODO: read-only indexers don't need invariance // TODO: really we should only allow this if prop.type is optional. @@ -1238,37 +1757,55 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(prop.type, rt->indexer->indexResultType); + innerState.tryUnify_(subTable->indexer->indexResultType, prop.type); - checkChildUnifierTypeMismatch(innerState.errors, name, left, right); + checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + { + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + } else - innerState.log.rollback(); + { + if (innerState.errors.empty()) + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + else + innerState.DEPRECATED_log.rollback(); + } } else if (isOptional(prop.type) || get(follow(prop.type))) // TODO: this case is unsound, but without it our test suite fails. CLI-46031 // TODO: should isOptional(anyType) be true? { } - else if (rt->state == TableState::Free) + else if (subTable->state == TableState::Free) { - log(rt); - rt->props[name] = prop; + if (FFlag::LuauUseCommittingTxnLog) + { + PendingType* pendingSub = log.queue(subTy); + TableTypeVar* ttv = getMutable(pendingSub); + LUAU_ASSERT(ttv); + ttv->props[name] = prop; + } + else + { + DEPRECATED_log(subTy); + subTable->props[name] = prop; + } } else missingProperties.push_back(name); } - for (const auto& [name, prop] : rt->props) + for (const auto& [name, prop] : subTable->props) { - if (lt->props.count(name)) + if (superTable->props.count(name)) { // If both lt and rt contain the property, then // we're done since we already unified them above } - else if (lt->indexer && isString(lt->indexer->indexType)) + else if (superTable->indexer && isString(superTable->indexer->indexType)) { // TODO: read-only indexers don't need invariance // TODO: really we should only allow this if prop.type is optional. @@ -1276,24 +1813,42 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(prop.type, lt->indexer->indexResultType); + innerState.tryUnify_(superTable->indexer->indexResultType, prop.type); - checkChildUnifierTypeMismatch(innerState.errors, name, left, right); + checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + { + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + } else - innerState.log.rollback(); + { + if (innerState.errors.empty()) + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + else + innerState.DEPRECATED_log.rollback(); + } } - else if (lt->state == TableState::Unsealed) + else if (superTable->state == TableState::Unsealed) { // TODO: this case is unsound when variance is Invariant, but without it lua-apps fails to typecheck. // TODO: file a JIRA // TODO: hopefully readonly/writeonly properties will fix this. Property clone = prop; clone.type = deeplyOptional(clone.type); - log(left); - lt->props[name] = clone; + + if (FFlag::LuauUseCommittingTxnLog) + { + PendingType* pendingSuper = log.queue(superTy); + TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); + pendingSuperTtv->props[name] = clone; + } + else + { + DEPRECATED_log(superTy); + superTable->props[name] = clone; + } } else if (variance == Covariant) { @@ -1303,61 +1858,93 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) // TODO: should isOptional(anyType) be true? { } - else if (lt->state == TableState::Free) + else if (superTable->state == TableState::Free) { - log(left); - lt->props[name] = prop; + if (FFlag::LuauUseCommittingTxnLog) + { + PendingType* pendingSuper = log.queue(superTy); + TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); + pendingSuperTtv->props[name] = prop; + } + else + { + DEPRECATED_log(superTy); + superTable->props[name] = prop; + } } else extraProperties.push_back(name); } // Unify indexers - if (lt->indexer && rt->indexer) + if (superTable->indexer && subTable->indexer) { // TODO: read-only indexers don't need invariance Resetter resetter{&variance}; variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnify(*lt->indexer, *rt->indexer); - checkChildUnifierTypeMismatch(innerState.errors, left, right); - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); + innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); + checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); + + if (FFlag::LuauUseCommittingTxnLog) + { + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); + } else - innerState.log.rollback(); + { + if (innerState.errors.empty()) + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + else + innerState.DEPRECATED_log.rollback(); + } } - else if (lt->indexer) + else if (superTable->indexer) { - if (rt->state == TableState::Unsealed || rt->state == TableState::Free) + if (subTable->state == TableState::Unsealed || subTable->state == TableState::Free) { // passing/assigning a table without an indexer to something that has one // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. // TODO: we only need to do this if the supertype's indexer is read/write // since that can add indexed elements. - log(right); - rt->indexer = lt->indexer; + if (FFlag::LuauUseCommittingTxnLog) + { + log.changeIndexer(subTy, superTable->indexer); + } + else + { + DEPRECATED_log(subTy); + subTable->indexer = superTable->indexer; + } } } - else if (rt->indexer && variance == Invariant) + else if (subTable->indexer && variance == Invariant) { // Symmetric if we are invariant - if (lt->state == TableState::Unsealed || lt->state == TableState::Free) + if (superTable->state == TableState::Unsealed || superTable->state == TableState::Free) { - log(left); - lt->indexer = rt->indexer; + if (FFlag::LuauUseCommittingTxnLog) + { + log.changeIndexer(superTy, subTable->indexer); + } + else + { + DEPRECATED_log(superTy); + superTable->indexer = subTable->indexer; + } } } if (!missingProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingProperties)}}); + errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); return; } if (!extraProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraProperties), MissingProperties::Extra}}); + errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); return; } @@ -1369,18 +1956,32 @@ void Unifier::tryUnifyTables(TypeId left, TypeId right, bool isIntersection) * I believe this is guaranteed to terminate eventually because this will * only happen when a free table is bound to another table. */ - if (lt->boundTo || rt->boundTo) - return tryUnify_(left, right); + if (superTable->boundTo || subTable->boundTo) + return tryUnify_(subTy, superTy); - if (lt->state == TableState::Free) + if (superTable->state == TableState::Free) { - log(lt); - lt->boundTo = right; + if (FFlag::LuauUseCommittingTxnLog) + { + log.bindTable(superTy, subTy); + } + else + { + DEPRECATED_log(superTable); + superTable->boundTo = subTy; + } } - else if (rt->state == TableState::Free) + else if (subTable->state == TableState::Free) { - log(rt); - rt->boundTo = left; + if (FFlag::LuauUseCommittingTxnLog) + { + log.bindTable(subTy, superTy); + } + else + { + DEPRECATED_log(subTy); + subTable->boundTo = superTy; + } } } @@ -1406,99 +2007,129 @@ TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map see return types->addType(UnionTypeVar{{getSingletonTypes().nilType, ty}}); } -void Unifier::DEPRECATED_tryUnifyTables(TypeId left, TypeId right, bool isIntersection) +void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); Resetter resetter{&variance}; variance = Invariant; - TableTypeVar* lt = getMutable(left); - TableTypeVar* rt = getMutable(right); - if (!lt || !rt) + TableTypeVar* superTable = getMutable(superTy); + TableTypeVar* subTable = getMutable(subTy); + + if (FFlag::LuauUseCommittingTxnLog) + { + superTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); + } + + if (!superTable || !subTable) ice("passed non-table types to unifyTables"); - if (lt->state == TableState::Sealed && rt->state == TableState::Sealed) - return tryUnifySealedTables(left, right, isIntersection); - else if ((lt->state == TableState::Sealed && rt->state == TableState::Unsealed) || - (lt->state == TableState::Unsealed && rt->state == TableState::Sealed)) - return tryUnifySealedTables(left, right, isIntersection); - else if ((lt->state == TableState::Sealed && rt->state == TableState::Generic) || - (lt->state == TableState::Generic && rt->state == TableState::Sealed)) - errors.push_back(TypeError{location, TypeMismatch{left, right}}); - else if ((lt->state == TableState::Free) != (rt->state == TableState::Free)) // one table is free and the other is not + if (superTable->state == TableState::Sealed && subTable->state == TableState::Sealed) + return tryUnifySealedTables(subTy, superTy, isIntersection); + else if ((superTable->state == TableState::Sealed && subTable->state == TableState::Unsealed) || + (superTable->state == TableState::Unsealed && subTable->state == TableState::Sealed)) + return tryUnifySealedTables(subTy, superTy, isIntersection); + else if ((superTable->state == TableState::Sealed && subTable->state == TableState::Generic) || + (superTable->state == TableState::Generic && subTable->state == TableState::Sealed)) + errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + else if ((superTable->state == TableState::Free) != (subTable->state == TableState::Free)) // one table is free and the other is not { - TypeId freeTypeId = rt->state == TableState::Free ? right : left; - TypeId otherTypeId = rt->state == TableState::Free ? left : right; + TypeId freeTypeId = subTable->state == TableState::Free ? subTy : superTy; + TypeId otherTypeId = subTable->state == TableState::Free ? superTy : subTy; - return tryUnifyFreeTable(freeTypeId, otherTypeId); + return tryUnifyFreeTable(otherTypeId, freeTypeId); } - else if (lt->state == TableState::Free && rt->state == TableState::Free) + else if (superTable->state == TableState::Free && subTable->state == TableState::Free) { - tryUnifyFreeTable(left, right); + tryUnifyFreeTable(subTy, superTy); // avoid creating a cycle when the types are already pointing at each other - if (follow(left) != follow(right)) + if (follow(superTy) != follow(subTy)) { - log(lt); - lt->boundTo = right; + if (FFlag::LuauUseCommittingTxnLog) + { + log.bindTable(superTy, subTy); + } + else + { + DEPRECATED_log(superTable); + superTable->boundTo = subTy; + } } return; } - else if (lt->state != TableState::Sealed && rt->state != TableState::Sealed) + else if (superTable->state != TableState::Sealed && subTable->state != TableState::Sealed) { // All free tables are checked in one of the branches above - LUAU_ASSERT(lt->state != TableState::Free); - LUAU_ASSERT(rt->state != TableState::Free); + LUAU_ASSERT(superTable->state != TableState::Free); + LUAU_ASSERT(subTable->state != TableState::Free); // Tables must have exactly the same props and their types must all unify // I honestly have no idea if this is remotely close to reasonable. - for (const auto& [name, prop] : lt->props) + for (const auto& [name, prop] : superTable->props) { - const auto& r = rt->props.find(name); - if (r == rt->props.end()) - errors.push_back(TypeError{location, UnknownProperty{right, name}}); + const auto& r = subTable->props.find(name); + if (r == subTable->props.end()) + errors.push_back(TypeError{location, UnknownProperty{subTy, name}}); else - tryUnify_(prop.type, r->second.type); + tryUnify_(r->second.type, prop.type); } - if (lt->indexer && rt->indexer) - tryUnify(*lt->indexer, *rt->indexer); - else if (lt->indexer) + if (superTable->indexer && subTable->indexer) + tryUnifyIndexer(*subTable->indexer, *superTable->indexer); + else if (superTable->indexer) { // passing/assigning a table without an indexer to something that has one // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. - if (rt->state == TableState::Unsealed) - rt->indexer = lt->indexer; + if (subTable->state == TableState::Unsealed) + { + if (FFlag::LuauUseCommittingTxnLog) + { + log.changeIndexer(subTy, superTable->indexer); + } + else + { + subTable->indexer = superTable->indexer; + } + } else - errors.push_back(TypeError{location, CannotExtendTable{right, CannotExtendTable::Indexer}}); + errors.push_back(TypeError{location, CannotExtendTable{subTy, CannotExtendTable::Indexer}}); } } - else if (lt->state == TableState::Sealed) + else if (superTable->state == TableState::Sealed) { // lt is sealed and so it must be possible for rt to have precisely the same shape // Verify that this is the case, then bind rt to lt. ice("unsealed tables are not working yet", location); } - else if (rt->state == TableState::Sealed) - return tryUnifyTables(right, left, isIntersection); + else if (subTable->state == TableState::Sealed) + return tryUnifyTables(superTy, subTy, isIntersection); else ice("tryUnifyTables"); } -void Unifier::tryUnifyFreeTable(TypeId freeTypeId, TypeId otherTypeId) +void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) { - TableTypeVar* freeTable = getMutable(freeTypeId); - TableTypeVar* otherTable = getMutable(otherTypeId); - if (!freeTable || !otherTable) + TableTypeVar* freeTable = getMutable(superTy); + TableTypeVar* subTable = getMutable(subTy); + + if (FFlag::LuauUseCommittingTxnLog) + { + freeTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); + } + + if (!freeTable || !subTable) ice("passed non-table types to tryUnifyFreeTable"); // Any properties in freeTable must unify with those in otherTable. // Then bind freeTable to otherTable. for (const auto& [freeName, freeProp] : freeTable->props) { - if (auto otherProp = findTablePropertyRespectingMeta(otherTypeId, freeName)) + if (auto subProp = findTablePropertyRespectingMeta(subTy, freeName)) { - tryUnify_(*otherProp, freeProp.type); + tryUnify_(freeProp.type, *subProp); /* * TypeVars are commonly cyclic, so it is entirely possible @@ -1508,84 +2139,133 @@ void Unifier::tryUnifyFreeTable(TypeId freeTypeId, TypeId otherTypeId) * I believe this is guaranteed to terminate eventually because this will * only happen when a free table is bound to another table. */ - if (!get(freeTypeId) || !get(otherTypeId)) - return tryUnify_(freeTypeId, otherTypeId); + if (FFlag::LuauUseCommittingTxnLog) + { + if (!log.getMutable(superTy) || !log.getMutable(subTy)) + return tryUnify_(subTy, superTy); - if (freeTable->boundTo) - return tryUnify_(freeTypeId, otherTypeId); + if (TableTypeVar* pendingFreeTtv = log.getMutable(superTy); pendingFreeTtv && pendingFreeTtv->boundTo) + return tryUnify_(subTy, superTy); + } + else + { + if (!get(superTy) || !get(subTy)) + return tryUnify_(subTy, superTy); + + if (freeTable->boundTo) + return tryUnify_(subTy, superTy); + } } else { // If the other table is also free, then we are learning that it has more // properties than we previously thought. Else, it is an error. - if (otherTable->state == TableState::Free) - otherTable->props.insert({freeName, freeProp}); + if (subTable->state == TableState::Free) + { + if (FFlag::LuauUseCommittingTxnLog) + { + PendingType* pendingSub = log.queue(subTy); + TableTypeVar* pendingSubTtv = getMutable(pendingSub); + LUAU_ASSERT(pendingSubTtv); + pendingSubTtv->props.insert({freeName, freeProp}); + } + else + { + subTable->props.insert({freeName, freeProp}); + } + } else - errors.push_back(TypeError{location, UnknownProperty{otherTypeId, freeName}}); + errors.push_back(TypeError{location, UnknownProperty{subTy, freeName}}); } } - if (freeTable->indexer && otherTable->indexer) + if (freeTable->indexer && subTable->indexer) { Unifier innerState = makeChildUnifier(); - innerState.tryUnify(*freeTable->indexer, *otherTable->indexer); + innerState.tryUnifyIndexer(*subTable->indexer, *freeTable->indexer); - checkChildUnifierTypeMismatch(innerState.errors, freeTypeId, otherTypeId); + checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + log.concat(std::move(innerState.log)); + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + } + else if (subTable->state == TableState::Free && freeTable->indexer) + { + if (FFlag::LuauUseCommittingTxnLog) + { + log.changeIndexer(superTy, subTable->indexer); + } + else + { + freeTable->indexer = subTable->indexer; + } } - else if (otherTable->state == TableState::Free && freeTable->indexer) - freeTable->indexer = otherTable->indexer; - if (!freeTable->boundTo && otherTable->state != TableState::Free) + if (!freeTable->boundTo && subTable->state != TableState::Free) { - log(freeTable); - freeTable->boundTo = otherTypeId; + if (FFlag::LuauUseCommittingTxnLog) + { + log.bindTable(superTy, subTy); + } + else + { + DEPRECATED_log(freeTable); + freeTable->boundTo = subTy; + } } } -void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersection) +void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersection) { - TableTypeVar* lt = getMutable(left); - TableTypeVar* rt = getMutable(right); - if (!lt || !rt) + TableTypeVar* superTable = getMutable(superTy); + TableTypeVar* subTable = getMutable(subTy); + + if (FFlag::LuauUseCommittingTxnLog) + { + superTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); + } + + if (!superTable || !subTable) ice("passed non-table types to unifySealedTables"); Unifier innerState = makeChildUnifier(); std::vector missingPropertiesInSuper; - bool isUnnamedTable = rt->name == std::nullopt && rt->syntheticName == std::nullopt; + bool isUnnamedTable = subTable->name == std::nullopt && subTable->syntheticName == std::nullopt; bool errorReported = false; // Optimization: First test that the property sets are compatible without doing any recursive unification - if (FFlag::LuauTableUnificationEarlyTest && !rt->indexer) + if (FFlag::LuauTableUnificationEarlyTest && !subTable->indexer) { - for (const auto& [propName, superProp] : lt->props) + for (const auto& [propName, superProp] : superTable->props) { - auto subIter = rt->props.find(propName); - if (subIter == rt->props.end() && !isOptional(superProp.type)) + auto subIter = subTable->props.find(propName); + if (subIter == subTable->props.end() && !isOptional(superProp.type)) missingPropertiesInSuper.push_back(propName); } if (!missingPropertiesInSuper.empty()) { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingPropertiesInSuper)}}); + errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); return; } } // Tables must have exactly the same props and their types must all unify - for (const auto& it : lt->props) + for (const auto& it : superTable->props) { - const auto& r = rt->props.find(it.first); - if (r == rt->props.end()) + const auto& r = subTable->props.find(it.first); + if (r == subTable->props.end()) { if (isOptional(it.second.type)) continue; missingPropertiesInSuper.push_back(it.first); - innerState.errors.push_back(TypeError{location, TypeMismatch{left, right}}); + innerState.errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } else { @@ -1594,7 +2274,7 @@ void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersectio size_t oldErrorSize = innerState.errors.size(); Location old = innerState.location; innerState.location = *r->second.location; - innerState.tryUnify_(it.second.type, r->second.type); + innerState.tryUnify_(r->second.type, it.second.type); innerState.location = old; if (oldErrorSize != innerState.errors.size() && !errorReported) @@ -1605,113 +2285,165 @@ void Unifier::tryUnifySealedTables(TypeId left, TypeId right, bool isIntersectio } else { - innerState.tryUnify_(it.second.type, r->second.type); + innerState.tryUnify_(r->second.type, it.second.type); } } } - if (lt->indexer || rt->indexer) + if (superTable->indexer || subTable->indexer) { - if (lt->indexer && rt->indexer) - innerState.tryUnify(*lt->indexer, *rt->indexer); - else if (rt->state == TableState::Unsealed) + if (FFlag::LuauUseCommittingTxnLog) { - if (lt->indexer && !rt->indexer) - rt->indexer = lt->indexer; - } - else if (lt->state == TableState::Unsealed) - { - if (rt->indexer && !lt->indexer) - lt->indexer = rt->indexer; + if (superTable->indexer && subTable->indexer) + innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); + else if (subTable->state == TableState::Unsealed) + { + if (superTable->indexer && !subTable->indexer) + { + log.changeIndexer(subTy, superTable->indexer); + } + } + else if (superTable->state == TableState::Unsealed) + { + if (subTable->indexer && !superTable->indexer) + { + log.changeIndexer(superTy, subTable->indexer); + } + } + else if (superTable->indexer) + { + innerState.tryUnify_(getSingletonTypes().stringType, superTable->indexer->indexType); + for (const auto& [name, type] : subTable->props) + { + const auto& it = superTable->props.find(name); + if (it == superTable->props.end()) + innerState.tryUnify_(type.type, superTable->indexer->indexResultType); + } + } + else + innerState.errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } - else if (lt->indexer) + else { - innerState.tryUnify_(lt->indexer->indexType, getSingletonTypes().stringType); - // We already try to unify properties in both tables. - // Skip those and just look for the ones remaining and see if they fit into the indexer. - for (const auto& [name, type] : rt->props) + if (superTable->indexer && subTable->indexer) + innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); + else if (subTable->state == TableState::Unsealed) { - const auto& it = lt->props.find(name); - if (it == lt->props.end()) - innerState.tryUnify_(lt->indexer->indexResultType, type.type); + if (superTable->indexer && !subTable->indexer) + subTable->indexer = superTable->indexer; } + else if (superTable->state == TableState::Unsealed) + { + if (subTable->indexer && !superTable->indexer) + superTable->indexer = subTable->indexer; + } + else if (superTable->indexer) + { + innerState.tryUnify_(getSingletonTypes().stringType, superTable->indexer->indexType); + // We already try to unify properties in both tables. + // Skip those and just look for the ones remaining and see if they fit into the indexer. + for (const auto& [name, type] : subTable->props) + { + const auto& it = superTable->props.find(name); + if (it == superTable->props.end()) + innerState.tryUnify_(type.type, superTable->indexer->indexResultType); + } + } + else + innerState.errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); } - else - innerState.errors.push_back(TypeError{location, TypeMismatch{left, right}}); } - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + { + if (!errorReported) + log.concat(std::move(innerState.log)); + } + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); if (errorReported) return; if (!missingPropertiesInSuper.empty()) { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(missingPropertiesInSuper)}}); + errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); return; } - // If the superTy/left is an immediate part of an intersection type, do not do extra-property check. + // If the superTy is an immediate part of an intersection type, do not do extra-property check. // Otherwise, we would falsely generate an extra-property-error for 's' in this code: // local a: {n: number} & {s: string} = {n=1, s=""} // When checking against the table '{n: number}'. - if (!isIntersection && lt->state != TableState::Unsealed && !lt->indexer) + if (!isIntersection && superTable->state != TableState::Unsealed && !superTable->indexer) { // Check for extra properties in the subTy std::vector extraPropertiesInSub; - for (const auto& it : rt->props) + for (const auto& [subKey, subProp] : subTable->props) { - const auto& r = lt->props.find(it.first); - if (r == lt->props.end()) + const auto& superIt = superTable->props.find(subKey); + if (superIt == superTable->props.end()) { - if (isOptional(it.second.type)) + if (isOptional(subProp.type)) continue; - extraPropertiesInSub.push_back(it.first); + extraPropertiesInSub.push_back(subKey); } } if (!extraPropertiesInSub.empty()) { - errors.push_back(TypeError{location, MissingProperties{left, right, std::move(extraPropertiesInSub), MissingProperties::Extra}}); + errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(extraPropertiesInSub), MissingProperties::Extra}}); return; } } - checkChildUnifierTypeMismatch(innerState.errors, left, right); + checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); } -void Unifier::tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reversed) +void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) { - const MetatableTypeVar* lhs = get(metatable); - if (!lhs) + const MetatableTypeVar* superMetatable = get(superTy); + if (!superMetatable) ice("tryUnifyMetatable invoked with non-metatable TypeVar"); - TypeError mismatchError = TypeError{location, TypeMismatch{reversed ? other : metatable, reversed ? metatable : other}}; + TypeError mismatchError = TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy}}; - if (const MetatableTypeVar* rhs = get(other)) + if (const MetatableTypeVar* subMetatable = + FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) { Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(lhs->table, rhs->table); - innerState.tryUnify_(lhs->metatable, rhs->metatable); + innerState.tryUnify_(subMetatable->table, superMetatable->table); + innerState.tryUnify_(subMetatable->metatable, superMetatable->metatable); if (auto e = hasUnificationTooComplex(innerState.errors)) errors.push_back(*e); else if (!innerState.errors.empty()) errors.push_back( - TypeError{location, TypeMismatch{reversed ? other : metatable, reversed ? metatable : other, "", innerState.errors.front()}}); + TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); - log.concat(std::move(innerState.log)); + if (FFlag::LuauUseCommittingTxnLog) + log.concat(std::move(innerState.log)); + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); } - else if (TableTypeVar* rhs = getMutable(other)) + else if (TableTypeVar* subTable = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : getMutable(subTy)) { - switch (rhs->state) + switch (subTable->state) { case TableState::Free: { - tryUnify_(lhs->table, other); - rhs->boundTo = metatable; + tryUnify_(subTy, superMetatable->table); + + if (FFlag::LuauUseCommittingTxnLog) + { + log.bindTable(subTy, superTy); + } + else + { + subTable->boundTo = superTy; + } break; } @@ -1722,7 +2454,8 @@ void Unifier::tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reverse errors.push_back(mismatchError); } } - else if (get(other) || get(other)) + else if (FFlag::LuauUseCommittingTxnLog ? (log.getMutable(subTy) || log.getMutable(subTy)) + : (get(subTy) || get(subTy))) { } else @@ -1732,7 +2465,7 @@ void Unifier::tryUnifyWithMetatable(TypeId metatable, TypeId other, bool reverse } // Class unification is almost, but not quite symmetrical. We use the 'reversed' boolean to indicate which scenario we are evaluating. -void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) +void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) { if (reversed) std::swap(superTy, subTy); @@ -1763,7 +2496,7 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) } ice("Illegal variance setting!"); } - else if (TableTypeVar* table = getMutable(subTy)) + else if (TableTypeVar* subTable = getMutable(subTy)) { /** * A free table is something whose shape we do not exactly know yet. @@ -1775,12 +2508,12 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) * * Tables that are not free are known to be actual tables. */ - if (table->state != TableState::Free) + if (subTable->state != TableState::Free) return fail(); bool ok = true; - for (const auto& [propName, prop] : table->props) + for (const auto& [propName, prop] : subTable->props) { const Property* classProp = lookupClassProp(superClass, propName); if (!classProp) @@ -1791,23 +2524,37 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) else { Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(prop.type, classProp->type); + innerState.tryUnify_(classProp->type, prop.type); checkChildUnifierTypeMismatch(innerState.errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); - if (innerState.errors.empty()) + if (FFlag::LuauUseCommittingTxnLog) { - log.concat(std::move(innerState.log)); + if (innerState.errors.empty()) + { + log.concat(std::move(innerState.log)); + } + else + { + ok = false; + } } else { - ok = false; - innerState.log.rollback(); + if (innerState.errors.empty()) + { + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + } + else + { + ok = false; + innerState.DEPRECATED_log.rollback(); + } } } } - if (table->indexer) + if (subTable->indexer) { ok = false; std::string msg = "Class " + superClass->name + " does not have an indexer"; @@ -1817,17 +2564,24 @@ void Unifier::tryUnifyWithClass(TypeId superTy, TypeId subTy, bool reversed) if (!ok) return; - log(table); - table->boundTo = superTy; + if (FFlag::LuauUseCommittingTxnLog) + { + log.bindTable(subTy, superTy); + } + else + { + DEPRECATED_log(subTable); + subTable->boundTo = superTy; + } } else return fail(); } -void Unifier::tryUnify(const TableIndexer& superIndexer, const TableIndexer& subIndexer) +void Unifier::tryUnifyIndexer(const TableIndexer& subIndexer, const TableIndexer& superIndexer) { - tryUnify_(superIndexer.indexType, subIndexer.indexType); - tryUnify_(superIndexer.indexResultType, subIndexer.indexResultType); + tryUnify_(subIndexer.indexType, superIndexer.indexType); + tryUnify_(subIndexer.indexResultType, superIndexer.indexResultType); } static void queueTypePack(std::vector& queue, DenseHashSet& seenTypePacks, Unifier& state, TypePackId a, TypePackId anyTypePack) @@ -1840,54 +2594,85 @@ static void queueTypePack(std::vector& queue, DenseHashSet& break; seenTypePacks.insert(a); - if (get(a)) + if (FFlag::LuauUseCommittingTxnLog) { - state.log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; + if (state.log.getMutable(a)) + { + state.log.replace(a, Unifiable::Bound{anyTypePack}); + } + else if (auto tp = state.log.getMutable(a)) + { + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; + } } - else if (auto tp = get(a)) + else { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; + if (get(a)) + { + state.DEPRECATED_log(a); + *asMutable(a) = Unifiable::Bound{anyTypePack}; + } + else if (auto tp = get(a)) + { + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; + } } } } -void Unifier::tryUnifyVariadics(TypePackId superTp, TypePackId subTp, bool reversed, int subOffset) +void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool reversed, int subOffset) { - const VariadicTypePack* lv = get(superTp); - if (!lv) + const VariadicTypePack* superVariadic = get(superTp); + + if (FFlag::LuauUseCommittingTxnLog) + { + superVariadic = log.getMutable(superTp); + } + + if (!superVariadic) ice("passed non-variadic pack to tryUnifyVariadics"); - if (const VariadicTypePack* rv = get(subTp)) - tryUnify_(reversed ? rv->ty : lv->ty, reversed ? lv->ty : rv->ty); + if (const VariadicTypePack* subVariadic = get(subTp)) + tryUnify_(reversed ? superVariadic->ty : subVariadic->ty, reversed ? subVariadic->ty : superVariadic->ty); else if (get(subTp)) { - TypePackIterator rIter = begin(subTp); - TypePackIterator rEnd = end(subTp); + TypePackIterator subIter = begin(subTp, &log); + TypePackIterator subEnd = end(subTp); - std::advance(rIter, subOffset); + std::advance(subIter, subOffset); - while (rIter != rEnd) + while (subIter != subEnd) { - tryUnify_(reversed ? *rIter : lv->ty, reversed ? lv->ty : *rIter); - ++rIter; + tryUnify_(reversed ? superVariadic->ty : *subIter, reversed ? *subIter : superVariadic->ty); + ++subIter; } - if (std::optional maybeTail = rIter.tail()) + if (std::optional maybeTail = subIter.tail()) { TypePackId tail = follow(*maybeTail); if (get(tail)) { - log(tail); - *asMutable(tail) = BoundTypePack{superTp}; + if (FFlag::LuauUseCommittingTxnLog) + { + log.replace(tail, BoundTypePack(superTp)); + } + else + { + DEPRECATED_log(tail); + *asMutable(tail) = BoundTypePack{superTp}; + } } else if (const VariadicTypePack* vtp = get(tail)) { - tryUnify_(lv->ty, vtp->ty); + tryUnify_(vtp->ty, superVariadic->ty); } else if (get(tail)) { @@ -1914,65 +2699,113 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas { while (!queue.empty()) { - TypeId ty = follow(queue.back()); - queue.pop_back(); - if (seen.find(ty)) - continue; - seen.insert(ty); - - if (get(ty)) - { - state.log(ty); - *asMutable(ty) = BoundTypeVar{anyType}; - } - else if (auto fun = get(ty)) - { - queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); - queueTypePack(queue, seenTypePacks, state, fun->retType, anyTypePack); - } - else if (auto table = get(ty)) + if (FFlag::LuauUseCommittingTxnLog) { - for (const auto& [_name, prop] : table->props) - queue.push_back(prop.type); + TypeId ty = state.log.follow(queue.back()); + queue.pop_back(); + if (seen.find(ty)) + continue; + seen.insert(ty); - if (table->indexer) + if (state.log.getMutable(ty)) { - queue.push_back(table->indexer->indexType); - queue.push_back(table->indexer->indexResultType); + state.log.replace(ty, BoundTypeVar{anyType}); } + else if (auto fun = state.log.getMutable(ty)) + { + queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); + queueTypePack(queue, seenTypePacks, state, fun->retType, anyTypePack); + } + else if (auto table = state.log.getMutable(ty)) + { + for (const auto& [_name, prop] : table->props) + queue.push_back(prop.type); + + if (table->indexer) + { + queue.push_back(table->indexer->indexType); + queue.push_back(table->indexer->indexResultType); + } + } + else if (auto mt = state.log.getMutable(ty)) + { + queue.push_back(mt->table); + queue.push_back(mt->metatable); + } + else if (state.log.getMutable(ty)) + { + // ClassTypeVars never contain free typevars. + } + else if (auto union_ = state.log.getMutable(ty)) + queue.insert(queue.end(), union_->options.begin(), union_->options.end()); + else if (auto intersection = state.log.getMutable(ty)) + queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); + else + { + } // Primitives, any, errors, and generics are left untouched. } - else if (auto mt = get(ty)) - { - queue.push_back(mt->table); - queue.push_back(mt->metatable); - } - else if (get(ty)) - { - // ClassTypeVars never contain free typevars. - } - else if (auto union_ = get(ty)) - queue.insert(queue.end(), union_->options.begin(), union_->options.end()); - else if (auto intersection = get(ty)) - queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); else { - } // Primitives, any, errors, and generics are left untouched. + TypeId ty = follow(queue.back()); + queue.pop_back(); + if (seen.find(ty)) + continue; + seen.insert(ty); + + if (get(ty)) + { + state.DEPRECATED_log(ty); + *asMutable(ty) = BoundTypeVar{anyType}; + } + else if (auto fun = get(ty)) + { + queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); + queueTypePack(queue, seenTypePacks, state, fun->retType, anyTypePack); + } + else if (auto table = get(ty)) + { + for (const auto& [_name, prop] : table->props) + queue.push_back(prop.type); + + if (table->indexer) + { + queue.push_back(table->indexer->indexType); + queue.push_back(table->indexer->indexResultType); + } + } + else if (auto mt = get(ty)) + { + queue.push_back(mt->table); + queue.push_back(mt->metatable); + } + else if (get(ty)) + { + // ClassTypeVars never contain free typevars. + } + else if (auto union_ = get(ty)) + queue.insert(queue.end(), union_->options.begin(), union_->options.end()); + else if (auto intersection = get(ty)) + queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); + else + { + } // Primitives, any, errors, and generics are left untouched. + } } } -void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) +void Unifier::tryUnifyWithAny(TypeId subTy, TypeId anyTy) { - LUAU_ASSERT(get(any) || get(any)); + LUAU_ASSERT(get(anyTy) || get(anyTy)); // These types are not visited in general loop below - if (get(ty) || get(ty) || get(ty)) + if (get(subTy) || get(subTy) || get(subTy)) return; const TypePackId anyTypePack = types->addTypePack(TypePackVar{VariadicTypePack{getSingletonTypes().anyType}}); - const TypePackId anyTP = get(any) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); + const TypePackId anyTP = get(anyTy) ? anyTypePack : types->addTypePack(TypePackVar{Unifiable::Error{}}); - std::vector queue = {ty}; + std::vector queue = {subTy}; sharedState.tempSeenTy.clear(); sharedState.tempSeenTp.clear(); @@ -1980,9 +2813,9 @@ void Unifier::tryUnifyWithAny(TypeId any, TypeId ty) Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, getSingletonTypes().anyType, anyTP); } -void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) +void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp) { - LUAU_ASSERT(get(any)); + LUAU_ASSERT(get(anyTp)); const TypeId anyTy = getSingletonTypes().errorRecoveryType(); @@ -1991,9 +2824,9 @@ void Unifier::tryUnifyWithAny(TypePackId any, TypePackId ty) sharedState.tempSeenTy.clear(); sharedState.tempSeenTp.clear(); - queueTypePack(queue, sharedState.tempSeenTp, *this, ty, any); + queueTypePack(queue, sharedState.tempSeenTp, *this, subTy, anyTp); - Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, anyTy, any); + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, anyTy, anyTp); } std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, Name name) @@ -2012,54 +2845,105 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays { RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); - needle = follow(needle); - haystack = follow(haystack); + auto check = [&](TypeId tv) { + occursCheck(seen, needle, tv); + }; - if (seen.find(haystack)) - return; + if (FFlag::LuauUseCommittingTxnLog) + { + needle = log.follow(needle); + haystack = log.follow(haystack); - seen.insert(haystack); + if (seen.find(haystack)) + return; - if (get(needle)) - return; + seen.insert(haystack); - if (!get(needle)) - ice("Expected needle to be free"); + if (log.getMutable(needle)) + return; - if (needle == haystack) - { - errors.push_back(TypeError{location, OccursCheckFailed{}}); - log(needle); - *asMutable(needle) = *getSingletonTypes().errorRecoveryType(); - return; - } + if (!log.getMutable(needle)) + ice("Expected needle to be free"); - auto check = [&](TypeId tv) { - occursCheck(seen, needle, tv); - }; + if (needle == haystack) + { + errors.push_back(TypeError{location, OccursCheckFailed{}}); + log.replace(needle, *getSingletonTypes().errorRecoveryType()); - if (get(haystack)) - return; - else if (auto a = get(haystack)) - { - if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) + return; + } + + if (log.getMutable(haystack)) + return; + else if (auto a = log.getMutable(haystack)) { - for (TypeId ty : a->argTypes) - check(ty); + if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) + { + for (TypePackIterator it(a->argTypes, &log); it != end(a->argTypes); ++it) + check(*it); - for (TypeId ty : a->retType) + for (TypePackIterator it(a->retType, &log); it != end(a->retType); ++it) + check(*it); + } + } + else if (auto a = log.getMutable(haystack)) + { + for (TypeId ty : a->options) + check(ty); + } + else if (auto a = log.getMutable(haystack)) + { + for (TypeId ty : a->parts) check(ty); } } - else if (auto a = get(haystack)) - { - for (TypeId ty : a->options) - check(ty); - } - else if (auto a = get(haystack)) + else { - for (TypeId ty : a->parts) - check(ty); + needle = follow(needle); + haystack = follow(haystack); + + if (seen.find(haystack)) + return; + + seen.insert(haystack); + + if (get(needle)) + return; + + if (!get(needle)) + ice("Expected needle to be free"); + + if (needle == haystack) + { + errors.push_back(TypeError{location, OccursCheckFailed{}}); + DEPRECATED_log(needle); + *asMutable(needle) = *getSingletonTypes().errorRecoveryType(); + return; + } + + if (get(haystack)) + return; + else if (auto a = get(haystack)) + { + if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) + { + for (TypeId ty : a->argTypes) + check(ty); + + for (TypeId ty : a->retType) + check(ty); + } + } + else if (auto a = get(haystack)) + { + for (TypeId ty : a->options) + check(ty); + } + else if (auto a = get(haystack)) + { + for (TypeId ty : a->parts) + check(ty); + } } } @@ -2072,59 +2956,115 @@ void Unifier::occursCheck(TypePackId needle, TypePackId haystack) void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack) { - needle = follow(needle); - haystack = follow(haystack); + if (FFlag::LuauUseCommittingTxnLog) + { + needle = log.follow(needle); + haystack = log.follow(haystack); - if (seen.find(haystack)) - return; + if (seen.find(haystack)) + return; - seen.insert(haystack); + seen.insert(haystack); - if (get(needle)) - return; + if (log.getMutable(needle)) + return; - if (!get(needle)) - ice("Expected needle pack to be free"); + if (!get(needle)) + ice("Expected needle pack to be free"); - RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); + RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); - while (!get(haystack)) - { - if (needle == haystack) + while (!log.getMutable(haystack)) { - errors.push_back(TypeError{location, OccursCheckFailed{}}); - log(needle); - *asMutable(needle) = *getSingletonTypes().errorRecoveryTypePack(); - return; - } + if (needle == haystack) + { + errors.push_back(TypeError{location, OccursCheckFailed{}}); + log.replace(needle, *getSingletonTypes().errorRecoveryTypePack()); - if (auto a = get(haystack)) - { - if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) + return; + } + + if (auto a = get(haystack)) { for (const auto& ty : a->head) { - if (auto f = get(follow(ty))) + if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) { - occursCheck(seen, needle, f->argTypes); - occursCheck(seen, needle, f->retType); + if (auto f = log.getMutable(log.follow(ty))) + { + occursCheck(seen, needle, f->argTypes); + occursCheck(seen, needle, f->retType); + } } } + + if (a->tail) + { + haystack = follow(*a->tail); + continue; + } } + break; + } + } + else + { + needle = follow(needle); + haystack = follow(haystack); + + if (seen.find(haystack)) + return; + + seen.insert(haystack); + + if (get(needle)) + return; - if (a->tail) + if (!get(needle)) + ice("Expected needle pack to be free"); + + RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); + + while (!get(haystack)) + { + if (needle == haystack) { - haystack = follow(*a->tail); - continue; + errors.push_back(TypeError{location, OccursCheckFailed{}}); + DEPRECATED_log(needle); + *asMutable(needle) = *getSingletonTypes().errorRecoveryTypePack(); } + + if (auto a = get(haystack)) + { + if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) + { + for (const auto& ty : a->head) + { + if (auto f = get(follow(ty))) + { + occursCheck(seen, needle, f->argTypes); + occursCheck(seen, needle, f->retType); + } + } + } + + if (a->tail) + { + haystack = follow(*a->tail); + continue; + } + } + break; } - break; } } Unifier Unifier::makeChildUnifier() { - return Unifier{types, mode, globalScope, log.sharedSeen, location, variance, sharedState}; + if (FFlag::LuauUseCommittingTxnLog) + return Unifier{types, mode, globalScope, log.sharedSeen, location, variance, sharedState, &log}; + else + return Unifier{types, mode, globalScope, DEPRECATED_log.sharedSeen, location, variance, sharedState, &log}; } bool Unifier::isNonstrictMode() const diff --git a/Ast/include/Luau/Common.h b/Ast/include/Luau/Common.h index 63cd3df49..fbb03a9e9 100644 --- a/Ast/include/Luau/Common.h +++ b/Ast/include/Luau/Common.h @@ -29,7 +29,7 @@ namespace Luau { -using AssertHandler = int (*)(const char* expression, const char* file, int line); +using AssertHandler = int (*)(const char* expression, const char* file, int line, const char* function); inline AssertHandler& assertHandler() { @@ -37,10 +37,10 @@ inline AssertHandler& assertHandler() return handler; } -inline int assertCallHandler(const char* expression, const char* file, int line) +inline int assertCallHandler(const char* expression, const char* file, int line, const char* function) { if (AssertHandler handler = assertHandler()) - return handler(expression, file, line); + return handler(expression, file, line, function); return 1; } @@ -48,7 +48,7 @@ inline int assertCallHandler(const char* expression, const char* file, int line) } // namespace Luau #if !defined(NDEBUG) || defined(LUAU_ENABLE_ASSERT) -#define LUAU_ASSERT(expr) ((void)(!!(expr) || (Luau::assertCallHandler(#expr, __FILE__, __LINE__) && (LUAU_DEBUGBREAK(), 0)))) +#define LUAU_ASSERT(expr) ((void)(!!(expr) || (Luau::assertCallHandler(#expr, __FILE__, __LINE__, __FUNCTION__) && (LUAU_DEBUGBREAK(), 0)))) #define LUAU_ASSERTENABLED #else #define LUAU_ASSERT(expr) (void)sizeof(!!(expr)) diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index aecb619a1..54a9a26fa 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -107,7 +107,7 @@ static void displayHelp(const char* argv0) printf(" --formatter=gnu: report analysis errors in GNU-compatible format\n"); } -static int assertionHandler(const char* expr, const char* file, int line) +static int assertionHandler(const char* expr, const char* file, int line, const char* function) { printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr); return 1; diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 35c02f2c8..26d4333a9 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -235,11 +235,14 @@ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, while (lua_next(L, -2) != 0) { - // table, key, value - std::string_view key = lua_tostring(L, -2); + if (lua_type(L, -2) == LUA_TSTRING) + { + // table, key, value + std::string_view key = lua_tostring(L, -2); - if (!key.empty() && Luau::startsWith(key, prefix)) - completions.push_back(editBuffer + std::string(key.substr(prefix.size()))); + if (!key.empty() && Luau::startsWith(key, prefix)) + completions.push_back(editBuffer + std::string(key.substr(prefix.size()))); + } lua_pop(L, 1); } @@ -253,7 +256,7 @@ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, lua_rawget(L, -2); lua_remove(L, -2); - if (lua_isnil(L, -1)) + if (!lua_istable(L, -1)) break; lookup.remove_prefix(dot + 1); @@ -266,7 +269,7 @@ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, static void completeRepl(lua_State* L, const char* editBuffer, std::vector& completions) { size_t start = strlen(editBuffer); - while (start > 0 && (isalnum(editBuffer[start - 1]) || editBuffer[start - 1] == '.')) + while (start > 0 && (isalnum(editBuffer[start - 1]) || editBuffer[start - 1] == '.' || editBuffer[start - 1] == '_')) start--; // look the value up in current global table first @@ -278,6 +281,34 @@ static void completeRepl(lua_State* L, const char* editBuffer, std::vector globalState(luaL_newstate(), lua_close); @@ -292,6 +323,7 @@ static void runRepl() }); std::string buffer; + LinenoiseScopedHistory scopedHistory; for (;;) { @@ -457,7 +489,7 @@ static void displayHelp(const char* argv0) printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); } -static int assertionHandler(const char* expr, const char* file, int line) +static int assertionHandler(const char* expr, const char* file, int line, const char* function) { printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr); return 1; diff --git a/CLI/Web.cpp b/CLI/Web.cpp index cf5c831e9..416a79f2b 100644 --- a/CLI/Web.cpp +++ b/CLI/Web.cpp @@ -53,30 +53,38 @@ static std::string runCode(lua_State* L, const std::string& source) lua_insert(T, 1); lua_pcall(T, n, 0, 0); } + + lua_pop(L, 1); // pop T + return std::string(); } else { std::string error; + lua_Debug ar; + if (lua_getinfo(L, 0, "sln", &ar)) + { + error += ar.short_src; + error += ':'; + error += std::to_string(ar.currentline); + error += ": "; + } + if (status == LUA_YIELD) { - error = "thread yielded unexpectedly"; + error += "thread yielded unexpectedly"; } else if (const char* str = lua_tostring(T, -1)) { - error = str; + error += str; } error += "\nstack backtrace:\n"; error += lua_debugtrace(T); - error = "Error:" + error; - - fprintf(stdout, "%s", error.c_str()); + lua_pop(L, 1); // pop T + return error; } - - lua_pop(L, 1); - return std::string(); } extern "C" const char* executeScript(const char* source) diff --git a/Compiler/include/Luau/Bytecode.h b/Compiler/include/Luau/Bytecode.h index 71631d101..d9694d7d0 100644 --- a/Compiler/include/Luau/Bytecode.h +++ b/Compiler/include/Luau/Bytecode.h @@ -377,6 +377,7 @@ enum LuauBytecodeTag { // Bytecode version LBC_VERSION = 1, + LBC_VERSION_FUTURE = 2, // TODO: This will be removed in favor of LBC_VERSION with LuauBytecodeV2Force // Types of constant table entries LBC_CONSTANT_NIL = 0, LBC_CONSTANT_BOOLEAN, diff --git a/Compiler/include/Luau/BytecodeBuilder.h b/Compiler/include/Luau/BytecodeBuilder.h index d4ebad6bf..287bf4eed 100644 --- a/Compiler/include/Luau/BytecodeBuilder.h +++ b/Compiler/include/Luau/BytecodeBuilder.h @@ -74,6 +74,7 @@ class BytecodeBuilder void expandJumps(); void setDebugFunctionName(StringRef name); + void setDebugFunctionLineDefined(int line); void setDebugLine(int line); void pushDebugLocal(StringRef name, uint8_t reg, uint32_t startpc, uint32_t endpc); void pushDebugUpval(StringRef name); @@ -162,6 +163,7 @@ class BytecodeBuilder bool isvararg = false; unsigned int debugname = 0; + int debuglinedefined = 0; std::string dump; std::string dumpname; diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 3280c8a40..2d31c409c 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -6,6 +6,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Write, false) + namespace Luau { @@ -81,6 +83,52 @@ static int getOpLength(LuauOpcode op) } } +inline bool isJumpD(LuauOpcode op) +{ + switch (op) + { + case LOP_JUMP: + case LOP_JUMPIF: + case LOP_JUMPIFNOT: + case LOP_JUMPIFEQ: + case LOP_JUMPIFLE: + case LOP_JUMPIFLT: + case LOP_JUMPIFNOTEQ: + case LOP_JUMPIFNOTLE: + case LOP_JUMPIFNOTLT: + case LOP_FORNPREP: + case LOP_FORNLOOP: + case LOP_FORGLOOP: + case LOP_FORGPREP_INEXT: + case LOP_FORGLOOP_INEXT: + case LOP_FORGPREP_NEXT: + case LOP_FORGLOOP_NEXT: + case LOP_JUMPBACK: + case LOP_JUMPIFEQK: + case LOP_JUMPIFNOTEQK: + return true; + + default: + return false; + } +} + +inline bool isSkipC(LuauOpcode op) +{ + switch (op) + { + case LOP_LOADB: + case LOP_FASTCALL: + case LOP_FASTCALL1: + case LOP_FASTCALL2: + case LOP_FASTCALL2K: + return true; + + default: + return false; + } +} + bool BytecodeBuilder::StringRef::operator==(const StringRef& other) const { return (data && other.data) ? (length == other.length && memcmp(data, other.data, length) == 0) : (data == other.data); @@ -365,13 +413,7 @@ bool BytecodeBuilder::patchJumpD(size_t jumpLabel, size_t targetLabel) unsigned int jumpInsn = insns[jumpLabel]; (void)jumpInsn; - LUAU_ASSERT(LUAU_INSN_OP(jumpInsn) == LOP_JUMP || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIF || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFNOT || - LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFEQ || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFLE || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFLT || - LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFNOTEQ || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFNOTLE || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFNOTLT || - LUAU_INSN_OP(jumpInsn) == LOP_FORNPREP || LUAU_INSN_OP(jumpInsn) == LOP_FORNLOOP || LUAU_INSN_OP(jumpInsn) == LOP_FORGLOOP || - LUAU_INSN_OP(jumpInsn) == LOP_FORGPREP_INEXT || LUAU_INSN_OP(jumpInsn) == LOP_FORGLOOP_INEXT || - LUAU_INSN_OP(jumpInsn) == LOP_FORGPREP_NEXT || LUAU_INSN_OP(jumpInsn) == LOP_FORGLOOP_NEXT || - LUAU_INSN_OP(jumpInsn) == LOP_JUMPBACK || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFEQK || LUAU_INSN_OP(jumpInsn) == LOP_JUMPIFNOTEQK); + LUAU_ASSERT(isJumpD(LuauOpcode(LUAU_INSN_OP(jumpInsn)))); LUAU_ASSERT(LUAU_INSN_D(jumpInsn) == 0); LUAU_ASSERT(targetLabel <= insns.size()); @@ -403,8 +445,7 @@ bool BytecodeBuilder::patchSkipC(size_t jumpLabel, size_t targetLabel) unsigned int jumpInsn = insns[jumpLabel]; (void)jumpInsn; - LUAU_ASSERT(LUAU_INSN_OP(jumpInsn) == LOP_FASTCALL || LUAU_INSN_OP(jumpInsn) == LOP_FASTCALL1 || LUAU_INSN_OP(jumpInsn) == LOP_FASTCALL2 || - LUAU_INSN_OP(jumpInsn) == LOP_FASTCALL2K); + LUAU_ASSERT(isSkipC(LuauOpcode(LUAU_INSN_OP(jumpInsn)))); LUAU_ASSERT(LUAU_INSN_C(jumpInsn) == 0); int offset = int(targetLabel) - int(jumpLabel) - 1; @@ -428,6 +469,11 @@ void BytecodeBuilder::setDebugFunctionName(StringRef name) functions[currentFunction].dumpname = std::string(name.data, name.length); } +void BytecodeBuilder::setDebugFunctionLineDefined(int line) +{ + functions[currentFunction].debuglinedefined = line; +} + void BytecodeBuilder::setDebugLine(int line) { debugLine = line; @@ -464,7 +510,7 @@ uint32_t BytecodeBuilder::getDebugPC() const void BytecodeBuilder::finalize() { LUAU_ASSERT(bytecode.empty()); - bytecode = char(LBC_VERSION); + bytecode = char(FFlag::LuauBytecodeV2Write ? LBC_VERSION_FUTURE : LBC_VERSION); writeStringTable(bytecode); @@ -565,6 +611,9 @@ void BytecodeBuilder::writeFunction(std::string& ss, uint32_t id) const writeVarInt(ss, child); // debug info + if (FFlag::LuauBytecodeV2Write) + writeVarInt(ss, func.debuglinedefined); + writeVarInt(ss, func.debugname); bool hasLines = true; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 8f74ffedd..6ae490273 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -11,7 +11,6 @@ #include LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport) -LUAU_FASTFLAGVARIABLE(LuauBit32CountBuiltin, false) namespace Luau { @@ -179,6 +178,8 @@ struct Compiler if (options.optimizationLevel >= 1 && options.debugLevel >= 2) gatherConstUpvals(func); + bytecode.setDebugFunctionLineDefined(func->location.begin.line + 1); + if (options.debugLevel >= 1 && func->debugname.value) bytecode.setDebugFunctionName(sref(func->debugname)); @@ -3626,9 +3627,9 @@ struct Compiler return LBF_BIT32_RROTATE; if (builtin.method == "rshift") return LBF_BIT32_RSHIFT; - if (builtin.method == "countlz" && FFlag::LuauBit32CountBuiltin) + if (builtin.method == "countlz") return LBF_BIT32_COUNTLZ; - if (builtin.method == "countrz" && FFlag::LuauBit32CountBuiltin) + if (builtin.method == "countrz") return LBF_BIT32_COUNTRZ; } diff --git a/Sources.cmake b/Sources.cmake index a7153eb37..5dd486aaa 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -125,6 +125,7 @@ target_sources(Luau.VM PRIVATE VM/src/linit.cpp VM/src/lmathlib.cpp VM/src/lmem.cpp + VM/src/lnumprint.cpp VM/src/lobject.cpp VM/src/loslib.cpp VM/src/lperf.cpp diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h index a01a14819..7e0832e7c 100644 --- a/VM/include/luaconf.h +++ b/VM/include/luaconf.h @@ -138,10 +138,6 @@ /* }================================================================== */ -/* Default number printing format and the string length limit */ -#define LUA_NUMBER_FMT "%.14g" -#define LUAI_MAXNUMBER2STR 32 /* 16 digits, sign, point, and \0 */ - /* @@ LUAI_USER_ALIGNMENT_T is a type that requires maximum alignment. ** CHANGE it if your system requires alignments larger than double. (For diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index a65b03253..c98b95908 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -14,8 +14,6 @@ #include -LUAU_FASTFLAG(LuauActivateBeforeExec) - const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -939,21 +937,7 @@ void lua_call(lua_State* L, int nargs, int nresults) checkresults(L, nargs, nresults); func = L->top - (nargs + 1); - if (FFlag::LuauActivateBeforeExec) - { - luaD_call(L, func, nresults); - } - else - { - int oldactive = luaC_threadactive(L); - l_setbit(L->stackstate, THREAD_ACTIVEBIT); - luaC_checkthreadsleep(L); - - luaD_call(L, func, nresults); - - if (!oldactive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); - } + luaD_call(L, func, nresults); adjustresults(L, nresults); return; @@ -994,21 +978,7 @@ int lua_pcall(lua_State* L, int nargs, int nresults, int errfunc) c.func = L->top - (nargs + 1); /* function to be called */ c.nresults = nresults; - if (FFlag::LuauActivateBeforeExec) - { - status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); - } - else - { - int oldactive = luaC_threadactive(L); - l_setbit(L->stackstate, THREAD_ACTIVEBIT); - luaC_checkthreadsleep(L); - - status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); - - if (!oldactive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); - } + status = luaD_pcall(L, f_call, &c, savestack(L, c.func), func); adjustresults(L, nresults); return status; diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index 7ed2a62ee..71975a520 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -7,9 +7,12 @@ #include "lstring.h" #include "lapi.h" #include "lgc.h" +#include "lnumutils.h" #include +LUAU_FASTFLAG(LuauSchubfach) + /* convert a stack index to positive */ #define abs_index(L, i) ((i) > 0 || (i) <= LUA_REGISTRYINDEX ? (i) : lua_gettop(L) + (i) + 1) @@ -477,7 +480,17 @@ const char* luaL_tolstring(lua_State* L, int idx, size_t* len) switch (lua_type(L, idx)) { case LUA_TNUMBER: - lua_pushstring(L, lua_tostring(L, idx)); + if (FFlag::LuauSchubfach) + { + double n = lua_tonumber(L, idx); + char s[LUAI_MAXNUM2STR]; + char* e = luai_num2str(s, n); + lua_pushlstring(L, s, e - s); + } + else + { + lua_pushstring(L, lua_tostring(L, idx)); + } break; case LUA_TSTRING: lua_pushvalue(L, idx); @@ -491,11 +504,30 @@ const char* luaL_tolstring(lua_State* L, int idx, size_t* len) case LUA_TVECTOR: { const float* v = lua_tovector(L, idx); + + if (FFlag::LuauSchubfach) + { + char s[LUAI_MAXNUM2STR * LUA_VECTOR_SIZE]; + char* e = s; + for (int i = 0; i < LUA_VECTOR_SIZE; ++i) + { + if (i != 0) + { + *e++ = ','; + *e++ = ' '; + } + e = luai_num2str(e, v[i]); + } + lua_pushlstring(L, s, e - s); + } + else + { #if LUA_VECTOR_SIZE == 4 - lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2], v[3]); + lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2], v[3]); #else - lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2]); + lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2]); #endif + } break; } default: diff --git a/VM/src/lbitlib.cpp b/VM/src/lbitlib.cpp index 8b511edf0..093400f2e 100644 --- a/VM/src/lbitlib.cpp +++ b/VM/src/lbitlib.cpp @@ -5,8 +5,6 @@ #include "lcommon.h" #include "lnumutils.h" -LUAU_FASTFLAGVARIABLE(LuauBit32Count, false) - #define ALLONES ~0u #define NBITS int(8 * sizeof(unsigned)) @@ -182,9 +180,6 @@ static int b_replace(lua_State* L) static int b_countlz(lua_State* L) { - if (!FFlag::LuauBit32Count) - luaL_error(L, "bit32.countlz isn't enabled"); - b_uint v = luaL_checkunsigned(L, 1); b_uint r = NBITS; @@ -201,9 +196,6 @@ static int b_countlz(lua_State* L) static int b_countrz(lua_State* L) { - if (!FFlag::LuauBit32Count) - luaL_error(L, "bit32.countrz isn't enabled"); - b_uint v = luaL_checkunsigned(L, 1); b_uint r = NBITS; diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 9fe1885fb..2b5382bba 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -12,6 +12,9 @@ #include #include +LUAU_FASTFLAG(LuauBytecodeV2Read) +LUAU_FASTFLAG(LuauBytecodeV2Force) + static const char* getfuncname(Closure* f); static int currentpc(lua_State* L, CallInfo* ci) @@ -89,6 +92,16 @@ const char* lua_setlocal(lua_State* L, int level, int n) return name; } +static int getlinedefined(Proto* p) +{ + if (FFlag::LuauBytecodeV2Force) + return p->linedefined; + else if (FFlag::LuauBytecodeV2Read && p->linedefined >= 0) + return p->linedefined; + else + return luaG_getline(p, 0); +} + static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, CallInfo* ci) { int status = 1; @@ -108,7 +121,7 @@ static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, { ar->source = getstr(f->l.p->source); ar->what = "Lua"; - ar->linedefined = luaG_getline(f->l.p, 0); + ar->linedefined = getlinedefined(f->l.p); } luaO_chunkid(ar->short_src, ar->source, LUA_IDSIZE); break; @@ -121,7 +134,7 @@ static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, } else { - ar->currentline = f->isC ? -1 : luaG_getline(f->l.p, 0); + ar->currentline = f->isC ? -1 : getlinedefined(f->l.p); } break; diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 62bbdb7c9..eb47971a8 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -19,7 +19,6 @@ LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false) LUAU_FASTFLAG(LuauCoroutineClose) -LUAU_FASTFLAGVARIABLE(LuauActivateBeforeExec, false) /* ** {====================================================== @@ -228,21 +227,14 @@ void luaD_call(lua_State* L, StkId func, int nResults) { /* is a Lua function? */ L->ci->flags |= LUA_CALLINFO_RETURN; /* luau_execute will stop after returning from the stack frame */ - if (FFlag::LuauActivateBeforeExec) - { - int oldactive = luaC_threadactive(L); - l_setbit(L->stackstate, THREAD_ACTIVEBIT); - luaC_checkthreadsleep(L); + int oldactive = luaC_threadactive(L); + l_setbit(L->stackstate, THREAD_ACTIVEBIT); + luaC_checkthreadsleep(L); - luau_execute(L); /* call it */ + luau_execute(L); /* call it */ - if (!oldactive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); - } - else - { - luau_execute(L); /* call it */ - } + if (!oldactive) + resetbit(L->stackstate, THREAD_ACTIVEBIT); } L->nCcalls--; luaC_checkGC(L); @@ -549,12 +541,9 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e status = LUA_ERRERR; } - if (FFlag::LuauActivateBeforeExec) - { - // since the call failed with an error, we might have to reset the 'active' thread state - if (!oldactive) - resetbit(L->stackstate, THREAD_ACTIVEBIT); - } + // since the call failed with an error, we might have to reset the 'active' thread state + if (!oldactive) + resetbit(L->stackstate, THREAD_ACTIVEBIT); if (FFlag::LuauCcallRestoreFix) { diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 7393fc74f..76ef7a06b 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -12,8 +12,6 @@ #include -LUAU_FASTFLAG(LuauArrayBoundary) - #define GC_SWEEPMAX 40 #define GC_SWEEPCOST 10 diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index f6f7a878f..c66de9c1d 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -12,8 +12,6 @@ #include #include -LUAU_FASTFLAG(LuauArrayBoundary) - static void validateobjref(global_State* g, GCObject* f, GCObject* t) { LUAU_ASSERT(!isdead(g, t)); @@ -38,10 +36,7 @@ static void validatetable(global_State* g, Table* h) { int sizenode = 1 << h->lsizenode; - if (FFlag::LuauArrayBoundary) - LUAU_ASSERT(h->lastfree <= sizenode); - else - LUAU_ASSERT(h->lastfree >= 0 && h->lastfree <= sizenode); + LUAU_ASSERT(h->lastfree <= sizenode); if (h->metatable) validateobjref(g, obj2gco(h), obj2gco(h->metatable)); diff --git a/VM/src/lnumprint.cpp b/VM/src/lnumprint.cpp new file mode 100644 index 000000000..2fd0f1bbd --- /dev/null +++ b/VM/src/lnumprint.cpp @@ -0,0 +1,375 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +// This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details +#include "luaconf.h" +#include "lnumutils.h" + +#include "lcommon.h" + +#include +#include // TODO: Remove with LuauSchubfach + +#ifdef _MSC_VER +#include +#endif + +// This work is based on: +// Raffaello Giulietti. The Schubfach way to render doubles. 2021 +// https://drive.google.com/file/d/1IEeATSVnEE6TkrHlCYNY2GjaraBjOT4f/edit + +// The code uses the notation from the paper for local variables where appropriate, and refers to paper sections/figures/results. + +LUAU_FASTFLAGVARIABLE(LuauSchubfach, false) + +// 9.8.2. Precomputed table for 128-bit overestimates of powers of 10 (see figure 3 for table bounds) +// To avoid storing 616 128-bit numbers directly we use a technique inspired by Dragonbox implementation and store 16 consecutive +// powers using a 128-bit baseline and a bitvector with 1-bit scale and 3-bit offset for the delta between each entry and base*5^k +static const int kPow10TableMin = -292; +static const int kPow10TableMax = 324; + +// clang-format off +static const uint64_t kPow5Table[16] = { + 0x8000000000000000, 0xa000000000000000, 0xc800000000000000, 0xfa00000000000000, 0x9c40000000000000, 0xc350000000000000, + 0xf424000000000000, 0x9896800000000000, 0xbebc200000000000, 0xee6b280000000000, 0x9502f90000000000, 0xba43b74000000000, + 0xe8d4a51000000000, 0x9184e72a00000000, 0xb5e620f480000000, 0xe35fa931a0000000, +}; +static const uint64_t kPow10Table[(kPow10TableMax - kPow10TableMin + 1 + 15) / 16][3] = { + {0xff77b1fcbebcdc4f, 0x25e8e89c13bb0f7b, 0x333443443333443b}, {0x8dd01fad907ffc3b, 0xae3da7d97f6792e4, 0xbbb3ab3cb3ba3cbc}, + {0x9d71ac8fada6c9b5, 0x6f773fc3603db4aa, 0x4ba4bc4bb4bb4bcc}, {0xaecc49914078536d, 0x58fae9f773886e19, 0x3ba3bc33b43b43bb}, + {0xc21094364dfb5636, 0x985915fc12f542e5, 0x33b43b43a33b33cb}, {0xd77485cb25823ac7, 0x7d633293366b828c, 0x34b44c444343443c}, + {0xef340a98172aace4, 0x86fb897116c87c35, 0x333343333343334b}, {0x84c8d4dfd2c63f3b, 0x29ecd9f40041e074, 0xccaccbbcbcbb4bbc}, + {0x936b9fcebb25c995, 0xcab10dd900beec35, 0x3ab3ab3ab3bb3bbb}, {0xa3ab66580d5fdaf5, 0xc13e60d0d2e0ebbb, 0x4cc3dc4db4db4dbb}, + {0xb5b5ada8aaff80b8, 0x0d819992132456bb, 0x33b33a34c33b34ab}, {0xc9bcff6034c13052, 0xfc89b393dd02f0b6, 0x33c33b44b43c34bc}, + {0xdff9772470297ebd, 0x59787e2b93bc56f8, 0x43b444444443434c}, {0xf8a95fcf88747d94, 0x75a44c6397ce912b, 0x443334343443343b}, + {0x8a08f0f8bf0f156b, 0x1b8e9ecb641b5900, 0xbbabab3aa3ab4ccc}, {0x993fe2c6d07b7fab, 0xe546a8038efe402a, 0x4cb4bc4db4db4bcc}, + {0xaa242499697392d2, 0xdde50bd1d5d0b9ea, 0x3ba3ba3bb33b33bc}, {0xbce5086492111aea, 0x88f4bb1ca6bcf585, 0x44b44c44c44c43cb}, + {0xd1b71758e219652b, 0xd3c36113404ea4a9, 0x44c44c44c444443b}, {0xe8d4a51000000000, 0x0000000000000000, 0x444444444444444c}, + {0x813f3978f8940984, 0x4000000000000000, 0xcccccccccccccccc}, {0x8f7e32ce7bea5c6f, 0xe4820023a2000000, 0xbba3bc4cc4cc4ccc}, + {0x9f4f2726179a2245, 0x01d762422c946591, 0x4aa3bb3aa3ba3bab}, {0xb0de65388cc8ada8, 0x3b25a55f43294bcc, 0x3ca33b33b44b43bc}, + {0xc45d1df942711d9a, 0x3ba5d0bd324f8395, 0x44c44c34c44b44cb}, {0xda01ee641a708de9, 0xe80e6f4820cc9496, 0x33b33b343333333c}, + {0xf209787bb47d6b84, 0xc0678c5dbd23a49b, 0x443444444443443b}, {0x865b86925b9bc5c2, 0x0b8a2392ba45a9b3, 0xdbccbcccb4cb3bbb}, + {0x952ab45cfa97a0b2, 0xdd945a747bf26184, 0x3bc4bb4ab3ca3cbc}, {0xa59bc234db398c25, 0x43fab9837e699096, 0x3bb3ac3ab3bb33ac}, + {0xb7dcbf5354e9bece, 0x0c11ed6d538aeb30, 0x33b43b43b34c34dc}, {0xcc20ce9bd35c78a5, 0x31ec038df7b441f5, 0x34c44c43c44b44cb}, + {0xe2a0b5dc971f303a, 0x2e44ae64840fd61e, 0x333333333333333c}, {0xfb9b7cd9a4a7443c, 0x169840ef017da3b2, 0x433344443333344c}, + {0x8bab8eefb6409c1a, 0x1ad089b6c2f7548f, 0xdcbdcc3cc4cc4bcb}, {0x9b10a4e5e9913128, 0xca7cf2b4191c8327, 0x3ab3cb3bc3bb4bbb}, + {0xac2820d9623bf429, 0x546345fa9fbdcd45, 0x3bb3cc43c43c43cb}, {0xbf21e44003acdd2c, 0xe0470a63e6bd56c4, 0x44b34a43b44c44bc}, + {0xd433179d9c8cb841, 0x5fa60692a46151ec, 0x43a33a33a333333c}, +}; +// clang-format on + +static const char kDigitTable[] = "0001020304050607080910111213141516171819202122232425262728293031323334353637383940414243444546474849" + "5051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899"; + +// x*y => 128-bit product (lo+hi) +inline uint64_t mul128(uint64_t x, uint64_t y, uint64_t* hi) +{ +#if defined(_MSC_VER) && defined(_M_X64) + return _umul128(x, y, hi); +#elif defined(__SIZEOF_INT128__) + unsigned __int128 r = x; + r *= y; + *hi = uint64_t(r >> 64); + return uint64_t(r); +#else + uint32_t x0 = uint32_t(x), x1 = uint32_t(x >> 32); + uint32_t y0 = uint32_t(y), y1 = uint32_t(y >> 32); + uint64_t p11 = uint64_t(x1) * y1, p01 = uint64_t(x0) * y1; + uint64_t p10 = uint64_t(x1) * y0, p00 = uint64_t(x0) * y0; + uint64_t mid = p10 + (p00 >> 32) + uint32_t(p01); + uint64_t r0 = (mid << 32) | uint32_t(p00); + uint64_t r1 = p11 + (mid >> 32) + (p01 >> 32); + *hi = r1; + return r0; +#endif +} + +// (x*y)>>64 => 128-bit product (lo+hi) +inline uint64_t mul192hi(uint64_t xhi, uint64_t xlo, uint64_t y, uint64_t* hi) +{ + uint64_t z2; + uint64_t z1 = mul128(xhi, y, &z2); + + uint64_t z1c; + uint64_t z0 = mul128(xlo, y, &z1c); + (void)z0; + + z1 += z1c; + z2 += (z1 < z1c); + + *hi = z2; + return z1; +} + +// 9.3. Rounding to odd (+ figure 8 + result 23) +inline uint64_t roundodd(uint64_t ghi, uint64_t glo, uint64_t cp) +{ + uint64_t xhi; + uint64_t xlo = mul128(glo, cp, &xhi); + (void)xlo; + + uint64_t yhi; + uint64_t ylo = mul128(ghi, cp, &yhi); + + uint64_t z = ylo + xhi; + return (yhi + (z < xhi)) | (z > 1); +} + +struct Decimal +{ + uint64_t s; + int k; +}; + +static Decimal schubfach(int exponent, uint64_t fraction) +{ + // Extract c & q such that c*2^q == |v| + uint64_t c = fraction; + int q = exponent - 1023 - 51; + + if (exponent != 0) // normal numbers have implicit leading 1 + { + c |= (1ull << 52); + q--; + } + + // 8.3. Fast path for integers + if (unsigned(-q) < 53 && (c & ((1ull << (-q)) - 1)) == 0) + return {c >> (-q), 0}; + + // 5. Rounding interval + int irr = (c == (1ull << 52) && q != -1074); // Qmin + int out = int(c & 1); + + // 9.8.1. Boundaries for c + uint64_t cbl = 4 * c - 2 + irr; + uint64_t cb = 4 * c; + uint64_t cbr = 4 * c + 2; + + // 9.1. Computing k and h + const int Q = 20; + const int C = 315652; // floor(2^Q * log10(2)) + const int A = -131008; // floor(2^Q * log10(3/4)) + const int C2 = 3483294; // floor(2^Q * log2(10)) + int k = (q * C + (irr ? A : 0)) >> Q; + int h = q + ((-k * C2) >> Q) + 1; // see (9) in 9.9 + + // 9.8.2. Overestimates of powers of 10 + // Recover 10^-k fraction using compact tables generated by tools/numutils.py + // The 128-bit fraction is encoded as 128-bit baseline * power-of-5 * scale + offset + LUAU_ASSERT(-k >= kPow10TableMin && -k <= kPow10TableMax); + int gtoff = -k - kPow10TableMin; + const uint64_t* gt = kPow10Table[gtoff >> 4]; + + uint64_t ghi; + uint64_t glo = mul192hi(gt[0], gt[1], kPow5Table[gtoff & 15], &ghi); + + // Apply 1-bit scale + 3-bit offset; note, offset is intentionally applied without carry, numutils.py validates that this is sufficient + int gterr = (gt[2] >> ((gtoff & 15) * 4)) & 15; + int gtscale = gterr >> 3; + + ghi <<= gtscale; + ghi += (glo >> 63) & gtscale; + glo <<= gtscale; + glo -= (gterr & 7) - 4; + + // 9.9. Boundaries for v + uint64_t vbl = roundodd(ghi, glo, cbl << h); + uint64_t vb = roundodd(ghi, glo, cb << h); + uint64_t vbr = roundodd(ghi, glo, cbr << h); + + // Main algorithm; see figure 7 + figure 9 + uint64_t s = vb / 4; + + if (s >= 10) + { + uint64_t sp = s / 10; + + bool upin = vbl + out <= 40 * sp; + bool wpin = vbr >= 40 * sp + 40 + out; + + if (upin != wpin) + return {sp + wpin, k + 1}; + } + + // Figure 7 contains the algorithm to select between u (s) and w (s+1) + // rup computes the last 4 conditions in that algorithm + // rup is only used when uin == win, but since these branches predict poorly we use branchless selects + bool uin = vbl + out <= 4 * s; + bool win = 4 * s + 4 + out <= vbr; + bool rup = vb >= 4 * s + 2 + 1 - (s & 1); + + return {s + (uin != win ? win : rup), k}; +} + +static char* printspecial(char* buf, int sign, uint64_t fraction) +{ + if (fraction == 0) + { + memcpy(buf, ("-inf") + (1 - sign), 4); + return buf + 3 + sign; + } + else + { + memcpy(buf, "nan", 4); + return buf + 3; + } +} + +static char* printunsignedrev(char* end, uint64_t num) +{ + while (num >= 10000) + { + unsigned int tail = unsigned(num % 10000); + + memcpy(end - 4, &kDigitTable[int(tail / 100) * 2], 2); + memcpy(end - 2, &kDigitTable[int(tail % 100) * 2], 2); + num /= 10000; + end -= 4; + } + + unsigned int rest = unsigned(num); + + while (rest >= 10) + { + memcpy(end - 2, &kDigitTable[int(rest % 100) * 2], 2); + rest /= 100; + end -= 2; + } + + if (rest) + { + end[-1] = '0' + int(rest); + end -= 1; + } + + return end; +} + +static char* printexp(char* buf, int num) +{ + *buf++ = 'e'; + *buf++ = num < 0 ? '-' : '+'; + + int v = num < 0 ? -num : num; + + if (v >= 100) + { + *buf++ = '0' + (v / 100); + v %= 100; + } + + memcpy(buf, &kDigitTable[v * 2], 2); + return buf + 2; +} + +inline char* trimzero(char* end) +{ + while (end[-1] == '0') + end--; + + return end; +} + +// We use fixed-length memcpy/memset since they lower to fast SIMD+scalar writes; the target buffers should have padding space +#define fastmemcpy(dst, src, size, sizefast) check_exp((size) <= sizefast, memcpy(dst, src, sizefast)) +#define fastmemset(dst, val, size, sizefast) check_exp((size) <= sizefast, memset(dst, val, sizefast)) + +char* luai_num2str(char* buf, double n) +{ + if (!FFlag::LuauSchubfach) + { + snprintf(buf, LUAI_MAXNUM2STR, LUA_NUMBER_FMT, n); + return buf + strlen(buf); + } + + // IEEE-754 + union + { + double v; + uint64_t bits; + } v = {n}; + int sign = int(v.bits >> 63); + int exponent = int(v.bits >> 52) & 2047; + uint64_t fraction = v.bits & ((1ull << 52) - 1); + + // specials + if (LUAU_UNLIKELY(exponent == 0x7ff)) + return printspecial(buf, sign, fraction); + + // sign bit + *buf = '-'; + buf += sign; + + // zero + if (exponent == 0 && fraction == 0) + { + buf[0] = '0'; + return buf + 1; + } + + // convert binary to decimal using Schubfach + Decimal d = schubfach(exponent, fraction); + LUAU_ASSERT(d.s < uint64_t(1e17)); + + // print the decimal to a temporary buffer; we'll need to insert the decimal point and figure out the format + char decbuf[40]; + char* decend = decbuf + 20; // significand needs at most 17 digits; the rest of the buffer may be copied using fixed length memcpy + char* dec = printunsignedrev(decend, d.s); + + int declen = int(decend - dec); + LUAU_ASSERT(declen <= 17); + + int dot = declen + d.k; + + // the limits are somewhat arbitrary but changing them may require changing fastmemset/fastmemcpy sizes below + if (dot >= -5 && dot <= 21) + { + // fixed point format + if (dot <= 0) + { + buf[0] = '0'; + buf[1] = '.'; + + fastmemset(buf + 2, '0', -dot, 5); + fastmemcpy(buf + 2 + (-dot), dec, declen, 17); + + return trimzero(buf + 2 + (-dot) + declen); + } + else if (dot == declen) + { + // no dot + fastmemcpy(buf, dec, dot, 17); + + return buf + dot; + } + else if (dot < declen) + { + // dot in the middle + fastmemcpy(buf, dec, dot, 16); + + buf[dot] = '.'; + + fastmemcpy(buf + dot + 1, dec + dot, declen - dot, 16); + + return trimzero(buf + declen + 1); + } + else + { + // no dot, zero padding + fastmemcpy(buf, dec, declen, 17); + fastmemset(buf + declen, '0', dot - declen, 8); + + return buf + dot; + } + } + else + { + // scientific format + buf[0] = dec[0]; + buf[1] = '.'; + fastmemcpy(buf + 2, dec + 1, declen - 1, 16); + + char* exp = trimzero(buf + declen + 1); + + return printexp(exp, dot - 1); + } +} diff --git a/VM/src/lnumutils.h b/VM/src/lnumutils.h index 67f832dc5..fba07bc3b 100644 --- a/VM/src/lnumutils.h +++ b/VM/src/lnumutils.h @@ -3,7 +3,6 @@ #pragma once #include -#include #define luai_numadd(a, b) ((a) + (b)) #define luai_numsub(a, b) ((a) - (b)) @@ -56,5 +55,9 @@ LUAU_FASTMATH_END #define luai_num2unsigned(i, n) ((i) = (unsigned)(long long)(n)) #endif -#define luai_num2str(s, n) snprintf((s), sizeof(s), LUA_NUMBER_FMT, (n)) +#define LUA_NUMBER_FMT "%.14g" /* TODO: Remove with LuauSchubfach */ +#define LUAI_MAXNUM2STR 48 + +LUAI_FUNC char* luai_num2str(char* buf, double n); + #define luai_str2num(s, p) strtod((s), (p)) diff --git a/VM/src/lobject.h b/VM/src/lobject.h index fd0a15b75..b642cf787 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -289,6 +289,7 @@ typedef struct Proto int sizek; int sizelineinfo; int linegaplog2; + int linedefined; uint8_t nups; /* number of upvalues */ diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 0b55fceac..83b59f3fa 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -24,8 +24,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauArrayBoundary, false) - // max size of both array and hash part is 2^MAXBITS #define MAXBITS 26 #define MAXSIZE (1 << MAXBITS) @@ -222,7 +220,7 @@ int luaH_next(lua_State* L, Table* t, StkId key) #define maybesetaboundary(t, boundary) \ { \ - if (FFlag::LuauArrayBoundary && t->aboundary <= 0) \ + if (t->aboundary <= 0) \ t->aboundary = -int(boundary); \ } @@ -705,7 +703,7 @@ int luaH_getn(Table* t) { int boundary = getaboundary(t); - if (FFlag::LuauArrayBoundary && boundary > 0) + if (boundary > 0) { if (!ttisnil(&t->array[t->sizearray - 1]) && t->node == dummynode) return t->sizearray; /* fast-path: the end of the array in `t' already refers to a boundary */ diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index add3588d4..cdb276c06 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -13,6 +13,9 @@ #include +LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Read, false) +LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Force, false) + // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens template struct TempBuffer @@ -146,15 +149,19 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size uint8_t version = read(data, size, offset); // 0 means the rest of the bytecode is the error message - if (version == 0 || version != LBC_VERSION) + if (version == 0) { char chunkid[LUA_IDSIZE]; luaO_chunkid(chunkid, chunkname, LUA_IDSIZE); + lua_pushfstring(L, "%s%.*s", chunkid, int(size - offset), data + offset); + return 1; + } - if (version == 0) - lua_pushfstring(L, "%s%.*s", chunkid, int(size - offset), data + offset); - else - lua_pushfstring(L, "%s: bytecode version mismatch", chunkid); + if (FFlag::LuauBytecodeV2Force ? (version != LBC_VERSION_FUTURE) : FFlag::LuauBytecodeV2Read ? (version != LBC_VERSION && version != LBC_VERSION_FUTURE) : (version != LBC_VERSION)) + { + char chunkid[LUA_IDSIZE]; + luaO_chunkid(chunkid, chunkname, LUA_IDSIZE); + lua_pushfstring(L, "%s: bytecode version mismatch (expected %d, got %d)", chunkid, FFlag::LuauBytecodeV2Force ? LBC_VERSION_FUTURE : LBC_VERSION, version); return 1; } @@ -285,6 +292,11 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size p->p[j] = protos[fid]; } + if (FFlag::LuauBytecodeV2Force || (FFlag::LuauBytecodeV2Read && version == LBC_VERSION_FUTURE)) + p->linedefined = readVarInt(data, size, offset); + else + p->linedefined = -1; + p->debugname = readString(strings, data, size, offset); uint8_t lineinfo = read(data, size, offset); @@ -307,11 +319,11 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size p->lineinfo[j] = lastoffset; } - int lastLine = 0; + int lastline = 0; for (int j = 0; j < intervals; ++j) { - lastLine += read(data, size, offset); - p->abslineinfo[j] = lastLine; + lastline += read(data, size, offset); + p->abslineinfo[j] = lastline; } } diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 5d802277a..31dd59c86 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -34,10 +34,11 @@ int luaV_tostring(lua_State* L, StkId obj) return 0; else { - char s[LUAI_MAXNUMBER2STR]; + char s[LUAI_MAXNUM2STR]; double n = nvalue(obj); - luai_num2str(s, n); - setsvalue2s(L, obj, luaS_new(L, s)); + char* e = luai_num2str(s, n); + LUAU_ASSERT(e < s + sizeof(s)); + setsvalue2s(L, obj, luaS_newlstr(L, s, e - s)); return 1; } } diff --git a/fuzz/number.cpp b/fuzz/number.cpp new file mode 100644 index 000000000..704474096 --- /dev/null +++ b/fuzz/number.cpp @@ -0,0 +1,35 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Luau/Common.h" + +#include +#include +#include +#include + +LUAU_FASTFLAG(LuauSchubfach); + +#define LUAI_MAXNUM2STR 48 + +char* luai_num2str(char* buf, double n); + +extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) +{ + if (Size < 8) + return 0; + + FFlag::LuauSchubfach.value = true; + + double num; + memcpy(&num, Data, 8); + + char buf[LUAI_MAXNUM2STR]; + char* end = luai_num2str(buf, num); + LUAU_ASSERT(end < buf + sizeof(buf)); + + *end = 0; + + double rec = strtod(buf, nullptr); + + LUAU_ASSERT(rec == num || (rec != rec && num != num)); + return 0; +} diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index 2090b0148..41b553b55 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -83,8 +83,6 @@ TEST_SUITE_BEGIN("AstQuery"); TEST_CASE_FIXTURE(Fixture, "last_argument_function_call_type") { - ScopedFastFlag luauTailArgumentTypeInfo{"LuauTailArgumentTypeInfo", true}; - check(R"( local function foo() return 2 end local function bar(a: number) return -a end diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 8ca09c0e0..210db7eea 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) +LUAU_FASTFLAG(LuauUseCommittingTxnLog) using namespace Luau; @@ -1911,11 +1912,14 @@ local bar: @1= foo CHECK(!ac.entryMap.count("foo")); } -TEST_CASE_FIXTURE(ACFixture, "type_correct_function_no_parenthesis") +// Switch back to TEST_CASE_FIXTURE with regular ACFixture when removing the +// LuauUseCommittingTxnLog flag. +TEST_CASE("type_correct_function_no_parenthesis") { - ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + ScopedFastFlag sff_LuauUseCommittingTxnLog = ScopedFastFlag("LuauUseCommittingTxnLog", true); + ACFixture fix; - check(R"( + fix.check(R"( local function target(a: (number) -> number) return a(4) end local function bar1(a: number) return -a end local function bar2(a: string) return a .. 'x' end @@ -1923,7 +1927,7 @@ local function bar2(a: string) return a .. 'x' end return target(b@1 )"); - auto ac = autocomplete('1'); + auto ac = fix.autocomplete('1'); CHECK(ac.entryMap.count("bar1")); CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::Correct); @@ -1976,11 +1980,14 @@ local fp: @1= f CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); } -TEST_CASE_FIXTURE(ACFixture, "type_correct_keywords") +// Switch back to TEST_CASE_FIXTURE with regular ACFixture when removing the +// LuauUseCommittingTxnLog flag. +TEST_CASE("type_correct_keywords") { - ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + ScopedFastFlag sff_LuauUseCommittingTxnLog = ScopedFastFlag("LuauUseCommittingTxnLog", true); + ACFixture fix; - check(R"( + fix.check(R"( local function a(x: boolean) end local function b(x: number?) end local function c(x: (number) -> string) end @@ -1997,26 +2004,26 @@ local dc = d(f@4) local ec = e(f@5) )"); - auto ac = autocomplete('1'); + auto ac = fix.autocomplete('1'); CHECK(ac.entryMap.count("tru")); CHECK(ac.entryMap["tru"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["true"].typeCorrect == TypeCorrectKind::Correct); CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete('2'); + ac = fix.autocomplete('2'); CHECK(ac.entryMap.count("ni")); CHECK(ac.entryMap["ni"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["nil"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete('3'); + ac = fix.autocomplete('3'); CHECK(ac.entryMap.count("false")); CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete('4'); + ac = fix.autocomplete('4'); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); - ac = autocomplete('5'); + ac = fix.autocomplete('5'); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); } @@ -2507,21 +2514,23 @@ local t = { CHECK(ac.entryMap.count("second")); } -TEST_CASE_FIXTURE(UnfrozenFixture, "autocomplete_documentation_symbols") +TEST_CASE("autocomplete_documentation_symbols") { - loadDefinition(R"( + Fixture fix(FFlag::LuauUseCommittingTxnLog); + + fix.loadDefinition(R"( declare y: { x: number, } )"); - fileResolver.source["Module/A"] = R"( + fix.fileResolver.source["Module/A"] = R"( local a = y. )"; - frontend.check("Module/A"); + fix.frontend.check("Module/A"); - auto ac = autocomplete(frontend, "Module/A", Position{1, 21}, nullCallback); + auto ac = autocomplete(fix.frontend, "Module/A", Position{1, 21}, nullCallback); REQUIRE(ac.entryMap.count("x")); CHECK_EQ(ac.entryMap["x"].documentationSymbol, "@test/global/y.x"); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index b055a38e4..663b329ee 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -13,8 +13,11 @@ #include "ScopedFlags.h" #include +#include #include +extern bool verbose; + static int lua_collectgarbage(lua_State* L) { static const char* const opts[] = {"stop", "restart", "collect", "count", "isrunning", "step", "setgoal", "setstepmul", "setstepsize", nullptr}; @@ -146,15 +149,21 @@ static StateRef runConformance(const char* name, void (*setup)(lua_State* L) = n luaL_openlibs(L); // Register a few global functions for conformance tests - static const luaL_Reg funcs[] = { + std::vector funcs = { {"collectgarbage", lua_collectgarbage}, {"loadstring", lua_loadstring}, - {"print", lua_silence}, // Disable print() by default; comment this out to enable debug prints in tests - {nullptr, nullptr}, }; + if (!verbose) + { + funcs.push_back({"print", lua_silence}); + } + + // "null" terminate the list of functions to register + funcs.push_back({nullptr, nullptr}); + lua_pushvalue(L, LUA_GLOBALSINDEX); - luaL_register(L, nullptr, funcs); + luaL_register(L, nullptr, funcs.data()); lua_pop(L, 1); // In some configurations we have a larger C stack consumption which trips some conformance tests @@ -312,8 +321,6 @@ TEST_CASE("GC") TEST_CASE("Bitwise") { - ScopedFastFlag sff("LuauBit32Count", true); - runConformance("bitwise.lua"); } @@ -491,6 +498,9 @@ TEST_CASE("DateTime") TEST_CASE("Debug") { + ScopedFastFlag sffr("LuauBytecodeV2Read", true); + ScopedFastFlag sffw("LuauBytecodeV2Write", true); + runConformance("debug.lua"); } @@ -738,8 +748,6 @@ TEST_CASE("ApiFunctionCalls") // lua_equal with a sleeping thread wake up { - ScopedFastFlag luauActivateBeforeExec("LuauActivateBeforeExec", true); - lua_State* L2 = lua_newthread(L); lua_getfield(L2, LUA_GLOBALSINDEX, "create_with_tm"); @@ -913,4 +921,11 @@ TEST_CASE("Coverage") nullptr, nullptr, &copts); } +TEST_CASE("StringConversion") +{ + ScopedFastFlag sff{"LuauSchubfach", true}; + + runConformance("strconv.lua"); +} + TEST_SUITE_END(); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index 36d6f5612..ca4281a0b 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -103,11 +103,6 @@ Fixture::~Fixture() Luau::resetPrintLine(); } -UnfrozenFixture::UnfrozenFixture() - : Fixture(false) -{ -} - AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& parseOptions) { sourceModule.reset(new SourceModule); diff --git a/tests/Fixture.h b/tests/Fixture.h index de2b7381e..e01632eab 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -152,15 +152,6 @@ struct Fixture LoadDefinitionFileResult loadDefinition(const std::string& source); }; -// Disables arena freezing for a given test case. -// Do not use this in new tests. If you are running into access violations, you -// are violating Luau's memory model - the fix is not to use UnfrozenFixture. -// Related: CLI-45692 -struct UnfrozenFixture : Fixture -{ - UnfrozenFixture(); -}; - ModuleName fromString(std::string_view name); template diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 51fcd3d6c..405f26e07 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -914,6 +914,8 @@ TEST_CASE_FIXTURE(FrontendFixture, "typecheck_twice_for_ast_types") TEST_CASE_FIXTURE(FrontendFixture, "imported_table_modification_2") { + ScopedFastFlag sffs("LuauSealExports", true); + frontend.options.retainFullTypeGraphs = false; fileResolver.source["Module/A"] = R"( @@ -927,7 +929,7 @@ return a; --!nonstrict local a = require(script.Parent.A) local b = {} -function a:b() end -- this should error, but doesn't +function a:b() end -- this should error, since A doesn't define a:b() return b )"; @@ -942,8 +944,7 @@ a:b() -- this should error, since A doesn't define a:b() LUAU_REQUIRE_NO_ERRORS(resultA); CheckResult resultB = frontend.check("Module/B"); - // TODO (CLI-45592): this should error, since we shouldn't be adding properties to objects from other modules - LUAU_REQUIRE_NO_ERRORS(resultB); + LUAU_REQUIRE_ERRORS(resultB); CheckResult resultC = frontend.check("Module/C"); LUAU_REQUIRE_ERRORS(resultC); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 71ff4e1b8..275782b30 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -620,7 +620,7 @@ struct AssertionCatcher { tripped = 0; oldhook = Luau::assertHandler(); - Luau::assertHandler() = [](const char* expr, const char* file, int line) -> int { + Luau::assertHandler() = [](const char* expr, const char* file, int line, const char* function) -> int { ++tripped; return 0; }; diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 506279b94..5e08654a7 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -11,6 +11,8 @@ LUAU_FASTFLAG(LuauFixTonumberReturnType) using namespace Luau; +LUAU_FASTFLAG(LuauUseCommittingTxnLog) + TEST_SUITE_BEGIN("BuiltinTests"); TEST_CASE_FIXTURE(Fixture, "math_things_are_defined") @@ -444,19 +446,28 @@ TEST_CASE_FIXTURE(Fixture, "os_time_takes_optional_date_table") CHECK_EQ(*typeChecker.numberType, *requireType("n3")); } -TEST_CASE_FIXTURE(Fixture, "thread_is_a_type") +// Switch back to TEST_CASE_FIXTURE with regular Fixture when removing the +// LuauUseCommittingTxnLog flag. +TEST_CASE("thread_is_a_type") { - CheckResult result = check(R"( + Fixture fix(FFlag::LuauUseCommittingTxnLog); + + CheckResult result = fix.check(R"( local co = coroutine.create(function() end) )"); - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.threadType, *requireType("co")); + // Replace with LUAU_REQUIRE_NO_ERRORS(result) when using TEST_CASE_FIXTURE. + CHECK(result.errors.size() == 0); + CHECK_EQ(*fix.typeChecker.threadType, *fix.requireType("co")); } -TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") +// Switch back to TEST_CASE_FIXTURE with regular Fixture when removing the +// LuauUseCommittingTxnLog flag. +TEST_CASE("coroutine_resume_anything_goes") { - CheckResult result = check(R"( + Fixture fix(FFlag::LuauUseCommittingTxnLog); + + CheckResult result = fix.check(R"( local function nifty(x, y) print(x, y) local z = coroutine.yield(1, 2) @@ -469,12 +480,17 @@ TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") local answer = coroutine.resume(co, 3) )"); - LUAU_REQUIRE_NO_ERRORS(result); + // Replace with LUAU_REQUIRE_NO_ERRORS(result) when using TEST_CASE_FIXTURE. + CHECK(result.errors.size() == 0); } -TEST_CASE_FIXTURE(Fixture, "coroutine_wrap_anything_goes") +// Switch back to TEST_CASE_FIXTURE with regular Fixture when removing the +// LuauUseCommittingTxnLog flag. +TEST_CASE("coroutine_wrap_anything_goes") { - CheckResult result = check(R"( + Fixture fix(FFlag::LuauUseCommittingTxnLog); + + CheckResult result = fix.check(R"( --!nonstrict local function nifty(x, y) print(x, y) @@ -488,7 +504,8 @@ TEST_CASE_FIXTURE(Fixture, "coroutine_wrap_anything_goes") local answer = f(3) )"); - LUAU_REQUIRE_NO_ERRORS(result); + // Replace with LUAU_REQUIRE_NO_ERRORS(result) when using TEST_CASE_FIXTURE. + CHECK(result.errors.size() == 0); } TEST_CASE_FIXTURE(Fixture, "setmetatable_should_not_mutate_persisted_types") diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index b62044fa9..114679e34 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -629,8 +629,6 @@ return exports TEST_CASE_FIXTURE(Fixture, "instantiated_function_argument_names") { - ScopedFastFlag luauFunctionArgumentNameSize{"LuauFunctionArgumentNameSize", true}; - CheckResult result = check(R"( local function f(a: T, ...: U...) end diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 76324556a..f70f3b1c8 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -4803,8 +4803,6 @@ local bar = foo.nutrition + 100 TEST_CASE_FIXTURE(Fixture, "require_failed_module") { - ScopedFastFlag luauModuleRequireErrorPack{"LuauModuleRequireErrorPack", true}; - fileResolver.source["game/A"] = R"( return unfortunately() )"; diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index f55b46a40..1e790eba9 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -12,6 +12,8 @@ LUAU_FASTFLAG(LuauQuantifyInPlace2); using namespace Luau; +LUAU_FASTFLAG(LuauUseCommittingTxnLog) + struct TryUnifyFixture : Fixture { TypeArena arena; @@ -28,7 +30,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "primitives_unify") TypeVar numberOne{TypeVariant{PrimitiveTypeVar{PrimitiveTypeVar::Number}}}; TypeVar numberTwo = numberOne; - state.tryUnify(&numberOne, &numberTwo); + state.tryUnify(&numberTwo, &numberOne); CHECK(state.errors.empty()); } @@ -41,9 +43,12 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") TypeVar functionTwo{TypeVariant{ FunctionTypeVar(arena.addTypePack({arena.freshType(globalScope->level)}), arena.addTypePack({arena.freshType(globalScope->level)}))}}; - state.tryUnify(&functionOne, &functionTwo); + state.tryUnify(&functionTwo, &functionOne); CHECK(state.errors.empty()); + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); + CHECK_EQ(functionOne, functionTwo); } @@ -61,7 +66,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_functions_are_preserved") TypeVar functionTwoSaved = functionTwo; - state.tryUnify(&functionOne, &functionTwo); + state.tryUnify(&functionTwo, &functionOne); CHECK(!state.errors.empty()); CHECK_EQ(functionOne, functionOneSaved); @@ -80,10 +85,13 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); - state.tryUnify(&tableOne, &tableTwo); + state.tryUnify(&tableTwo, &tableOne); CHECK(state.errors.empty()); + if (FFlag::LuauUseCommittingTxnLog) + state.log.commit(); + CHECK_EQ(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); } @@ -101,11 +109,12 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); - state.tryUnify(&tableOne, &tableTwo); + state.tryUnify(&tableTwo, &tableOne); CHECK_EQ(1, state.errors.size()); - state.log.rollback(); + if (!FFlag::LuauUseCommittingTxnLog) + state.DEPRECATED_log.rollback(); CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); } @@ -170,7 +179,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification") TypePackVar testPack{TypePack{{typeChecker.numberType, typeChecker.stringType}, std::nullopt}}; TypePackVar variadicPack{VariadicTypePack{typeChecker.numberType}}; - state.tryUnify(&variadicPack, &testPack); + state.tryUnify(&testPack, &variadicPack); CHECK(!state.errors.empty()); } @@ -180,7 +189,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_tails_respect_progress") TypePackVar a{TypePack{{typeChecker.numberType, typeChecker.stringType, typeChecker.booleanType, typeChecker.booleanType}}}; TypePackVar b{TypePack{{typeChecker.numberType, typeChecker.stringType}, &variadicPack}}; - state.tryUnify(&a, &b); + state.tryUnify(&b, &a); CHECK(state.errors.empty()); } @@ -214,32 +223,41 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_41095_concat_log_in_sealed_table_unifica CHECK_EQ(toString(result.errors[1]), "Available overloads: ({a}, a) -> (); and ({a}, number, a) -> ()"); } -TEST_CASE_FIXTURE(TryUnifyFixture, "undo_new_prop_on_unsealed_table") +TEST_CASE("undo_new_prop_on_unsealed_table") { ScopedFastFlag flags[] = { {"LuauTableSubtypingVariance2", true}, + // This test makes no sense with a committing TxnLog. + {"LuauUseCommittingTxnLog", false}, }; // I am not sure how to make this happen in Luau code. - TypeId unsealedTable = arena.addType(TableTypeVar{TableState::Unsealed, TypeLevel{}}); - TypeId sealedTable = arena.addType(TableTypeVar{ - {{"prop", Property{getSingletonTypes().numberType}}}, - std::nullopt, - TypeLevel{}, - TableState::Sealed - }); + TryUnifyFixture fix; + + TypeId unsealedTable = fix.arena.addType(TableTypeVar{TableState::Unsealed, TypeLevel{}}); + TypeId sealedTable = + fix.arena.addType(TableTypeVar{{{"prop", Property{getSingletonTypes().numberType}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); const TableTypeVar* ttv = get(unsealedTable); REQUIRE(ttv); - state.tryUnify(unsealedTable, sealedTable); + fix.state.tryUnify(sealedTable, unsealedTable); // To be honest, it's really quite spooky here that we're amending an unsealed table in this case. CHECK(!ttv->props.empty()); - state.log.rollback(); + fix.state.DEPRECATED_log.rollback(); CHECK(ttv->props.empty()); } +TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly") +{ + TypePackId threeNumbers = arena.addTypePack(TypePack{{typeChecker.numberType, typeChecker.numberType, typeChecker.numberType}, std::nullopt}); + TypePackId numberAndFreeTail = arena.addTypePack(TypePack{{typeChecker.numberType}, arena.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})}); + + ErrorVec unifyErrors = state.canUnify(numberAndFreeTail, threeNumbers); + CHECK(unifyErrors.size() == 0); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 3f4420cda..5d37b032a 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -10,6 +10,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauUseCommittingTxnLog) + TEST_SUITE_BEGIN("TypePackTests"); TEST_CASE_FIXTURE(Fixture, "infer_multi_return") @@ -263,10 +265,13 @@ TEST_CASE_FIXTURE(Fixture, "variadic_pack_syntax") CHECK_EQ(toString(requireType("foo")), "(...number) -> ()"); } -// CLI-45791 -TEST_CASE_FIXTURE(UnfrozenFixture, "type_pack_hidden_free_tail_infinite_growth") +// Switch back to TEST_CASE_FIXTURE with regular Fixture when removing the +// LuauUseCommittingTxnLog flag. +TEST_CASE("type_pack_hidden_free_tail_infinite_growth") { - CheckResult result = check(R"( + Fixture fix(FFlag::LuauUseCommittingTxnLog); + + CheckResult result = fix.check(R"( --!nonstrict if _ then _[function(l0)end],l0 = _ @@ -278,7 +283,8 @@ elseif _ then end )"); - LUAU_REQUIRE_ERRORS(result); + // Switch back to LUAU_REQUIRE_ERRORS(result) when using TEST_CASE_FIXTURE. + CHECK(result.errors.size() > 0); } TEST_CASE_FIXTURE(Fixture, "variadic_argument_tail") diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 2357869e9..b54ba9962 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -8,6 +8,7 @@ #include "doctest.h" LUAU_FASTFLAG(LuauEqConstraint) +LUAU_FASTFLAG(LuauUseCommittingTxnLog) using namespace Luau; @@ -282,16 +283,19 @@ local c = b:foo(1, 2) CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); } -TEST_CASE_FIXTURE(UnfrozenFixture, "optional_union_follow") +TEST_CASE("optional_union_follow") { - CheckResult result = check(R"( + Fixture fix(FFlag::LuauUseCommittingTxnLog); + + CheckResult result = fix.check(R"( local y: number? = 2 local x = y local function f(a: number, b: typeof(x), c: typeof(x)) return -a end return f() )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + REQUIRE_EQ(result.errors.size(), 1); + // LUAU_REQUIRE_ERROR_COUNT(1, result); auto acm = get(result.errors[0]); REQUIRE(acm); diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 13db923ed..2e0d149ec 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -185,6 +185,8 @@ TEST_CASE_FIXTURE(Fixture, "UnionTypeVarIterator_with_empty_union") TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") { + ScopedFastFlag sff{"LuauSealExports", true}; + TypeVar ftv11{FreeTypeVar{TypeLevel{}}}; TypePackVar tp24{TypePack{{&ftv11}}}; @@ -261,7 +263,7 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") TypeId result = typeChecker.anyify(typeChecker.globalScope, root, Location{}); - CHECK_EQ("{ f: t1 } where t1 = () -> { f: () -> { f: ({ f: t1 }) -> (), signal: { f: (any) -> () } } }", toString(result)); + CHECK_EQ("{| f: t1 |} where t1 = () -> {| f: () -> {| f: ({| f: t1 |}) -> (), signal: {| f: (any) -> () |} |} |}", toString(result)); } TEST_CASE("tagging_tables") diff --git a/tests/conformance/debug.lua b/tests/conformance/debug.lua index 9cf3c7423..8c96ab335 100644 --- a/tests/conformance/debug.lua +++ b/tests/conformance/debug.lua @@ -98,4 +98,13 @@ assert(quuz(function(...) end) == "0 true") assert(quuz(function(a, b) end) == "2 false") assert(quuz(function(a, b, ...) end) == "2 true") +-- info linedefined & line +function testlinedefined() + local line = debug.info(1, "l") + local linedefined = debug.info(testlinedefined, "l") + assert(linedefined + 1 == line) +end + +testlinedefined() + return 'OK' diff --git a/tests/conformance/strconv.lua b/tests/conformance/strconv.lua new file mode 100644 index 000000000..85ad0295e --- /dev/null +++ b/tests/conformance/strconv.lua @@ -0,0 +1,51 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +-- This file is based on Lua 5.x tests -- https://github.com/lua/lua/tree/master/testes +print("testing string-number conversion") + +-- zero +assert(tostring(0) == "0") +assert(tostring(0/-1) == "-0") + +-- specials +assert(tostring(1/0) == "inf") +assert(tostring(-1/0) == "-inf") +assert(tostring(0/0) == "nan") + +-- integers +assert(tostring(1) == "1") +assert(tostring(42) == "42") +assert(tostring(-4294967296) == "-4294967296") +assert(tostring(9007199254740991) == "9007199254740991") + +-- decimals +assert(tostring(0.5) == "0.5") +assert(tostring(0.1) == "0.1") +assert(tostring(-0.17) == "-0.17") +assert(tostring(math.pi) == "3.141592653589793") + +-- fuzzing corpus +assert(tostring(5.4536123983019448e-311) == "5.453612398302e-311") +assert(tostring(5.4834368411298348e-311) == "5.48343684113e-311") +assert(tostring(4.4154895841930002e-305) == "4.415489584193e-305") +assert(tostring(1125968630513728) == "1125968630513728") +assert(tostring(3.3951932655938423e-313) == "3.3951932656e-313") +assert(tostring(1.625) == "1.625") +assert(tostring(4.9406564584124654e-324) == "5.e-324") +assert(tostring(2.0049288280105384) == "2.0049288280105384") +assert(tostring(3.0517578125e-05) == "0.000030517578125") +assert(tostring(1.383544921875) == "1.383544921875") +assert(tostring(3.0053350932691001) == "3.0053350932691") +assert(tostring(0.0001373291015625) == "0.0001373291015625") +assert(tostring(-1.9490628022799998e+289) == "-1.94906280228e+289") +assert(tostring(-0.00610404721867928) == "-0.00610404721867928") +assert(tostring(0.00014495849609375) == "0.00014495849609375") +assert(tostring(0.453125) == "0.453125") +assert(tostring(-4.2375343999999997e+73) == "-4.2375344e+73") +assert(tostring(1.3202313930270133e-192) == "1.3202313930270133e-192") +assert(tostring(3.6984408976312836e+19) == "36984408976312840000") +assert(tostring(2.0563000527063302) == "2.05630005270633") +assert(tostring(4.8970527433648997e-260) == "4.8970527433649e-260") +assert(tostring(1.62890625) == "1.62890625") +assert(tostring(1.1295093211933533e+65) == "1.1295093211933533e+65") + +return "OK" diff --git a/tests/main.cpp b/tests/main.cpp index ed17070c6..cd24e100f 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -2,6 +2,8 @@ #include "Luau/Common.h" #define DOCTEST_CONFIG_IMPLEMENT +// Our calls to parseOption/parseFlag don't provide a prefix so set the prefix to the empty string. +#define DOCTEST_CONFIG_OPTIONS_PREFIX "" #include "doctest.h" #ifdef _WIN32 @@ -18,6 +20,10 @@ #include +// Indicates if verbose output is enabled. +// Currently, this enables output from lua's 'print', but other verbose output could be enabled eventually. +bool verbose = false; + static bool skipFastFlag(const char* flagName) { if (strncmp(flagName, "Test", 4) == 0) @@ -46,7 +52,7 @@ static bool debuggerPresent() #endif } -static int assertionHandler(const char* expr, const char* file, int line) +static int assertionHandler(const char* expr, const char* file, int line, const char* function) { if (debuggerPresent()) LUAU_DEBUGBREAK(); @@ -235,6 +241,11 @@ int main(int argc, char** argv) return 0; } + if (doctest::parseFlag(argc, argv, "--verbose")) + { + verbose = true; + } + if (std::vector flags; doctest::parseCommaSepArgs(argc, argv, "--fflags=", flags)) setFastFlags(flags); @@ -261,7 +272,15 @@ int main(int argc, char** argv) } } - return context.run(); + int result = context.run(); + if (doctest::parseFlag(argc, argv, "--help") || doctest::parseFlag(argc, argv, "-h")) + { + printf("Additional command line options:\n"); + printf(" --verbose Enables verbose output (e.g. lua 'print' statements)\n"); + printf(" --fflags= Sets specified fast flags\n"); + printf(" --list-fflags List all fast flags\n"); + } + return result; } diff --git a/tools/numprint.py b/tools/numprint.py new file mode 100644 index 000000000..47ad36d9c --- /dev/null +++ b/tools/numprint.py @@ -0,0 +1,82 @@ +#!/usr/bin/python +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# This code can be used to generate power tables for Schubfach algorithm (see lnumprint.cpp) + +import math +import sys + +(_, pow10min, pow10max, compact) = sys.argv +pow10min = int(pow10min) +pow10max = int(pow10max) +compact = compact == "True" + +# extract high 128 bits of the value +def high128(n, roundup): + L = math.ceil(math.log2(n)) + + r = 0 + for i in range(L - 128, L): + if i >= 0 and (n & (1 << i)) != 0: + r |= (1 << (i - L + 128)) + + return r + (1 if roundup else 0) + +def pow10approx(n): + if n == 0: + return 1 << 127 + elif n > 0: + return high128(10**n, 5**n >= 2**128) + else: + # 10^-n is a binary fraction that can't be represented in floating point + # we need to extract top 128 bits of the fraction starting from the first 1 + # to get there, we need to divide 2^k by 10^n for a sufficiently large k and repeat the extraction process + p = 10**-n + k = 2**128 * 16**-n # this guarantees that the fraction has more than 128 extra bits + return high128(k // p, True) + +def pow5_64(n): + assert(n >= 0) + if n == 0: + return 1 << 63 + else: + return high128(5**n, False) >> 64 + +if not compact: + print("// kPow10Table", pow10min, "..", pow10max) + print("{") + for p in range(pow10min, pow10max + 1): + h = hex(pow10approx(p))[2:] + assert(len(h) == 32) + print(" {0x%s, 0x%s}," % (h[0:16].upper(), h[16:32].upper())) + print("}") +else: + print("// kPow5Table") + print("{") + for i in range(16): + print(" " + hex(pow5_64(i)) + ",") + print("}") + print("// kPow10Table", pow10min, "..", pow10max) + print("{") + for p in range(pow10min, pow10max + 1, 16): + base = pow10approx(p) + errw = 0 + for i in range(16): + real = pow10approx(p + i) + appr = (base * pow5_64(i)) >> 64 + scale = 1 if appr < (1 << 127) else 0 # 1-bit scale + + offset = (appr << scale) - real + assert(offset >= -4 and offset <= 3) # 3-bit offset + assert((appr << scale) >> 64 == real >> 64) # offset only affects low half + assert((appr << scale) - offset == real) # validate full reconstruction + + err = (scale << 3) | (offset + 4) + errw |= err << (i * 4) + + hbase = hex(base)[2:] + assert(len(hbase) == 32) + assert(errw < 1 << 64) + + print(" {0x%s, 0x%s, 0x%16x}," % (hbase[0:16], hbase[16:32], errw)) + print("}") From d189bd9b1a3c6ecc882c793c7382a1c886635900 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 6 Jan 2022 17:37:50 -0800 Subject: [PATCH 13/32] Enable V2Read flag early --- VM/src/lvmload.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index cdb276c06..7839c68c2 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -13,7 +13,7 @@ #include -LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Read, false) +LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Read, true) LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Force, false) // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens From 80d5c0000ee34767f8fec7ec24c02a438174a667 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 14 Jan 2022 08:06:31 -0800 Subject: [PATCH 14/32] Sync to upstream/release/510 --- Analysis/include/Luau/TypeInfer.h | 12 +- Analysis/include/Luau/TypeVar.h | 20 +- Analysis/src/Autocomplete.cpp | 79 +-- Analysis/src/Error.cpp | 12 +- Analysis/src/IostreamHelpers.cpp | 8 +- Analysis/src/JsonEncoder.cpp | 38 ++ Analysis/src/Module.cpp | 26 +- Analysis/src/ToString.cpp | 107 +++- Analysis/src/Transpiler.cpp | 65 +- Analysis/src/TypeAttach.cpp | 12 +- Analysis/src/TypeInfer.cpp | 310 ++++++++-- Analysis/src/TypeVar.cpp | 4 + Analysis/src/Unifier.cpp | 17 +- Ast/include/Luau/Ast.h | 49 +- Ast/include/Luau/Parser.h | 6 +- Ast/src/Ast.cpp | 32 +- Ast/src/Parser.cpp | 128 ++-- CLI/Analyze.cpp | 18 +- CLI/Repl.cpp | 131 ++-- Compiler/src/Builtins.cpp | 197 ++++++ Compiler/src/Builtins.h | 41 ++ Compiler/src/BytecodeBuilder.cpp | 6 +- Compiler/src/Compiler.cpp | 894 ++------------------------- Compiler/src/ConstantFolding.cpp | 394 ++++++++++++ Compiler/src/ConstantFolding.h | 48 ++ Compiler/src/TableShape.cpp | 129 ++++ Compiler/src/TableShape.h | 21 + Compiler/src/ValueTracking.cpp | 103 +++ Compiler/src/ValueTracking.h | 42 ++ Sources.cmake | 8 + VM/src/ldo.cpp | 12 +- VM/src/lfunc.cpp | 8 +- VM/src/lstrlib.cpp | 20 +- bench/tests/sunspider/3d-cube.lua | 32 +- tests/Autocomplete.test.cpp | 61 +- tests/Compiler.test.cpp | 93 +-- tests/Conformance.test.cpp | 8 - tests/Fixture.cpp | 5 +- tests/Fixture.h | 2 +- tests/Parser.test.cpp | 53 +- tests/ToString.test.cpp | 22 +- tests/Transpiler.test.cpp | 14 +- tests/TypeInfer.annotations.test.cpp | 2 +- tests/TypeInfer.refinements.test.cpp | 28 +- tests/TypeInfer.singletons.test.cpp | 26 +- tests/TypeInfer.tables.test.cpp | 101 ++- tests/TypeInfer.test.cpp | 150 ++++- tests/TypeInfer.typePacks.cpp | 324 ++++++++++ 48 files changed, 2649 insertions(+), 1269 deletions(-) create mode 100644 Compiler/src/Builtins.cpp create mode 100644 Compiler/src/Builtins.h create mode 100644 Compiler/src/ConstantFolding.cpp create mode 100644 Compiler/src/ConstantFolding.h create mode 100644 Compiler/src/TableShape.cpp create mode 100644 Compiler/src/TableShape.h create mode 100644 Compiler/src/ValueTracking.cpp create mode 100644 Compiler/src/ValueTracking.h diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 312283b05..aa0900140 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -97,6 +97,12 @@ struct ApplyTypeFunction : Substitution TypePackId clean(TypePackId tp) override; }; +struct GenericTypeDefinitions +{ + std::vector genericTypes; + std::vector genericPacks; +}; + // All TypeVars are retained via Environment::typeVars. All TypeIds // within a program are borrowed pointers into this set. struct TypeChecker @@ -146,7 +152,7 @@ struct TypeChecker ExprResult checkExpr(const ScopePtr& scope, const AstExprBinary& expr); ExprResult checkExpr(const ScopePtr& scope, const AstExprTypeAssertion& expr); ExprResult checkExpr(const ScopePtr& scope, const AstExprError& expr); - ExprResult checkExpr(const ScopePtr& scope, const AstExprIfElse& expr); + ExprResult checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType = std::nullopt); TypeId checkExprTable(const ScopePtr& scope, const AstExprTable& expr, const std::vector>& fieldTypes, std::optional expectedType); @@ -336,8 +342,8 @@ struct TypeChecker const std::vector& typePackParams, const Location& location); // Note: `scope` must be a fresh scope. - std::pair, std::vector> createGenericTypes(const ScopePtr& scope, std::optional levelOpt, - const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames); + GenericTypeDefinitions createGenericTypes(const ScopePtr& scope, std::optional levelOpt, const AstNode& node, + const AstArray& genericNames, const AstArray& genericPackNames); public: ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index d6e177142..fd2c2afa7 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -181,6 +181,18 @@ const T* get(const SingletonTypeVar* stv) return nullptr; } +struct GenericTypeDefinition +{ + TypeId ty; + std::optional defaultValue; +}; + +struct GenericTypePackDefinition +{ + TypePackId tp; + std::optional defaultValue; +}; + struct FunctionArgument { Name name; @@ -358,8 +370,8 @@ struct ClassTypeVar struct TypeFun { // These should all be generic - std::vector typeParams; - std::vector typePackParams; + std::vector typeParams; + std::vector typePackParams; /** The underlying type. * @@ -369,13 +381,13 @@ struct TypeFun TypeId type; TypeFun() = default; - TypeFun(std::vector typeParams, TypeId type) + TypeFun(std::vector typeParams, TypeId type) : typeParams(std::move(typeParams)) , type(type) { } - TypeFun(std::vector typeParams, std::vector typePackParams, TypeId type) + TypeFun(std::vector typeParams, std::vector typePackParams, TypeId type) : typeParams(std::move(typeParams)) , typePackParams(std::move(typePackParams)) , type(type) diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 67ebd0755..7a801f970 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -13,11 +13,10 @@ #include LUAU_FASTFLAG(LuauUseCommittingTxnLog) -LUAU_FASTFLAG(LuauIfElseExpressionAnalysisSupport) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); -LUAU_FASTFLAGVARIABLE(LuauAutocompletePreferToCallFunctions, false); LUAU_FASTFLAGVARIABLE(LuauAutocompleteFirstArg, false); LUAU_FASTFLAGVARIABLE(LuauCompleteBrokenStringParams, false); +LUAU_FASTFLAGVARIABLE(LuauMissingFollowACMetatables, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -291,51 +290,23 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ expectedType = follow(*it); } - if (FFlag::LuauAutocompletePreferToCallFunctions) + // We also want to suggest functions that return compatible result + if (const FunctionTypeVar* ftv = get(ty)) { - // We also want to suggest functions that return compatible result - if (const FunctionTypeVar* ftv = get(ty)) - { - auto [retHead, retTail] = flatten(ftv->retType); - - if (!retHead.empty() && canUnify(retHead.front(), expectedType)) - return TypeCorrectKind::CorrectFunctionResult; - - // We might only have a variadic tail pack, check if the element is compatible - if (retTail) - { - if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) - return TypeCorrectKind::CorrectFunctionResult; - } - } - - return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; - } - else - { - if (canUnify(ty, expectedType)) - return TypeCorrectKind::Correct; - - // We also want to suggest functions that return compatible result - const FunctionTypeVar* ftv = get(ty); - - if (!ftv) - return TypeCorrectKind::None; - auto [retHead, retTail] = flatten(ftv->retType); - if (!retHead.empty()) - return canUnify(retHead.front(), expectedType) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + if (!retHead.empty() && canUnify(retHead.front(), expectedType)) + return TypeCorrectKind::CorrectFunctionResult; // We might only have a variadic tail pack, check if the element is compatible if (retTail) { - if (const VariadicTypePack* vtp = get(follow(*retTail))) - return canUnify(vtp->ty, expectedType) ? TypeCorrectKind::CorrectFunctionResult : TypeCorrectKind::None; + if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) + return TypeCorrectKind::CorrectFunctionResult; } - - return TypeCorrectKind::None; } + + return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; } enum class PropIndexType @@ -435,13 +406,28 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId auto indexIt = mtable->props.find("__index"); if (indexIt != mtable->props.end()) { - if (get(indexIt->second.type) || get(indexIt->second.type)) - autocompleteProps(module, typeArena, indexIt->second.type, indexType, nodes, result, seen); - else if (auto indexFunction = get(indexIt->second.type)) + if (FFlag::LuauMissingFollowACMetatables) { - std::optional indexFunctionResult = first(indexFunction->retType); - if (indexFunctionResult) - autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); + TypeId followed = follow(indexIt->second.type); + if (get(followed) || get(followed)) + autocompleteProps(module, typeArena, followed, indexType, nodes, result, seen); + else if (auto indexFunction = get(followed)) + { + std::optional indexFunctionResult = first(indexFunction->retType); + if (indexFunctionResult) + autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); + } + } + else + { + if (get(indexIt->second.type) || get(indexIt->second.type)) + autocompleteProps(module, typeArena, indexIt->second.type, indexType, nodes, result, seen); + else if (auto indexFunction = get(indexIt->second.type)) + { + std::optional indexFunctionResult = first(indexFunction->retType); + if (indexFunctionResult) + autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); + } } } } @@ -1224,7 +1210,7 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul if (auto it = module.astTypes.find(node->asExpr())) autocompleteProps(module, typeArena, *it, PropIndexType::Point, ancestry, result); } - else if (FFlag::LuauIfElseExpressionAnalysisSupport && autocompleteIfElseExpression(node, ancestry, position, result)) + else if (autocompleteIfElseExpression(node, ancestry, position, result)) return; else if (node->is()) return; @@ -1261,8 +1247,7 @@ static void autocompleteExpression(const SourceModule& sourceModule, const Modul TypeCorrectKind correctForFunction = functionIsExpectedAt(module, node, position).value_or(false) ? TypeCorrectKind::Correct : TypeCorrectKind::None; - if (FFlag::LuauIfElseExpressionAnalysisSupport) - result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; + result["if"] = {AutocompleteEntryKind::Keyword, std::nullopt, false, false}; result["true"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; result["false"] = {AutocompleteEntryKind::Keyword, typeChecker.booleanType, false, false, correctForBoolean}; result["nil"] = {AutocompleteEntryKind::Keyword, typeChecker.nilType, false, false, correctForNil}; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index ce832c6b3..88069f1f5 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -190,24 +190,24 @@ struct ErrorConverter { name += "<"; bool first = true; - for (TypeId t : e.typeFun.typeParams) + for (auto param : e.typeFun.typeParams) { if (first) first = false; else name += ", "; - name += toString(t); + name += toString(param.ty); } - for (TypePackId t : e.typeFun.typePackParams) + for (auto param : e.typeFun.typePackParams) { if (first) first = false; else name += ", "; - name += toString(t); + name += toString(param.tp); } name += ">"; @@ -544,13 +544,13 @@ bool IncorrectGenericParameterCount::operator==(const IncorrectGenericParameterC for (size_t i = 0; i < typeFun.typeParams.size(); ++i) { - if (typeFun.typeParams[i] != rhs.typeFun.typeParams[i]) + if (typeFun.typeParams[i].ty != rhs.typeFun.typeParams[i].ty) return false; } for (size_t i = 0; i < typeFun.typePackParams.size(); ++i) { - if (typeFun.typePackParams[i] != rhs.typeFun.typePackParams[i]) + if (typeFun.typePackParams[i].tp != rhs.typeFun.typePackParams[i].tp) return false; } diff --git a/Analysis/src/IostreamHelpers.cpp b/Analysis/src/IostreamHelpers.cpp index 5bc76ade5..19c2ddabd 100644 --- a/Analysis/src/IostreamHelpers.cpp +++ b/Analysis/src/IostreamHelpers.cpp @@ -96,24 +96,24 @@ std::ostream& operator<<(std::ostream& stream, const IncorrectGenericParameterCo { stream << "<"; bool first = true; - for (TypeId t : error.typeFun.typeParams) + for (auto param : error.typeFun.typeParams) { if (first) first = false; else stream << ", "; - stream << toString(t); + stream << toString(param.ty); } - for (TypePackId t : error.typeFun.typePackParams) + for (auto param : error.typeFun.typePackParams) { if (first) first = false; else stream << ", "; - stream << toString(t); + stream << toString(param.tp); } stream << ">"; diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index 23491a5a1..8dd597e17 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -5,6 +5,8 @@ #include "Luau/StringUtils.h" #include "Luau/Common.h" +LUAU_FASTFLAG(LuauTypeAliasDefaults) + namespace Luau { @@ -337,6 +339,42 @@ struct AstJsonEncoder : public AstVisitor writeRaw("}"); } + void write(const AstGenericType& genericType) + { + if (FFlag::LuauTypeAliasDefaults) + { + writeRaw("{"); + bool c = pushComma(); + write("name", genericType.name); + if (genericType.defaultValue) + write("type", genericType.defaultValue); + popComma(c); + writeRaw("}"); + } + else + { + write(genericType.name); + } + } + + void write(const AstGenericTypePack& genericTypePack) + { + if (FFlag::LuauTypeAliasDefaults) + { + writeRaw("{"); + bool c = pushComma(); + write("name", genericTypePack.name); + if (genericTypePack.defaultValue) + write("type", genericTypePack.defaultValue); + popComma(c); + writeRaw("}"); + } + else + { + write(genericTypePack.name); + } + } + void write(AstExprTable::Item::Kind kind) { switch (kind) diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index cff85897c..9f352f4b3 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) +LUAU_FASTFLAG(LuauTypeAliasDefaults) namespace Luau { @@ -447,11 +448,28 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks& seenTypePacks, CloneState& cloneState) { TypeFun result; - for (TypeId ty : typeFun.typeParams) - result.typeParams.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); - for (TypePackId tp : typeFun.typePackParams) - result.typePackParams.push_back(clone(tp, dest, seenTypes, seenTypePacks, cloneState)); + for (auto param : typeFun.typeParams) + { + TypeId ty = clone(param.ty, dest, seenTypes, seenTypePacks, cloneState); + std::optional defaultValue; + + if (FFlag::LuauTypeAliasDefaults && param.defaultValue) + defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); + + result.typeParams.push_back({ty, defaultValue}); + } + + for (auto param : typeFun.typePackParams) + { + TypePackId tp = clone(param.tp, dest, seenTypes, seenTypePacks, cloneState); + std::optional defaultValue; + + if (FFlag::LuauTypeAliasDefaults && param.defaultValue) + defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); + + result.typePackParams.push_back({tp, defaultValue}); + } result.type = clone(typeFun.type, dest, seenTypes, seenTypePacks, cloneState); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 889dd6dc5..4b898d3a6 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -11,6 +11,7 @@ #include LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) +LUAU_FASTFLAG(LuauTypeAliasDefaults) /* * Prefix generic typenames with gen- @@ -209,6 +210,14 @@ struct StringifierState result.name += s; } + + void emit(const char* s) + { + if (opts.maxTypeLength > 0 && result.name.length() > opts.maxTypeLength) + return; + + result.name += s; + } }; struct TypeVarStringifier @@ -280,13 +289,28 @@ struct TypeVarStringifier else first = false; - if (!singleTp) - state.emit("("); + if (FFlag::LuauTypeAliasDefaults) + { + bool wrap = !singleTp && get(follow(tp)); - stringify(tp); + if (wrap) + state.emit("("); - if (!singleTp) - state.emit(")"); + stringify(tp); + + if (wrap) + state.emit(")"); + } + else + { + if (!singleTp) + state.emit("("); + + stringify(tp); + + if (!singleTp) + state.emit(")"); + } } if (types.size() || typePacks.size()) @@ -1086,7 +1110,7 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts) return toString(const_cast(&tp), std::move(opts)); } -std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) +std::string toStringNamedFunction_DEPRECATED(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) { std::string s = prefix; @@ -1175,6 +1199,77 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV return s; } +std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) +{ + if (!FFlag::LuauTypeAliasDefaults) + return toStringNamedFunction_DEPRECATED(prefix, ftv, opts); + + ToStringResult result; + StringifierState state(opts, result, opts.nameMap); + TypeVarStringifier tvs{state}; + + state.emit(prefix); + + if (!opts.hideNamedFunctionTypeParameters) + tvs.stringify(ftv.generics, ftv.genericPacks); + + state.emit("("); + + auto argPackIter = begin(ftv.argTypes); + auto argNameIter = ftv.argNames.begin(); + + bool first = true; + while (argPackIter != end(ftv.argTypes)) + { + if (!first) + state.emit(", "); + first = false; + + // We don't currently respect opts.functionTypeArguments. I don't think this function should. + if (argNameIter != ftv.argNames.end()) + { + state.emit((*argNameIter ? (*argNameIter)->name : "_") + ": "); + ++argNameIter; + } + else + { + state.emit("_: "); + } + + tvs.stringify(*argPackIter); + ++argPackIter; + } + + if (argPackIter.tail()) + { + if (!first) + state.emit(", "); + + state.emit("...: "); + + if (auto vtp = get(*argPackIter.tail())) + tvs.stringify(vtp->ty); + else + tvs.stringify(*argPackIter.tail()); + } + + state.emit("): "); + + size_t retSize = size(ftv.retType); + bool hasTail = !finite(ftv.retType); + bool wrap = get(follow(ftv.retType)) && (hasTail ? retSize != 0 : retSize != 1); + + if (wrap) + state.emit("("); + + tvs.stringify(ftv.retType); + + if (wrap) + state.emit(")"); + + return result.name; +} + std::string dump(TypeId ty) { ToStringOptions opts; diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 8e13ea5be..f59086834 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -10,6 +10,8 @@ #include #include +LUAU_FASTFLAG(LuauTypeAliasDefaults) + namespace { bool isIdentifierStartChar(char c) @@ -793,14 +795,47 @@ struct Printer for (auto o : a->generics) { comma(); - writer.identifier(o.value); + + if (FFlag::LuauTypeAliasDefaults) + { + writer.advance(o.location.begin); + writer.identifier(o.name.value); + + if (o.defaultValue) + { + writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.symbol("="); + visualizeTypeAnnotation(*o.defaultValue); + } + } + else + { + writer.identifier(o.name.value); + } } for (auto o : a->genericPacks) { comma(); - writer.identifier(o.value); - writer.symbol("..."); + + if (FFlag::LuauTypeAliasDefaults) + { + writer.advance(o.location.begin); + writer.identifier(o.name.value); + writer.symbol("..."); + + if (o.defaultValue) + { + writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.symbol("="); + visualizeTypePackAnnotation(*o.defaultValue, false); + } + } + else + { + writer.identifier(o.name.value); + writer.symbol("..."); + } } writer.symbol(">"); @@ -846,12 +881,20 @@ struct Printer for (const auto& o : func.generics) { comma(); - writer.identifier(o.value); + + if (FFlag::LuauTypeAliasDefaults) + writer.advance(o.location.begin); + + writer.identifier(o.name.value); } for (const auto& o : func.genericPacks) { comma(); - writer.identifier(o.value); + + if (FFlag::LuauTypeAliasDefaults) + writer.advance(o.location.begin); + + writer.identifier(o.name.value); writer.symbol("..."); } writer.symbol(">"); @@ -979,12 +1022,20 @@ struct Printer for (const auto& o : a->generics) { comma(); - writer.identifier(o.value); + + if (FFlag::LuauTypeAliasDefaults) + writer.advance(o.location.begin); + + writer.identifier(o.name.value); } for (const auto& o : a->genericPacks) { comma(); - writer.identifier(o.value); + + if (FFlag::LuauTypeAliasDefaults) + writer.advance(o.location.begin); + + writer.identifier(o.name.value); writer.symbol("..."); } writer.symbol(">"); diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 9e61c7924..2ec020937 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -212,24 +212,24 @@ class TypeRehydrationVisitor if (hasSeen(&ftv)) return allocator->alloc(Location(), std::nullopt, AstName("")); - AstArray generics; + AstArray generics; generics.size = ftv.generics.size(); - generics.data = static_cast(allocator->allocate(sizeof(AstName) * generics.size)); + generics.data = static_cast(allocator->allocate(sizeof(AstGenericType) * generics.size)); size_t numGenerics = 0; for (auto it = ftv.generics.begin(); it != ftv.generics.end(); ++it) { if (auto gtv = get(*it)) - generics.data[numGenerics++] = AstName(gtv->name.c_str()); + generics.data[numGenerics++] = {AstName(gtv->name.c_str()), Location(), nullptr}; } - AstArray genericPacks; + AstArray genericPacks; genericPacks.size = ftv.genericPacks.size(); - genericPacks.data = static_cast(allocator->allocate(sizeof(AstName) * genericPacks.size)); + genericPacks.data = static_cast(allocator->allocate(sizeof(AstGenericTypePack) * genericPacks.size)); size_t numGenericPacks = 0; for (auto it = ftv.genericPacks.begin(); it != ftv.genericPacks.end(); ++it) { if (auto gtv = get(*it)) - genericPacks.data[numGenericPacks++] = AstName(gtv->name.c_str()); + genericPacks.data[numGenericPacks++] = {AstName(gtv->name.c_str()), Location(), nullptr}; } AstArray argTypes; diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 1689a5c3d..bedcc0227 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -15,8 +15,8 @@ #include "Luau/TypeVar.h" #include "Luau/TimeTrace.h" -#include #include +#include LUAU_FASTFLAGVARIABLE(DebugLuauMagicTypes, false) LUAU_FASTINTVARIABLE(LuauTypeInferRecursionLimit, 500) @@ -24,25 +24,30 @@ LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) +LUAU_FASTFLAGVARIABLE(LuauGroupExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) -LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionAnalysisSupport, false) +LUAU_FASTFLAGVARIABLE(LuauIfElseBranchTypeUnion, false) +LUAU_FASTFLAGVARIABLE(LuauIfElseExpectedType2, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) +LUAU_FASTFLAGVARIABLE(LuauTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauLValueAsKey, false) LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false) +LUAU_FASTFLAGVARIABLE(LuauPerModuleUnificationCache, false) LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) LUAU_FASTFLAGVARIABLE(LuauFixRecursiveMetatableCall, false) LUAU_FASTFLAGVARIABLE(LuauBidirectionalAsExpr, false) +LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAGVARIABLE(LuauUpdateFunctionNameBinding, false) namespace Luau @@ -279,6 +284,14 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona GenericError{"Free types leaked into this module's public interface. This is an internal Luau error; please report it."}}); } + if (FFlag::LuauPerModuleUnificationCache) + { + // Clear unifier cache since it's keyed off internal types that get deallocated + // This avoids fake cross-module cache hits and keeps cache size at bay when typechecking large module graphs. + unifierState.cachedUnify.clear(); + unifierState.skipCacheForType.clear(); + } + return std::move(currentModule); } @@ -1213,18 +1226,18 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias ScopePtr aliasScope = childScope(scope, typealias.location); aliasScope->level = scope->level.incr(); - for (TypeId ty : binding->typeParams) + for (auto param : binding->typeParams) { - auto generic = get(ty); + auto generic = get(param.ty); LUAU_ASSERT(generic); - aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, ty}; + aliasScope->privateTypeBindings[generic->name] = TypeFun{{}, param.ty}; } - for (TypePackId tp : binding->typePackParams) + for (auto param : binding->typePackParams) { - auto generic = get(tp); + auto generic = get(param.tp); LUAU_ASSERT(generic); - aliasScope->privateTypePackBindings[generic->name] = tp; + aliasScope->privateTypePackBindings[generic->name] = param.tp; } TypeId ty = resolveType(aliasScope, *typealias.type); @@ -1233,9 +1246,17 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias // If the table is already named and we want to rename the type function, we have to bind new alias to a copy if (ttv->name) { + bool sameTys = std::equal(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), binding->typeParams.begin(), + binding->typeParams.end(), [](auto&& itp, auto&& tp) { + return itp == tp.ty; + }); + bool sameTps = std::equal(ttv->instantiatedTypePackParams.begin(), ttv->instantiatedTypePackParams.end(), + binding->typePackParams.begin(), binding->typePackParams.end(), [](auto&& itpp, auto&& tpp) { + return itpp == tpp.tp; + }); + // Copy can be skipped if this is an identical alias - if (ttv->name != name || ttv->instantiatedTypeParams != binding->typeParams || - ttv->instantiatedTypePackParams != binding->typePackParams) + if (ttv->name != name || !sameTys || !sameTps) { // This is a shallow clone, original recursive links to self are not updated TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; @@ -1243,8 +1264,12 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias clone.methodDefinitionLocations = ttv->methodDefinitionLocations; clone.definitionModuleName = ttv->definitionModuleName; clone.name = name; - clone.instantiatedTypeParams = binding->typeParams; - clone.instantiatedTypePackParams = binding->typePackParams; + + for (auto param : binding->typeParams) + clone.instantiatedTypeParams.push_back(param.ty); + + for (auto param : binding->typePackParams) + clone.instantiatedTypePackParams.push_back(param.tp); ty = addType(std::move(clone)); } @@ -1252,8 +1277,14 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias else { ttv->name = name; - ttv->instantiatedTypeParams = binding->typeParams; - ttv->instantiatedTypePackParams = binding->typePackParams; + + ttv->instantiatedTypeParams.clear(); + for (auto param : binding->typeParams) + ttv->instantiatedTypeParams.push_back(param.ty); + + ttv->instantiatedTypePackParams.clear(); + for (auto param : binding->typePackParams) + ttv->instantiatedTypePackParams.push_back(param.tp); } } else if (auto mtv = getMutable(follow(ty))) @@ -1367,9 +1398,21 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, global, global.generics, global.genericPacks); + std::vector genericTys; + genericTys.reserve(generics.size()); + std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { + return el.ty; + }); + + std::vector genericTps; + genericTps.reserve(genericPacks.size()); + std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { + return el.tp; + }); + TypePackId argPack = resolveTypePack(funScope, global.params); TypePackId retPack = resolveTypePack(funScope, global.retTypes); - TypeId fnType = addType(FunctionTypeVar{funScope->level, generics, genericPacks, argPack, retPack}); + TypeId fnType = addType(FunctionTypeVar{funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack}); FunctionTypeVar* ftv = getMutable(fnType); ftv->argNames.reserve(global.paramNames.size); @@ -1394,7 +1437,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& ExprResult result; if (auto a = expr.as()) - result = checkExpr(scope, *a->expr); + result = checkExpr(scope, *a->expr, FFlag::LuauGroupExpectedType ? expectedType : std::nullopt); else if (expr.is()) result = {nilType}; else if (const AstExprConstantBool* bexpr = expr.as()) @@ -1438,21 +1481,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) - { - if (FFlag::LuauIfElseExpressionAnalysisSupport) - { - result = checkExpr(scope, *a); - } - else - { - // Note: When the fast flag is disabled we can't skip the handling of AstExprIfElse - // because we would generate an ICE. We also can't use the default value - // of result, because it will lead to a compiler crash. - // Note: LuauIfElseExpressionBaseSupport can be used to disable parser support - // for if-else expressions which will mean this node type is never created. - result = {anyType}; - } - } + result = checkExpr(scope, *a, FFlag::LuauIfElseExpectedType2 ? expectedType : std::nullopt); else ice("Unhandled AstExpr?"); @@ -1895,7 +1924,7 @@ TypeId TypeChecker::checkExprTable( } } - TableState state = (expr.items.size == 0 || isNonstrictMode()) ? TableState::Unsealed : TableState::Sealed; + TableState state = (expr.items.size == 0 || isNonstrictMode() || FFlag::LuauUnsealedTableLiteral) ? TableState::Unsealed : TableState::Sealed; TableTypeVar table = TableTypeVar{std::move(props), indexer, scope->level, state}; table.definitionModuleName = currentModuleName; return addType(table); @@ -2549,23 +2578,34 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprEr return {errorRecoveryType(scope)}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr) +ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIfElse& expr, std::optional expectedType) { ExprResult result = checkExpr(scope, *expr.condition); ScopePtr trueScope = childScope(scope, expr.trueExpr->location); reportErrors(resolve(result.predicates, trueScope, true)); - ExprResult trueType = checkExpr(trueScope, *expr.trueExpr); + ExprResult trueType = checkExpr(trueScope, *expr.trueExpr, expectedType); ScopePtr falseScope = childScope(scope, expr.falseExpr->location); // Don't report errors for this scope to avoid potentially duplicating errors reported for the first scope. resolve(result.predicates, falseScope, false); - ExprResult falseType = checkExpr(falseScope, *expr.falseExpr); + ExprResult falseType = checkExpr(falseScope, *expr.falseExpr, expectedType); - unify(falseType.type, trueType.type, expr.location); + if (FFlag::LuauIfElseBranchTypeUnion) + { + if (falseType.type == trueType.type) + return {trueType.type}; + + std::vector types = reduceUnion({trueType.type, falseType.type}); + return {types.size() == 1 ? types[0] : addType(UnionTypeVar{std::move(types)})}; + } + else + { + unify(falseType.type, trueType.type, expr.location); - // TODO: normalize(UnionTypeVar{{trueType, falseType}}) - // For now both trueType and falseType must be the same type. - return {trueType.type}; + // TODO: normalize(UnionTypeVar{{trueType, falseType}}) + // For now both trueType and falseType must be the same type. + return {trueType.type}; + } } TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr) @@ -3032,7 +3072,20 @@ std::pair TypeChecker::checkFunctionSignature( defn.varargLocation = expr.vararg ? std::make_optional(expr.varargLocation) : std::nullopt; defn.originalNameLocation = originalName.value_or(Location(expr.location.begin, 0)); - TypeId funTy = addType(FunctionTypeVar(funScope->level, generics, genericPacks, argPack, retPack, std::move(defn), bool(expr.self))); + std::vector genericTys; + genericTys.reserve(generics.size()); + std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { + return el.ty; + }); + + std::vector genericTps; + genericTps.reserve(genericPacks.size()); + std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { + return el.tp; + }); + + TypeId funTy = + addType(FunctionTypeVar(funScope->level, std::move(genericTys), std::move(genericTps), argPack, retPack, std::move(defn), bool(expr.self))); FunctionTypeVar* ftv = getMutable(funTy); @@ -4848,11 +4901,38 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (lit->parameters.size == 0 && tf->typeParams.empty() && tf->typePackParams.empty()) return tf->type; - if (!lit->hasParameterList && !tf->typePackParams.empty()) + bool hasDefaultTypes = false; + bool hasDefaultPacks = false; + bool parameterCountErrorReported = false; + + if (FFlag::LuauTypeAliasDefaults) { - reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); - if (!FFlag::LuauErrorRecoveryType) - return errorRecoveryType(scope); + hasDefaultTypes = std::any_of(tf->typeParams.begin(), tf->typeParams.end(), [](auto&& el) { + return el.defaultValue.has_value(); + }); + hasDefaultPacks = std::any_of(tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& el) { + return el.defaultValue.has_value(); + }); + + if (!lit->hasParameterList) + { + if ((!tf->typeParams.empty() && !hasDefaultTypes) || (!tf->typePackParams.empty() && !hasDefaultPacks)) + { + reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); + parameterCountErrorReported = true; + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); + } + } + } + else + { + if (!lit->hasParameterList && !tf->typePackParams.empty()) + { + reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); + } } std::vector typeParams; @@ -4892,14 +4972,89 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (typePackParams.empty() && !extraTypes.empty()) typePackParams.push_back(addTypePack(extraTypes)); + if (FFlag::LuauTypeAliasDefaults) + { + size_t typesProvided = typeParams.size(); + size_t typesRequired = tf->typeParams.size(); + + size_t packsProvided = typePackParams.size(); + size_t packsRequired = tf->typePackParams.size(); + + bool notEnoughParameters = + (typesProvided < typesRequired && packsProvided == 0) || (typesProvided == typesRequired && packsProvided < packsRequired); + bool hasDefaultParameters = hasDefaultTypes || hasDefaultPacks; + + // Add default type and type pack parameters if that's required and it's possible + if (notEnoughParameters && hasDefaultParameters) + { + // 'applyTypeFunction' is used to substitute default types that reference previous generic types + applyTypeFunction.typeArguments.clear(); + applyTypeFunction.typePackArguments.clear(); + applyTypeFunction.currentModule = currentModule; + applyTypeFunction.level = scope->level; + applyTypeFunction.encounteredForwardedType = false; + + for (size_t i = 0; i < typesProvided; ++i) + applyTypeFunction.typeArguments[tf->typeParams[i].ty] = typeParams[i]; + + if (typesProvided < typesRequired) + { + for (size_t i = typesProvided; i < typesRequired; ++i) + { + TypeId defaultTy = tf->typeParams[i].defaultValue.value_or(nullptr); + + if (!defaultTy) + break; + + std::optional maybeInstantiated = applyTypeFunction.substitute(defaultTy); + + if (!maybeInstantiated.has_value()) + { + reportError(annotation.location, UnificationTooComplex{}); + maybeInstantiated = errorRecoveryType(scope); + } + + applyTypeFunction.typeArguments[tf->typeParams[i].ty] = *maybeInstantiated; + typeParams.push_back(*maybeInstantiated); + } + } + + for (size_t i = 0; i < packsProvided; ++i) + applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = typePackParams[i]; + + if (packsProvided < packsRequired) + { + for (size_t i = packsProvided; i < packsRequired; ++i) + { + TypePackId defaultTp = tf->typePackParams[i].defaultValue.value_or(nullptr); + + if (!defaultTp) + break; + + std::optional maybeInstantiated = applyTypeFunction.substitute(defaultTp); + + if (!maybeInstantiated.has_value()) + { + reportError(annotation.location, UnificationTooComplex{}); + maybeInstantiated = errorRecoveryTypePack(scope); + } + + applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = *maybeInstantiated; + typePackParams.push_back(*maybeInstantiated); + } + } + } + } + // If we didn't combine regular types into a type pack and we're still one type pack short, provide an empty type pack if (extraTypes.empty() && typePackParams.size() + 1 == tf->typePackParams.size()) typePackParams.push_back(addTypePack({})); if (typeParams.size() != tf->typeParams.size() || typePackParams.size() != tf->typePackParams.size()) { - reportError( - TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); + if (!parameterCountErrorReported) + reportError( + TypeError{annotation.location, IncorrectGenericParameterCount{lit->name.value, *tf, typeParams.size(), typePackParams.size()}}); if (FFlag::LuauErrorRecoveryType) { @@ -4913,11 +5068,20 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation return errorRecoveryType(scope); } - if (FFlag::LuauRecursiveTypeParameterRestriction && typeParams == tf->typeParams && typePackParams == tf->typePackParams) + if (FFlag::LuauRecursiveTypeParameterRestriction) { + bool sameTys = std::equal(typeParams.begin(), typeParams.end(), tf->typeParams.begin(), tf->typeParams.end(), [](auto&& itp, auto&& tp) { + return itp == tp.ty; + }); + bool sameTps = std::equal( + typePackParams.begin(), typePackParams.end(), tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& itpp, auto&& tpp) { + return itpp == tpp.tp; + }); + // If the generic parameters and the type arguments are the same, we are about to // perform an identity substitution, which we can just short-circuit. - return tf->type; + if (sameTys && sameTps) + return tf->type; } return instantiateTypeFun(scope, *tf, typeParams, typePackParams, annotation.location); @@ -4948,7 +5112,19 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation TypePackId argTypes = resolveTypePack(funcScope, func->argTypes); TypePackId retTypes = resolveTypePack(funcScope, func->returnTypes); - TypeId fnType = addType(FunctionTypeVar{funcScope->level, std::move(generics), std::move(genericPacks), argTypes, retTypes}); + std::vector genericTys; + genericTys.reserve(generics.size()); + std::transform(generics.begin(), generics.end(), std::back_inserter(genericTys), [](auto&& el) { + return el.ty; + }); + + std::vector genericTps; + genericTps.reserve(genericPacks.size()); + std::transform(genericPacks.begin(), genericPacks.end(), std::back_inserter(genericTps), [](auto&& el) { + return el.tp; + }); + + TypeId fnType = addType(FunctionTypeVar{funcScope->level, std::move(genericTys), std::move(genericTps), argTypes, retTypes}); FunctionTypeVar* ftv = getMutable(fnType); @@ -5137,11 +5313,11 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, applyTypeFunction.typeArguments.clear(); for (size_t i = 0; i < tf.typeParams.size(); ++i) - applyTypeFunction.typeArguments[tf.typeParams[i]] = typeParams[i]; + applyTypeFunction.typeArguments[tf.typeParams[i].ty] = typeParams[i]; applyTypeFunction.typePackArguments.clear(); for (size_t i = 0; i < tf.typePackParams.size(); ++i) - applyTypeFunction.typePackArguments[tf.typePackParams[i]] = typePackParams[i]; + applyTypeFunction.typePackArguments[tf.typePackParams[i].tp] = typePackParams[i]; applyTypeFunction.currentModule = currentModule; applyTypeFunction.level = scope->level; @@ -5213,17 +5389,23 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, return instantiated; } -std::pair, std::vector> TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional levelOpt, - const AstNode& node, const AstArray& genericNames, const AstArray& genericPackNames) +GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional levelOpt, const AstNode& node, + const AstArray& genericNames, const AstArray& genericPackNames) { LUAU_ASSERT(scope->parent); const TypeLevel level = (FFlag::LuauQuantifyInPlace2 && levelOpt) ? *levelOpt : scope->level; - std::vector generics; - for (const AstName& generic : genericNames) + std::vector generics; + + for (const AstGenericType& generic : genericNames) { - Name n = generic.value; + std::optional defaultValue; + + if (FFlag::LuauTypeAliasDefaults && generic.defaultValue) + defaultValue = resolveType(scope, *generic.defaultValue); + + Name n = generic.name.value; // These generics are the only thing that will ever be added to scope, so we can be certain that // a collision can only occur when two generic typevars have the same name. @@ -5246,14 +5428,20 @@ std::pair, std::vector> TypeChecker::createGener g = addType(Unifiable::Generic{level, n}); } - generics.push_back(g); + generics.push_back({g, defaultValue}); scope->privateTypeBindings[n] = TypeFun{{}, g}; } - std::vector genericPacks; - for (const AstName& genericPack : genericPackNames) + std::vector genericPacks; + + for (const AstGenericTypePack& genericPack : genericPackNames) { - Name n = genericPack.value; + std::optional defaultValue; + + if (FFlag::LuauTypeAliasDefaults && genericPack.defaultValue) + defaultValue = resolveTypePack(scope, *genericPack.defaultValue); + + Name n = genericPack.name.value; // These generics are the only thing that will ever be added to scope, so we can be certain that // a collision can only occur when two generic typevars have the same name. @@ -5276,7 +5464,7 @@ std::pair, std::vector> TypeChecker::createGener g = addTypePack(TypePackVar{Unifiable::Generic{level, n}}); } - genericPacks.push_back(g); + genericPacks.push_back({g, defaultValue}); scope->privateTypePackBindings[n] = g; } diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 4cab79c8a..ac2b25410 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -19,6 +19,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) +LUAU_FASTFLAGVARIABLE(LuauMetatableAreEqualRecursion, false) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAG(DebugLuauFreezeArena) @@ -453,6 +454,9 @@ bool areEqual(SeenSet& seen, const TableTypeVar& lhs, const TableTypeVar& rhs) static bool areEqual(SeenSet& seen, const MetatableTypeVar& lhs, const MetatableTypeVar& rhs) { + if (FFlag::LuauMetatableAreEqualRecursion && areSeen(seen, &lhs, &rhs)) + return true; + return areEqual(seen, *lhs.table, *rhs.table) && areEqual(seen, *lhs.metatable, *rhs.metatable); } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 393a84a70..6873c657c 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -22,6 +22,7 @@ LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauProperTypeLevels); +LUAU_FASTFLAGVARIABLE(LuauUnifyPackTails, false) LUAU_FASTFLAGVARIABLE(LuauExtendedUnionMismatchError, false) LUAU_FASTFLAGVARIABLE(LuauExtendedFunctionMismatchError, false) @@ -1170,9 +1171,15 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { + if (FFlag::LuauUnifyPackTails && subTpv->tail && superTpv->tail) + { + tryUnify_(*subTpv->tail, *superTpv->tail); + break; + } + const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; - if (lFreeTail && rFreeTail) + if (!FFlag::LuauUnifyPackTails && lFreeTail && rFreeTail) tryUnify_(*subTpv->tail, *superTpv->tail); else if (lFreeTail) tryUnify_(emptyTp, *superTpv->tail); @@ -1370,9 +1377,15 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { + if (FFlag::LuauUnifyPackTails && subTpv->tail && superTpv->tail) + { + tryUnify_(*subTpv->tail, *superTpv->tail); + break; + } + const bool lFreeTail = superTpv->tail && get(follow(*superTpv->tail)) != nullptr; const bool rFreeTail = subTpv->tail && get(follow(*subTpv->tail)) != nullptr; - if (lFreeTail && rFreeTail) + if (!FFlag::LuauUnifyPackTails && lFreeTail && rFreeTail) tryUnify_(*subTpv->tail, *superTpv->tail); else if (lFreeTail) tryUnify_(emptyTp, *superTpv->tail); diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 5b4bfa033..573850a5a 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -334,6 +334,20 @@ struct AstTypeList using AstArgumentName = std::pair; // TODO: remove and replace when we get a common struct for this pair instead of AstName +struct AstGenericType +{ + AstName name; + Location location; + AstType* defaultValue = nullptr; +}; + +struct AstGenericTypePack +{ + AstName name; + Location location; + AstTypePack* defaultValue = nullptr; +}; + extern int gAstRttiIndex; template @@ -569,15 +583,15 @@ class AstExprFunction : public AstExpr public: LUAU_RTTI(AstExprFunction) - AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, AstLocal* self, - const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, const AstName& debugname, - std::optional returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, bool hasEnd = false, + AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, + AstLocal* self, const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, + const AstName& debugname, std::optional returnAnnotation = {}, AstTypePack* varargAnnotation = nullptr, bool hasEnd = false, std::optional argLocation = std::nullopt); void visit(AstVisitor* visitor) override; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstLocal* self; AstArray args; bool hasReturnAnnotation; @@ -942,14 +956,14 @@ class AstStatTypeAlias : public AstStat public: LUAU_RTTI(AstStatTypeAlias) - AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, const AstArray& genericPacks, - AstType* type, bool exported); + AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, AstType* type, bool exported); void visit(AstVisitor* visitor) override; AstName name; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstType* type; bool exported; }; @@ -972,14 +986,15 @@ class AstStatDeclareFunction : public AstStat public: LUAU_RTTI(AstStatDeclareFunction) - AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, const AstArray& genericPacks, - const AstTypeList& params, const AstArray& paramNames, const AstTypeList& retTypes); + AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, + const AstTypeList& retTypes); void visit(AstVisitor* visitor) override; AstName name; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstTypeList params; AstArray paramNames; AstTypeList retTypes; @@ -1077,13 +1092,13 @@ class AstTypeFunction : public AstType public: LUAU_RTTI(AstTypeFunction) - AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, const AstTypeList& argTypes, - const AstArray>& argNames, const AstTypeList& returnTypes); + AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, + const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes); void visit(AstVisitor* visitor) override; - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; AstTypeList argTypes; AstArray> argNames; AstTypeList returnTypes; diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 87ebc48b5..40ecdcdd5 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -219,7 +219,7 @@ class Parser AstTableIndexer* parseTableIndexerAnnotation(); AstTypeOrPack parseFunctionTypeAnnotation(bool allowPack); - AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, + AstType* parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation); AstType* parseTableTypeAnnotation(); @@ -281,7 +281,7 @@ class Parser Name parseIndexName(const char* context, const Position& previous); // `<' namelist `>' - std::pair, AstArray> parseGenericTypeList(); + std::pair, AstArray> parseGenericTypeList(bool withDefaultValues); // `<' typeAnnotation[, ...] `>' AstArray parseTypeParams(); @@ -418,6 +418,8 @@ class Parser std::vector scratchDeclaredClassProps; std::vector scratchItem; std::vector scratchArgName; + std::vector scratchGenericTypes; + std::vector scratchGenericTypePacks; std::vector> scratchOptArgName; std::string scratchData; }; diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index e709894d9..9b5bc0c71 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -158,9 +158,10 @@ void AstExprIndexExpr::visit(AstVisitor* visitor) } } -AstExprFunction::AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, AstLocal* self, - const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, const AstName& debugname, - std::optional returnAnnotation, AstTypePack* varargAnnotation, bool hasEnd, std::optional argLocation) +AstExprFunction::AstExprFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, + AstLocal* self, const AstArray& args, std::optional vararg, AstStatBlock* body, size_t functionDepth, + const AstName& debugname, std::optional returnAnnotation, AstTypePack* varargAnnotation, bool hasEnd, + std::optional argLocation) : AstExpr(ClassIndex(), location) , generics(generics) , genericPacks(genericPacks) @@ -641,8 +642,8 @@ void AstStatLocalFunction::visit(AstVisitor* visitor) func->visit(visitor); } -AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, AstType* type, bool exported) +AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, AstType* type, bool exported) : AstStat(ClassIndex(), location) , name(name) , generics(generics) @@ -655,7 +656,21 @@ AstStatTypeAlias::AstStatTypeAlias(const Location& location, const AstName& name void AstStatTypeAlias::visit(AstVisitor* visitor) { if (visitor->visit(this)) + { + for (const AstGenericType& el : generics) + { + if (el.defaultValue) + el.defaultValue->visit(visitor); + } + + for (const AstGenericTypePack& el : genericPacks) + { + if (el.defaultValue) + el.defaultValue->visit(visitor); + } + type->visit(visitor); + } } AstStatDeclareGlobal::AstStatDeclareGlobal(const Location& location, const AstName& name, AstType* type) @@ -671,8 +686,9 @@ void AstStatDeclareGlobal::visit(AstVisitor* visitor) type->visit(visitor); } -AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, - const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, const AstTypeList& retTypes) +AstStatDeclareFunction::AstStatDeclareFunction(const Location& location, const AstName& name, const AstArray& generics, + const AstArray& genericPacks, const AstTypeList& params, const AstArray& paramNames, + const AstTypeList& retTypes) : AstStat(ClassIndex(), location) , name(name) , generics(generics) @@ -778,7 +794,7 @@ void AstTypeTable::visit(AstVisitor* visitor) } } -AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, +AstTypeFunction::AstTypeFunction(const Location& location, const AstArray& generics, const AstArray& genericPacks, const AstTypeList& argTypes, const AstArray>& argNames, const AstTypeList& returnTypes) : AstType(ClassIndex(), location) , generics(generics) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 72f616497..77787cb1c 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -10,10 +10,10 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauIfElseExpressionBaseSupport, false) -LUAU_FASTFLAGVARIABLE(LuauIfStatementRecursionGuard, false) LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) +LUAU_FASTFLAGVARIABLE(LuauParseTypeAliasDefaults, false) +LUAU_FASTFLAGVARIABLE(LuauParseRecoverTypePackEllipsis, false) namespace Luau { @@ -394,23 +394,13 @@ AstStat* Parser::parseIf() if (lexer.current().type == Lexeme::ReservedElseif) { - if (FFlag::LuauIfStatementRecursionGuard) - { - unsigned int recursionCounterOld = recursionCounter; - incrementRecursionCounter("elseif"); - elseLocation = lexer.current().location; - elsebody = parseIf(); - end = elsebody->location; - hasEnd = elsebody->as()->hasEnd; - recursionCounter = recursionCounterOld; - } - else - { - elseLocation = lexer.current().location; - elsebody = parseIf(); - end = elsebody->location; - hasEnd = elsebody->as()->hasEnd; - } + unsigned int recursionCounterOld = recursionCounter; + incrementRecursionCounter("elseif"); + elseLocation = lexer.current().location; + elsebody = parseIf(); + end = elsebody->location; + hasEnd = elsebody->as()->hasEnd; + recursionCounter = recursionCounterOld; } else { @@ -772,7 +762,7 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) if (!name) name = Name(nameError, lexer.current().location); - auto [generics, genericPacks] = parseGenericTypeList(); + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ FFlag::LuauParseTypeAliasDefaults); expectAndConsume('=', "type alias"); @@ -788,8 +778,8 @@ AstDeclaredClassProp Parser::parseDeclaredClassMethod() Name fnName = parseName("function name"); // TODO: generic method declarations CLI-39909 - AstArray generics; - AstArray genericPacks; + AstArray generics; + AstArray genericPacks; generics.size = 0; generics.data = nullptr; genericPacks.size = 0; @@ -849,7 +839,7 @@ AstStat* Parser::parseDeclaration(const Location& start) nextLexeme(); Name globalName = parseName("global function name"); - auto [generics, genericPacks] = parseGenericTypeList(); + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); Lexeme matchParen = lexer.current(); @@ -991,7 +981,7 @@ std::pair Parser::parseFunctionBody( { Location start = matchFunction.location; - auto [generics, genericPacks] = parseGenericTypeList(); + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); Lexeme matchParen = lexer.current(); expectAndConsume('(', "function"); @@ -1228,8 +1218,8 @@ std::pair Parser::parseReturnTypeAnnotation() return {location, AstTypeList{copy(result), varargAnnotation}}; } - AstArray generics{nullptr, 0}; - AstArray genericPacks{nullptr, 0}; + AstArray generics{nullptr, 0}; + AstArray genericPacks{nullptr, 0}; AstArray types = copy(result); AstArray> names = copy(resultNames); @@ -1363,7 +1353,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) Lexeme begin = lexer.current(); - auto [generics, genericPacks] = parseGenericTypeList(); + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ false); Lexeme parameterStart = lexer.current(); @@ -1401,7 +1391,7 @@ AstTypeOrPack Parser::parseFunctionTypeAnnotation(bool allowPack) return {parseFunctionTypeAnnotationTail(begin, generics, genericPacks, paramTypes, paramNames, varargAnnotation), {}}; } -AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, +AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray generics, AstArray genericPacks, AstArray& params, AstArray>& paramNames, AstTypePack* varargAnnotation) { @@ -1448,7 +1438,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location if (c == '|') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation(false).type); + parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); isUnion = true; } else if (c == '?') @@ -1461,7 +1451,7 @@ AstType* Parser::parseTypeAnnotation(TempVector& parts, const Location else if (c == '&') { nextLexeme(); - parts.push_back(parseSimpleTypeAnnotation(false).type); + parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); isIntersection = true; } else @@ -1498,7 +1488,7 @@ AstTypeOrPack Parser::parseTypeOrPackAnnotation() TempVector parts(scratchAnnotation); - auto [type, typePack] = parseSimpleTypeAnnotation(true); + auto [type, typePack] = parseSimpleTypeAnnotation(/* allowPack= */ true); if (typePack) { @@ -1521,7 +1511,7 @@ AstType* Parser::parseTypeAnnotation() Location begin = lexer.current().location; TempVector parts(scratchAnnotation); - parts.push_back(parseSimpleTypeAnnotation(false).type); + parts.push_back(parseSimpleTypeAnnotation(/* allowPack= */ false).type); recursionCounter = oldRecursionCount; @@ -2121,7 +2111,7 @@ AstExpr* Parser::parseSimpleExpr() { return parseTableConstructor(); } - else if (FFlag::LuauIfElseExpressionBaseSupport && lexer.current().type == Lexeme::ReservedIf) + else if (lexer.current().type == Lexeme::ReservedIf) { return parseIfElseExpr(); } @@ -2341,10 +2331,10 @@ Parser::Name Parser::parseIndexName(const char* context, const Position& previou return Name(nameError, location); } -std::pair, AstArray> Parser::parseGenericTypeList() +std::pair, AstArray> Parser::parseGenericTypeList(bool withDefaultValues) { - TempVector names{scratchName}; - TempVector namePacks{scratchPackName}; + TempVector names{scratchGenericTypes}; + TempVector namePacks{scratchGenericTypePacks}; if (lexer.current().type == '<') { @@ -2352,21 +2342,73 @@ std::pair, AstArray> Parser::parseGenericTypeList() nextLexeme(); bool seenPack = false; + bool seenDefault = false; + while (true) { + Location nameLocation = lexer.current().location; AstName name = parseName().name; - if (lexer.current().type == Lexeme::Dot3) + if (lexer.current().type == Lexeme::Dot3 || (FFlag::LuauParseRecoverTypePackEllipsis && seenPack)) { seenPack = true; - nextLexeme(); - namePacks.push_back(name); + + if (FFlag::LuauParseRecoverTypePackEllipsis && lexer.current().type != Lexeme::Dot3) + report(lexer.current().location, "Generic types come before generic type packs"); + else + nextLexeme(); + + if (withDefaultValues && lexer.current().type == '=') + { + seenDefault = true; + nextLexeme(); + + Lexeme packBegin = lexer.current(); + + if (shouldParseTypePackAnnotation(lexer)) + { + auto typePack = parseTypePackAnnotation(); + + namePacks.push_back({name, nameLocation, typePack}); + } + else if (lexer.current().type == '(') + { + auto [type, typePack] = parseTypeOrPackAnnotation(); + + if (type) + report(Location(packBegin.location.begin, lexer.previousLocation().end), "Expected type pack after '=', got type"); + + namePacks.push_back({name, nameLocation, typePack}); + } + } + else + { + if (seenDefault) + report(lexer.current().location, "Expected default type pack after type pack name"); + + namePacks.push_back({name, nameLocation, nullptr}); + } } else { - if (seenPack) + if (!FFlag::LuauParseRecoverTypePackEllipsis && seenPack) report(lexer.current().location, "Generic types come before generic type packs"); - names.push_back(name); + if (withDefaultValues && lexer.current().type == '=') + { + seenDefault = true; + nextLexeme(); + + AstType* defaultType = parseTypeAnnotation(); + + names.push_back({name, nameLocation, defaultType}); + } + else + { + if (seenDefault) + report(lexer.current().location, "Expected default type after type name"); + + names.push_back({name, nameLocation, nullptr}); + } } if (lexer.current().type == ',') @@ -2378,8 +2420,8 @@ std::pair, AstArray> Parser::parseGenericTypeList() expectMatchAndConsume('>', begin); } - AstArray generics = copy(names); - AstArray genericPacks = copy(namePacks); + AstArray generics = copy(names); + AstArray genericPacks = copy(namePacks); return {generics, genericPacks}; } diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index 54a9a26fa..e0dc3e0fe 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -8,6 +8,8 @@ #include "FileUtils.h" +LUAU_FASTFLAG(DebugLuauTimeTracing) + enum class ReportFormat { Default, @@ -105,6 +107,7 @@ static void displayHelp(const char* argv0) printf("Available options:\n"); printf(" --formatter=plain: report analysis errors in Luacheck-compatible format\n"); printf(" --formatter=gnu: report analysis errors in GNU-compatible format\n"); + printf(" --timetrace: record compiler time tracing information into trace.json\n"); } static int assertionHandler(const char* expr, const char* file, int line, const char* function) @@ -213,7 +216,17 @@ int main(int argc, char** argv) format = ReportFormat::Gnu; else if (strcmp(argv[i], "--annotate") == 0) annotate = true; + else if (strcmp(argv[i], "--timetrace") == 0) + FFlag::DebugLuauTimeTracing.value = true; + } + +#if !defined(LUAU_ENABLE_TIME_TRACE) + if (FFlag::DebugLuauTimeTracing) + { + printf("To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n"); + return 1; } +#endif Luau::FrontendOptions frontendOptions; frontendOptions.retainFullTypeGraphs = annotate; @@ -240,7 +253,10 @@ int main(int argc, char** argv) fprintf(stderr, "%s: %s\n", pair.first.c_str(), pair.second.c_str()); } - return (format == ReportFormat::Luacheck) ? 0 : failed; + if (format == ReportFormat::Luacheck) + return 0; + else + return failed ? 1 : 0; } diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 26d4333a9..36747f487 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -19,6 +19,16 @@ #include #endif +LUAU_FASTFLAG(DebugLuauTimeTracing) + +enum class CliMode +{ + Unknown, + Repl, + Compile, + RunSourceFiles +}; + enum class CompileFormat { Text, @@ -485,8 +495,10 @@ static void displayHelp(const char* argv0) printf(" --compile[=format]: compile input files and output resulting formatted bytecode (binary or text)\n"); printf("\n"); printf("Available options:\n"); + printf(" -h, --help: Display this usage message.\n"); printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); + printf(" --timetrace: record compiler time tracing information into trace.json\n"); } static int assertionHandler(const char* expr, const char* file, int line, const char* function) @@ -503,71 +515,112 @@ int main(int argc, char** argv) if (strncmp(flag->name, "Luau", 4) == 0) flag->value = true; - if (argc == 1) + CliMode mode = CliMode::Unknown; + CompileFormat compileFormat{}; + int profile = 0; + bool coverage = false; + + // Set the mode if the user has explicitly specified one. + int argStart = 1; + if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0) { - runRepl(); - return 0; + argStart++; + mode = CliMode::Compile; + if (strcmp(argv[1], "--compile") == 0) + { + compileFormat = CompileFormat::Text; + } + else if (strcmp(argv[1], "--compile=binary") == 0) + { + compileFormat = CompileFormat::Binary; + } + else if (strcmp(argv[1], "--compile=text") == 0) + { + compileFormat = CompileFormat::Text; + } + else + { + fprintf(stdout, "Error: Unrecognized value for '--compile' specified.\n"); + return -1; + } } - if (argc >= 2 && strcmp(argv[1], "--help") == 0) + for (int i = argStart; i < argc; i++) { - displayHelp(argv[0]); - return 0; - } + if (strcmp(argv[i], "-h") == 0 || strcmp(argv[i], "--help") == 0) + { + displayHelp(argv[0]); + return 0; + } + else if (strcmp(argv[i], "--profile") == 0) + { + profile = 10000; // default to 10 KHz + } + else if (strncmp(argv[i], "--profile=", 10) == 0) + { + profile = atoi(argv[i] + 10); + } + else if (strcmp(argv[i], "--coverage") == 0) + { + coverage = true; + } + else if (strcmp(argv[i], "--timetrace") == 0) + { + FFlag::DebugLuauTimeTracing.value = true; +#if !defined(LUAU_ENABLE_TIME_TRACE) + printf("To run with --timetrace, Luau has to be built with LUAU_ENABLE_TIME_TRACE enabled\n"); + return 1; +#endif + } + else if (argv[i][0] == '-') + { + fprintf(stdout, "Error: Unrecognized option '%s'.\n\n", argv[i]); + displayHelp(argv[0]); + return 1; + } + } - if (argc >= 2 && strncmp(argv[1], "--compile", strlen("--compile")) == 0) + const std::vector files = getSourceFiles(argc, argv); + if (mode == CliMode::Unknown) { - CompileFormat format = CompileFormat::Text; - - if (strcmp(argv[1], "--compile=binary") == 0) - format = CompileFormat::Binary; + mode = files.empty() ? CliMode::Repl : CliMode::RunSourceFiles; + } + switch (mode) + { + case CliMode::Compile: + { #ifdef _WIN32 - if (format == CompileFormat::Binary) + if (compileFormat == CompileFormat::Binary) _setmode(_fileno(stdout), _O_BINARY); #endif - std::vector files = getSourceFiles(argc, argv); - int failed = 0; for (const std::string& path : files) - failed += !compileFile(path.c_str(), format); + failed += !compileFile(path.c_str(), compileFormat); - return failed; + return failed ? 1 : 0; } - + case CliMode::Repl: + { + runRepl(); + return 0; + } + case CliMode::RunSourceFiles: { std::unique_ptr globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); setupState(L); - int profile = 0; - bool coverage = false; - - for (int i = 1; i < argc; ++i) - { - if (argv[i][0] != '-') - continue; - - if (strcmp(argv[i], "--profile") == 0) - profile = 10000; // default to 10 KHz - else if (strncmp(argv[i], "--profile=", 10) == 0) - profile = atoi(argv[i] + 10); - else if (strcmp(argv[i], "--coverage") == 0) - coverage = true; - } - if (profile) profilerStart(L, profile); if (coverage) coverageInit(L); - std::vector files = getSourceFiles(argc, argv); - int failed = 0; for (const std::string& path : files) @@ -582,7 +635,11 @@ int main(int argc, char** argv) if (coverage) coverageDump("coverage.out"); - return failed; + return failed ? 1 : 0; + } + case CliMode::Unknown: + default: + LUAU_ASSERT(!"Unhandled cli mode."); } } diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp new file mode 100644 index 000000000..e344eb917 --- /dev/null +++ b/Compiler/src/Builtins.cpp @@ -0,0 +1,197 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Builtins.h" + +#include "Luau/Bytecode.h" +#include "Luau/Compiler.h" + +namespace Luau +{ +namespace Compile +{ + +Builtin getBuiltin(AstExpr* node, const DenseHashMap& globals, const DenseHashMap& variables) +{ + if (AstExprLocal* expr = node->as()) + { + const Variable* v = variables.find(expr->local); + + return v && !v->written && v->init ? getBuiltin(v->init, globals, variables) : Builtin(); + } + else if (AstExprIndexName* expr = node->as()) + { + if (AstExprGlobal* object = expr->expr->as()) + { + return getGlobalState(globals, object->name) == Global::Default ? Builtin{object->name, expr->index} : Builtin(); + } + else + { + return Builtin(); + } + } + else if (AstExprGlobal* expr = node->as()) + { + return getGlobalState(globals, expr->name) == Global::Default ? Builtin{AstName(), expr->name} : Builtin(); + } + else + { + return Builtin(); + } +} + +int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options) +{ + if (builtin.empty()) + return -1; + + if (builtin.isGlobal("assert")) + return LBF_ASSERT; + + if (builtin.isGlobal("type")) + return LBF_TYPE; + + if (builtin.isGlobal("typeof")) + return LBF_TYPEOF; + + if (builtin.isGlobal("rawset")) + return LBF_RAWSET; + if (builtin.isGlobal("rawget")) + return LBF_RAWGET; + if (builtin.isGlobal("rawequal")) + return LBF_RAWEQUAL; + + if (builtin.isGlobal("unpack")) + return LBF_TABLE_UNPACK; + + if (builtin.object == "math") + { + if (builtin.method == "abs") + return LBF_MATH_ABS; + if (builtin.method == "acos") + return LBF_MATH_ACOS; + if (builtin.method == "asin") + return LBF_MATH_ASIN; + if (builtin.method == "atan2") + return LBF_MATH_ATAN2; + if (builtin.method == "atan") + return LBF_MATH_ATAN; + if (builtin.method == "ceil") + return LBF_MATH_CEIL; + if (builtin.method == "cosh") + return LBF_MATH_COSH; + if (builtin.method == "cos") + return LBF_MATH_COS; + if (builtin.method == "deg") + return LBF_MATH_DEG; + if (builtin.method == "exp") + return LBF_MATH_EXP; + if (builtin.method == "floor") + return LBF_MATH_FLOOR; + if (builtin.method == "fmod") + return LBF_MATH_FMOD; + if (builtin.method == "frexp") + return LBF_MATH_FREXP; + if (builtin.method == "ldexp") + return LBF_MATH_LDEXP; + if (builtin.method == "log10") + return LBF_MATH_LOG10; + if (builtin.method == "log") + return LBF_MATH_LOG; + if (builtin.method == "max") + return LBF_MATH_MAX; + if (builtin.method == "min") + return LBF_MATH_MIN; + if (builtin.method == "modf") + return LBF_MATH_MODF; + if (builtin.method == "pow") + return LBF_MATH_POW; + if (builtin.method == "rad") + return LBF_MATH_RAD; + if (builtin.method == "sinh") + return LBF_MATH_SINH; + if (builtin.method == "sin") + return LBF_MATH_SIN; + if (builtin.method == "sqrt") + return LBF_MATH_SQRT; + if (builtin.method == "tanh") + return LBF_MATH_TANH; + if (builtin.method == "tan") + return LBF_MATH_TAN; + if (builtin.method == "clamp") + return LBF_MATH_CLAMP; + if (builtin.method == "sign") + return LBF_MATH_SIGN; + if (builtin.method == "round") + return LBF_MATH_ROUND; + } + + if (builtin.object == "bit32") + { + if (builtin.method == "arshift") + return LBF_BIT32_ARSHIFT; + if (builtin.method == "band") + return LBF_BIT32_BAND; + if (builtin.method == "bnot") + return LBF_BIT32_BNOT; + if (builtin.method == "bor") + return LBF_BIT32_BOR; + if (builtin.method == "bxor") + return LBF_BIT32_BXOR; + if (builtin.method == "btest") + return LBF_BIT32_BTEST; + if (builtin.method == "extract") + return LBF_BIT32_EXTRACT; + if (builtin.method == "lrotate") + return LBF_BIT32_LROTATE; + if (builtin.method == "lshift") + return LBF_BIT32_LSHIFT; + if (builtin.method == "replace") + return LBF_BIT32_REPLACE; + if (builtin.method == "rrotate") + return LBF_BIT32_RROTATE; + if (builtin.method == "rshift") + return LBF_BIT32_RSHIFT; + if (builtin.method == "countlz") + return LBF_BIT32_COUNTLZ; + if (builtin.method == "countrz") + return LBF_BIT32_COUNTRZ; + } + + if (builtin.object == "string") + { + if (builtin.method == "byte") + return LBF_STRING_BYTE; + if (builtin.method == "char") + return LBF_STRING_CHAR; + if (builtin.method == "len") + return LBF_STRING_LEN; + if (builtin.method == "sub") + return LBF_STRING_SUB; + } + + if (builtin.object == "table") + { + if (builtin.method == "insert") + return LBF_TABLE_INSERT; + if (builtin.method == "unpack") + return LBF_TABLE_UNPACK; + } + + if (options.vectorCtor) + { + if (options.vectorLib) + { + if (builtin.isMethod(options.vectorLib, options.vectorCtor)) + return LBF_VECTOR; + } + else + { + if (builtin.isGlobal(options.vectorCtor)) + return LBF_VECTOR; + } + } + + return -1; +} + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/Builtins.h b/Compiler/src/Builtins.h new file mode 100644 index 000000000..60df53a18 --- /dev/null +++ b/Compiler/src/Builtins.h @@ -0,0 +1,41 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "ValueTracking.h" + +namespace Luau +{ +struct CompileOptions; +} + +namespace Luau +{ +namespace Compile +{ + +struct Builtin +{ + AstName object; + AstName method; + + bool empty() const + { + return object == AstName() && method == AstName(); + } + + bool isGlobal(const char* name) const + { + return object == AstName() && method == name; + } + + bool isMethod(const char* table, const char* name) const + { + return object == table && method == name; + } +}; + +Builtin getBuiltin(AstExpr* node, const DenseHashMap& globals, const DenseHashMap& variables); +int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options); + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index 2d31c409c..e6d024546 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -714,7 +714,7 @@ void BytecodeBuilder::writeLineInfo(std::string& ss) const // third pass: write resulting data int logspan = log2(span); - writeByte(ss, logspan); + writeByte(ss, uint8_t(logspan)); uint8_t lastOffset = 0; @@ -723,8 +723,8 @@ void BytecodeBuilder::writeLineInfo(std::string& ss) const int delta = lines[i] - baseline[i >> logspan]; LUAU_ASSERT(delta >= 0 && delta <= 255); - writeByte(ss, delta - lastOffset); - lastOffset = delta; + writeByte(ss, uint8_t(delta) - lastOffset); + lastOffset = uint8_t(delta); } int lastLine = 0; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 6ae490273..9758c4a9a 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -6,15 +6,20 @@ #include "Luau/Common.h" #include "Luau/TimeTrace.h" +#include "Builtins.h" +#include "ConstantFolding.h" +#include "TableShape.h" +#include "ValueTracking.h" + #include #include #include -LUAU_FASTFLAG(LuauIfElseExpressionBaseSupport) - namespace Luau { +using namespace Luau::Compile; + static const uint32_t kMaxRegisterCount = 255; static const uint32_t kMaxUpvalueCount = 200; static const uint32_t kMaxLocalCount = 200; @@ -62,7 +67,6 @@ static BytecodeBuilder::StringRef sref(AstArray data) struct Compiler { - struct Constant; struct RegScope; Compiler(BytecodeBuilder& bytecode, const CompileOptions& options) @@ -71,8 +75,9 @@ struct Compiler , functions(nullptr) , locals(nullptr) , globals(AstName()) + , variables(nullptr) , constants(nullptr) - , predictedTableSize(nullptr) + , tableShapes(nullptr) { } @@ -96,8 +101,10 @@ struct Compiler local->location, "Out of upvalue registers when trying to allocate %s: exceeded limit %d", local->name.value, kMaxUpvalueCount); // mark local as captured so that closeLocals emits LOP_CLOSEUPVALS accordingly - Local& l = locals[local]; - l.captured = true; + Variable* v = variables.find(local); + + if (v && v->written) + locals[local].captured = true; upvals.push_back(local); @@ -273,8 +280,8 @@ struct Compiler if (options.optimizationLevel >= 1) { - Builtin builtin = getBuiltin(expr->func); - bfid = getBuiltinFunctionId(builtin); + Builtin builtin = getBuiltin(expr->func, globals, variables); + bfid = getBuiltinFunctionId(builtin, options); } if (expr->self) @@ -364,12 +371,12 @@ struct Compiler else { args[i] = uint8_t(regs + 1 + i); - compileExprTempTop(expr->args.data[i], args[i]); + compileExprTempTop(expr->args.data[i], uint8_t(args[i])); } } fastcallLabel = bytecode.emitLabel(); - bytecode.emitABC(opc, uint8_t(bfid), args[0], 0); + bytecode.emitABC(opc, uint8_t(bfid), uint8_t(args[0]), 0); if (opc != LOP_FASTCALL1) bytecode.emitAux(args[1]); @@ -385,7 +392,7 @@ struct Compiler } if (args[i] != regs + 1 + i) - bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), args[i], 0); + bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), uint8_t(args[i]), 0); } } else @@ -424,8 +431,10 @@ struct Compiler for (AstLocal* uv : f->upvals) { - Local* ul = locals.find(uv); - LUAU_ASSERT(ul); + Variable* ul = variables.find(uv); + + if (!ul) + return false; if (ul->written) return false; @@ -437,10 +446,11 @@ struct Compiler // this will only deoptimize (outside of fenv changes) if top level code is executed twice with different results. if (uv->functionDepth != 0 || uv->loopDepth != 0) { - if (!ul->func) + AstExprFunction* uf = ul->init ? ul->init->as() : nullptr; + if (!uf) return false; - if (ul->func != func && !shouldShareClosure(ul->func)) + if (uf != func && !shouldShareClosure(uf)) return false; } } @@ -471,7 +481,7 @@ struct Compiler if (cid >= 0 && cid < 32768) { - bytecode.emitAD(LOP_DUPCLOSURE, target, cid); + bytecode.emitAD(LOP_DUPCLOSURE, target, int16_t(cid)); shared = true; } } @@ -483,17 +493,15 @@ struct Compiler { LUAU_ASSERT(uv->functionDepth < expr->functionDepth); - Local* ul = locals.find(uv); - LUAU_ASSERT(ul); - - bool immutable = !ul->written; + Variable* ul = variables.find(uv); + bool immutable = !ul || !ul->written; if (uv->functionDepth == expr->functionDepth - 1) { // get local variable uint8_t reg = getLocal(uv); - bytecode.emitABC(LOP_CAPTURE, immutable ? LCT_VAL : LCT_REF, reg, 0); + bytecode.emitABC(LOP_CAPTURE, uint8_t(immutable ? LCT_VAL : LCT_REF), reg, 0); } else { @@ -635,7 +643,7 @@ struct Compiler if (expr->op == AstExprBinary::CompareGt || expr->op == AstExprBinary::CompareGe) { - bytecode.emitAD(opc, rr, 0); + bytecode.emitAD(opc, uint8_t(rr), 0); bytecode.emitAux(rl); } else @@ -687,7 +695,7 @@ struct Compiler break; case Constant::Type_String: - cid = bytecode.addConstantString(sref(c->valueString)); + cid = bytecode.addConstantString(sref(c->getString())); break; default: @@ -1066,10 +1074,10 @@ struct Compiler // Optimization: if the table is empty, we can compute it directly into the target if (expr->items.size == 0) { - auto [hashSize, arraySize] = predictedTableSize[expr]; + TableShape shape = tableShapes[expr]; - bytecode.emitABC(LOP_NEWTABLE, target, encodeHashSize(hashSize), 0); - bytecode.emitAux(arraySize); + bytecode.emitABC(LOP_NEWTABLE, target, encodeHashSize(shape.hashSize), 0); + bytecode.emitAux(shape.arraySize); return; } @@ -1144,7 +1152,7 @@ struct Compiler } else { - bytecode.emitABC(LOP_NEWTABLE, reg, encodedHashSize, 0); + bytecode.emitABC(LOP_NEWTABLE, reg, uint8_t(encodedHashSize), 0); bytecode.emitAux(0); } } @@ -1157,7 +1165,7 @@ struct Compiler bool trailingVarargs = last && last->kind == AstExprTable::Item::List && last->value->is(); LUAU_ASSERT(!trailingVarargs || arraySize > 0); - bytecode.emitABC(LOP_NEWTABLE, reg, encodedHashSize, 0); + bytecode.emitABC(LOP_NEWTABLE, reg, uint8_t(encodedHashSize), 0); bytecode.emitAux(arraySize - trailingVarargs + indexSize); } @@ -1252,16 +1260,12 @@ struct Compiler bool canImport(AstExprGlobal* expr) { - const Global* global = globals.find(expr->name); - - return options.optimizationLevel >= 1 && (!global || !global->written); + return options.optimizationLevel >= 1 && getGlobalState(globals, expr->name) != Global::Written; } bool canImportChain(AstExprGlobal* expr) { - const Global* global = globals.find(expr->name); - - return options.optimizationLevel >= 1 && (!global || (!global->written && !global->writable)); + return options.optimizationLevel >= 1 && getGlobalState(globals, expr->name) == Global::Default; } void compileExprIndexName(AstExprIndexName* expr, uint8_t target) @@ -1341,7 +1345,7 @@ struct Compiler { uint8_t rt = compileExprAuto(expr->expr, rs); - BytecodeBuilder::StringRef iname = sref(cv->valueString); + BytecodeBuilder::StringRef iname = sref(cv->getString()); int32_t cid = bytecode.addConstantString(iname); if (cid < 0) CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); @@ -1427,7 +1431,7 @@ struct Compiler case Constant::Type_String: { - int32_t cid = bytecode.addConstantString(sref(cv->valueString)); + int32_t cid = bytecode.addConstantString(sref(cv->getString())); if (cid < 0) CompileError::raise(node->location, "Exceeded constant limit; simplify the code to compile"); @@ -1546,7 +1550,7 @@ struct Compiler { compileExpr(expr->expr, target, targetTemp); } - else if (AstExprIfElse* expr = node->as(); FFlag::LuauIfElseExpressionBaseSupport && expr) + else if (AstExprIfElse* expr = node->as()) { compileExprIfElse(expr, target, targetTemp); } @@ -1711,7 +1715,7 @@ struct Compiler { LValue result = {LValue::Kind_IndexName}; result.reg = compileExprAuto(expr->expr, rs); - result.name = sref(cv->valueString); + result.name = sref(cv->getString()); result.location = node->location; return result; @@ -1796,9 +1800,8 @@ struct Compiler return false; Local* l = locals.find(le->local); - LUAU_ASSERT(l); - return l->allocated; + return l && l->allocated; } bool isStatBreak(AstStat* node) @@ -2040,9 +2043,9 @@ struct Compiler for (AstLocal* local : stat->vars) { - Local* l = locals.find(local); + Variable* v = variables.find(local); - if (!l || l->constant.type == Constant::Type_Unknown) + if (!v || !v->constant) return false; } @@ -2082,9 +2085,7 @@ struct Compiler // through) uint8_t varreg = regs + 2; - Local* il = locals.find(stat->var); - - if (il && il->written) + if (Variable* il = variables.find(stat->var); il && il->written) varreg = allocReg(stat, 1); compileExprTemp(stat->from, uint8_t(regs + 2)); @@ -2164,7 +2165,7 @@ struct Compiler { if (stat->values.size == 1 && stat->values.data[0]->is()) { - Builtin builtin = getBuiltin(stat->values.data[0]->as()->func); + Builtin builtin = getBuiltin(stat->values.data[0]->as()->func, globals, variables); if (builtin.isGlobal("ipairs")) // for .. in ipairs(t) { @@ -2179,7 +2180,7 @@ struct Compiler } else if (stat->values.size == 2) { - Builtin builtin = getBuiltin(stat->values.data[0]); + Builtin builtin = getBuiltin(stat->values.data[0], globals, variables); if (builtin.isGlobal("next")) // for .. in next,t { @@ -2594,7 +2595,7 @@ struct Compiler Local* l = locals.find(localStack[i]); LUAU_ASSERT(l); - if (l->captured && l->written) + if (l->captured) return true; } @@ -2613,7 +2614,7 @@ struct Compiler Local* l = locals.find(localStack[i]); LUAU_ASSERT(l); - if (l->captured && l->written) + if (l->captured) { captured = true; captureReg = std::min(captureReg, l->reg); @@ -2728,519 +2729,6 @@ struct Compiler return !node->is() && !node->is(); } - struct AssignmentVisitor : AstVisitor - { - struct Hasher - { - size_t operator()(const std::pair& p) const - { - return std::hash()(p.first) ^ std::hash()(p.second); - } - }; - - DenseHashMap localToTable; - DenseHashSet, Hasher> fields; - - AssignmentVisitor(Compiler* self) - : localToTable(nullptr) - , fields(std::pair()) - , self(self) - { - } - - void assignField(AstExpr* expr, AstName index) - { - if (AstExprLocal* lv = expr->as()) - { - if (AstExprTable** table = localToTable.find(lv->local)) - { - std::pair field = {*table, index}; - - if (!fields.contains(field)) - { - fields.insert(field); - self->predictedTableSize[*table].first += 1; - } - } - } - } - - void assignField(AstExpr* expr, AstExpr* index) - { - AstExprLocal* lv = expr->as(); - AstExprConstantNumber* number = index->as(); - - if (lv && number) - { - if (AstExprTable** table = localToTable.find(lv->local)) - { - unsigned int& arraySize = self->predictedTableSize[*table].second; - - if (number->value == double(arraySize + 1)) - arraySize += 1; - } - } - } - - void assign(AstExpr* var) - { - if (AstExprLocal* lv = var->as()) - { - self->locals[lv->local].written = true; - } - else if (AstExprGlobal* gv = var->as()) - { - self->globals[gv->name].written = true; - } - else if (AstExprIndexName* index = var->as()) - { - assignField(index->expr, index->index); - - var->visit(this); - } - else if (AstExprIndexExpr* index = var->as()) - { - assignField(index->expr, index->index); - - var->visit(this); - } - else - { - // we need to be able to track assignments in all expressions, including crazy ones like t[function() t = nil end] = 5 - var->visit(this); - } - } - - AstExprTable* getTableHint(AstExpr* expr) - { - // unadorned table literal - if (AstExprTable* table = expr->as()) - return table; - - // setmetatable(table literal, ...) - if (AstExprCall* call = expr->as(); call && !call->self && call->args.size == 2) - if (AstExprGlobal* func = call->func->as(); func && func->name == "setmetatable") - if (AstExprTable* table = call->args.data[0]->as()) - return table; - - return nullptr; - } - - bool visit(AstStatLocal* node) override - { - // track local -> table association so that we can update table size prediction in assignField - if (node->vars.size == 1 && node->values.size == 1) - if (AstExprTable* table = getTableHint(node->values.data[0]); table && table->items.size == 0) - localToTable[node->vars.data[0]] = table; - - return true; - } - - bool visit(AstStatAssign* node) override - { - for (size_t i = 0; i < node->vars.size; ++i) - assign(node->vars.data[i]); - - for (size_t i = 0; i < node->values.size; ++i) - node->values.data[i]->visit(this); - - return false; - } - - bool visit(AstStatCompoundAssign* node) override - { - assign(node->var); - node->value->visit(this); - - return false; - } - - bool visit(AstStatFunction* node) override - { - assign(node->name); - node->func->visit(this); - - return false; - } - - Compiler* self; - }; - - struct ConstantVisitor : AstVisitor - { - ConstantVisitor(Compiler* self) - : self(self) - { - } - - void analyzeUnary(Constant& result, AstExprUnary::Op op, const Constant& arg) - { - switch (op) - { - case AstExprUnary::Not: - if (arg.type != Constant::Type_Unknown) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = !arg.isTruthful(); - } - break; - - case AstExprUnary::Minus: - if (arg.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = -arg.valueNumber; - } - break; - - case AstExprUnary::Len: - if (arg.type == Constant::Type_String) - { - result.type = Constant::Type_Number; - result.valueNumber = double(arg.valueString.size); - } - break; - - default: - LUAU_ASSERT(!"Unexpected unary operation"); - } - } - - bool constantsEqual(const Constant& la, const Constant& ra) - { - LUAU_ASSERT(la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown); - - switch (la.type) - { - case Constant::Type_Nil: - return ra.type == Constant::Type_Nil; - - case Constant::Type_Boolean: - return ra.type == Constant::Type_Boolean && la.valueBoolean == ra.valueBoolean; - - case Constant::Type_Number: - return ra.type == Constant::Type_Number && la.valueNumber == ra.valueNumber; - - case Constant::Type_String: - return ra.type == Constant::Type_String && la.valueString.size == ra.valueString.size && - memcmp(la.valueString.data, ra.valueString.data, la.valueString.size) == 0; - - default: - LUAU_ASSERT(!"Unexpected constant type in comparison"); - return false; - } - } - - void analyzeBinary(Constant& result, AstExprBinary::Op op, const Constant& la, const Constant& ra) - { - switch (op) - { - case AstExprBinary::Add: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = la.valueNumber + ra.valueNumber; - } - break; - - case AstExprBinary::Sub: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = la.valueNumber - ra.valueNumber; - } - break; - - case AstExprBinary::Mul: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = la.valueNumber * ra.valueNumber; - } - break; - - case AstExprBinary::Div: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = la.valueNumber / ra.valueNumber; - } - break; - - case AstExprBinary::Mod: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = la.valueNumber - floor(la.valueNumber / ra.valueNumber) * ra.valueNumber; - } - break; - - case AstExprBinary::Pow: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Number; - result.valueNumber = pow(la.valueNumber, ra.valueNumber); - } - break; - - case AstExprBinary::Concat: - break; - - case AstExprBinary::CompareNe: - if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = !constantsEqual(la, ra); - } - break; - - case AstExprBinary::CompareEq: - if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = constantsEqual(la, ra); - } - break; - - case AstExprBinary::CompareLt: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = la.valueNumber < ra.valueNumber; - } - break; - - case AstExprBinary::CompareLe: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = la.valueNumber <= ra.valueNumber; - } - break; - - case AstExprBinary::CompareGt: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = la.valueNumber > ra.valueNumber; - } - break; - - case AstExprBinary::CompareGe: - if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = la.valueNumber >= ra.valueNumber; - } - break; - - case AstExprBinary::And: - if (la.type != Constant::Type_Unknown) - { - result = la.isTruthful() ? ra : la; - } - break; - - case AstExprBinary::Or: - if (la.type != Constant::Type_Unknown) - { - result = la.isTruthful() ? la : ra; - } - break; - - default: - LUAU_ASSERT(!"Unexpected binary operation"); - } - } - - Constant analyze(AstExpr* node) - { - Constant result; - result.type = Constant::Type_Unknown; - - if (AstExprGroup* expr = node->as()) - { - result = analyze(expr->expr); - } - else if (node->is()) - { - result.type = Constant::Type_Nil; - } - else if (AstExprConstantBool* expr = node->as()) - { - result.type = Constant::Type_Boolean; - result.valueBoolean = expr->value; - } - else if (AstExprConstantNumber* expr = node->as()) - { - result.type = Constant::Type_Number; - result.valueNumber = expr->value; - } - else if (AstExprConstantString* expr = node->as()) - { - result.type = Constant::Type_String; - result.valueString = expr->value; - } - else if (AstExprLocal* expr = node->as()) - { - const Local* l = self->locals.find(expr->local); - - if (l && l->constant.type != Constant::Type_Unknown) - { - LUAU_ASSERT(!l->written); - result = l->constant; - } - } - else if (node->is()) - { - // nope - } - else if (node->is()) - { - // nope - } - else if (AstExprCall* expr = node->as()) - { - analyze(expr->func); - - for (size_t i = 0; i < expr->args.size; ++i) - analyze(expr->args.data[i]); - } - else if (AstExprIndexName* expr = node->as()) - { - analyze(expr->expr); - } - else if (AstExprIndexExpr* expr = node->as()) - { - analyze(expr->expr); - analyze(expr->index); - } - else if (AstExprFunction* expr = node->as()) - { - // this is necessary to propagate constant information in all child functions - expr->body->visit(this); - } - else if (AstExprTable* expr = node->as()) - { - for (size_t i = 0; i < expr->items.size; ++i) - { - const AstExprTable::Item& item = expr->items.data[i]; - - if (item.key) - analyze(item.key); - - analyze(item.value); - } - } - else if (AstExprUnary* expr = node->as()) - { - Constant arg = analyze(expr->expr); - - analyzeUnary(result, expr->op, arg); - } - else if (AstExprBinary* expr = node->as()) - { - Constant la = analyze(expr->left); - Constant ra = analyze(expr->right); - - analyzeBinary(result, expr->op, la, ra); - } - else if (AstExprTypeAssertion* expr = node->as()) - { - Constant arg = analyze(expr->expr); - - result = arg; - } - else if (AstExprIfElse* expr = node->as(); FFlag::LuauIfElseExpressionBaseSupport && expr) - { - Constant cond = analyze(expr->condition); - Constant trueExpr = analyze(expr->trueExpr); - Constant falseExpr = analyze(expr->falseExpr); - if (cond.type != Constant::Type_Unknown) - { - result = cond.isTruthful() ? trueExpr : falseExpr; - } - } - else - { - LUAU_ASSERT(!"Unknown expression type"); - } - - if (result.type != Constant::Type_Unknown) - self->constants[node] = result; - - return result; - } - - bool visit(AstExpr* node) override - { - // note: we short-circuit the visitor traversal through any expression trees by returning false - // recursive traversal is happening inside analyze() which makes it easier to get the resulting value of the subexpression - analyze(node); - - return false; - } - - bool visit(AstStatLocal* node) override - { - // for values that match 1-1 we record the initializing expression for future analysis - for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i) - { - Local& l = self->locals[node->vars.data[i]]; - - l.init = node->values.data[i]; - } - - // all values that align wrt indexing are simple - we just match them 1-1 - for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i) - { - Constant arg = analyze(node->values.data[i]); - - if (arg.type != Constant::Type_Unknown) - { - Local& l = self->locals[node->vars.data[i]]; - - // note: we rely on AssignmentVisitor to have been run before us - if (!l.written) - l.constant = arg; - } - } - - if (node->vars.size > node->values.size) - { - // if we have trailing variables, then depending on whether the last value is capable of returning multiple values - // (aka call or varargs), we either don't know anything about these vars, or we know they're nil - AstExpr* last = node->values.size ? node->values.data[node->values.size - 1] : nullptr; - bool multRet = last && (last->is() || last->is()); - - for (size_t i = node->values.size; i < node->vars.size; ++i) - { - if (!multRet) - { - Local& l = self->locals[node->vars.data[i]]; - - // note: we rely on AssignmentVisitor to have been run before us - if (!l.written) - { - l.constant.type = Constant::Type_Nil; - } - } - } - } - else - { - // we can have more values than variables; in this case we still need to analyze them to make sure we do constant propagation inside - // them - for (size_t i = node->vars.size; i < node->values.size; ++i) - analyze(node->values.data[i]); - } - - return false; - } - - Compiler* self; - }; - struct FenvVisitor : AstVisitor { bool& getfenvUsed; @@ -3283,14 +2771,6 @@ struct Compiler return false; } - - bool visit(AstStatLocalFunction* node) override - { - // record local->function association for some optimizations - self->locals[node->name].func = node->func; - - return true; - } }; struct UndefinedLocalVisitor : AstVisitor @@ -3397,70 +2877,12 @@ struct Compiler std::vector upvals; }; - struct Constant - { - enum Type - { - Type_Unknown, - Type_Nil, - Type_Boolean, - Type_Number, - Type_String, - }; - - Type type = Type_Unknown; - - union - { - bool valueBoolean; - double valueNumber; - AstArray valueString = {}; - }; - - bool isTruthful() const - { - LUAU_ASSERT(type != Type_Unknown); - return type != Type_Nil && !(type == Type_Boolean && valueBoolean == false); - } - }; - struct Local { uint8_t reg = 0; bool allocated = false; bool captured = false; - bool written = false; - AstExpr* init = nullptr; uint32_t debugpc = 0; - Constant constant; - AstExprFunction* func = nullptr; - }; - - struct Global - { - bool writable = false; - bool written = false; - }; - - struct Builtin - { - AstName object; - AstName method; - - bool empty() const - { - return object == AstName() && method == AstName(); - } - - bool isGlobal(const char* name) const - { - return object == AstName() && method == name; - } - - bool isMethod(const char* table, const char* name) const - { - return object == table && method == name; - } }; struct LoopJump @@ -3482,194 +2904,6 @@ struct Compiler AstExpr* untilCondition; }; - Builtin getBuiltin(AstExpr* node) - { - if (AstExprLocal* expr = node->as()) - { - Local* l = locals.find(expr->local); - - return l && !l->written && l->init ? getBuiltin(l->init) : Builtin(); - } - else if (AstExprIndexName* expr = node->as()) - { - if (AstExprGlobal* object = expr->expr->as()) - { - Global* g = globals.find(object->name); - - return !g || (!g->writable && !g->written) ? Builtin{object->name, expr->index} : Builtin(); - } - else - { - return Builtin(); - } - } - else if (AstExprGlobal* expr = node->as()) - { - Global* g = globals.find(expr->name); - - return !g || !g->written ? Builtin{AstName(), expr->name} : Builtin(); - } - else - { - return Builtin(); - } - } - - int getBuiltinFunctionId(const Builtin& builtin) - { - if (builtin.empty()) - return -1; - - if (builtin.isGlobal("assert")) - return LBF_ASSERT; - - if (builtin.isGlobal("type")) - return LBF_TYPE; - - if (builtin.isGlobal("typeof")) - return LBF_TYPEOF; - - if (builtin.isGlobal("rawset")) - return LBF_RAWSET; - if (builtin.isGlobal("rawget")) - return LBF_RAWGET; - if (builtin.isGlobal("rawequal")) - return LBF_RAWEQUAL; - - if (builtin.isGlobal("unpack")) - return LBF_TABLE_UNPACK; - - if (builtin.object == "math") - { - if (builtin.method == "abs") - return LBF_MATH_ABS; - if (builtin.method == "acos") - return LBF_MATH_ACOS; - if (builtin.method == "asin") - return LBF_MATH_ASIN; - if (builtin.method == "atan2") - return LBF_MATH_ATAN2; - if (builtin.method == "atan") - return LBF_MATH_ATAN; - if (builtin.method == "ceil") - return LBF_MATH_CEIL; - if (builtin.method == "cosh") - return LBF_MATH_COSH; - if (builtin.method == "cos") - return LBF_MATH_COS; - if (builtin.method == "deg") - return LBF_MATH_DEG; - if (builtin.method == "exp") - return LBF_MATH_EXP; - if (builtin.method == "floor") - return LBF_MATH_FLOOR; - if (builtin.method == "fmod") - return LBF_MATH_FMOD; - if (builtin.method == "frexp") - return LBF_MATH_FREXP; - if (builtin.method == "ldexp") - return LBF_MATH_LDEXP; - if (builtin.method == "log10") - return LBF_MATH_LOG10; - if (builtin.method == "log") - return LBF_MATH_LOG; - if (builtin.method == "max") - return LBF_MATH_MAX; - if (builtin.method == "min") - return LBF_MATH_MIN; - if (builtin.method == "modf") - return LBF_MATH_MODF; - if (builtin.method == "pow") - return LBF_MATH_POW; - if (builtin.method == "rad") - return LBF_MATH_RAD; - if (builtin.method == "sinh") - return LBF_MATH_SINH; - if (builtin.method == "sin") - return LBF_MATH_SIN; - if (builtin.method == "sqrt") - return LBF_MATH_SQRT; - if (builtin.method == "tanh") - return LBF_MATH_TANH; - if (builtin.method == "tan") - return LBF_MATH_TAN; - if (builtin.method == "clamp") - return LBF_MATH_CLAMP; - if (builtin.method == "sign") - return LBF_MATH_SIGN; - if (builtin.method == "round") - return LBF_MATH_ROUND; - } - - if (builtin.object == "bit32") - { - if (builtin.method == "arshift") - return LBF_BIT32_ARSHIFT; - if (builtin.method == "band") - return LBF_BIT32_BAND; - if (builtin.method == "bnot") - return LBF_BIT32_BNOT; - if (builtin.method == "bor") - return LBF_BIT32_BOR; - if (builtin.method == "bxor") - return LBF_BIT32_BXOR; - if (builtin.method == "btest") - return LBF_BIT32_BTEST; - if (builtin.method == "extract") - return LBF_BIT32_EXTRACT; - if (builtin.method == "lrotate") - return LBF_BIT32_LROTATE; - if (builtin.method == "lshift") - return LBF_BIT32_LSHIFT; - if (builtin.method == "replace") - return LBF_BIT32_REPLACE; - if (builtin.method == "rrotate") - return LBF_BIT32_RROTATE; - if (builtin.method == "rshift") - return LBF_BIT32_RSHIFT; - if (builtin.method == "countlz") - return LBF_BIT32_COUNTLZ; - if (builtin.method == "countrz") - return LBF_BIT32_COUNTRZ; - } - - if (builtin.object == "string") - { - if (builtin.method == "byte") - return LBF_STRING_BYTE; - if (builtin.method == "char") - return LBF_STRING_CHAR; - if (builtin.method == "len") - return LBF_STRING_LEN; - if (builtin.method == "sub") - return LBF_STRING_SUB; - } - - if (builtin.object == "table") - { - if (builtin.method == "insert") - return LBF_TABLE_INSERT; - if (builtin.method == "unpack") - return LBF_TABLE_UNPACK; - } - - if (options.vectorCtor) - { - if (options.vectorLib) - { - if (builtin.isMethod(options.vectorLib, options.vectorCtor)) - return LBF_VECTOR; - } - else - { - if (builtin.isGlobal(options.vectorCtor)) - return LBF_VECTOR; - } - } - - return -1; - } - BytecodeBuilder& bytecode; CompileOptions options; @@ -3677,8 +2911,9 @@ struct Compiler DenseHashMap functions; DenseHashMap locals; DenseHashMap globals; + DenseHashMap variables; DenseHashMap constants; - DenseHashMap> predictedTableSize; + DenseHashMap tableShapes; unsigned int regTop = 0; unsigned int stackSize = 0; @@ -3699,23 +2934,18 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName Compiler compiler(bytecode, options); // since access to some global objects may result in values that change over time, we block imports from non-readonly tables - if (AstName name = names.get("_G"); name.value) - compiler.globals[name].writable = true; + assignMutable(compiler.globals, names, options.mutableGlobals); - if (options.mutableGlobals) - for (const char** ptr = options.mutableGlobals; *ptr; ++ptr) - if (AstName name = names.get(*ptr); name.value) - compiler.globals[name].writable = true; + // this pass analyzes mutability of locals/globals and associates locals with their initial values + trackValues(compiler.globals, compiler.variables, root); - // this visitor traverses the AST to analyze mutability of locals/globals, filling Local::written and Global::written - Compiler::AssignmentVisitor assignmentVisitor(&compiler); - root->visit(&assignmentVisitor); - - // this visitor traverses the AST to analyze constantness of expressions, filling constants[] and Local::constant/Local::init if (options.optimizationLevel >= 1) { - Compiler::ConstantVisitor constantVisitor(&compiler); - root->visit(&constantVisitor); + // this pass analyzes constantness of expressions + foldConstants(compiler.constants, compiler.variables, root); + + // this pass analyzes table assignments to estimate table shapes for initially empty tables + predictTableShapes(compiler.tableShapes, root); } // this visitor tracks calls to getfenv/setfenv and disables some optimizations when they are found @@ -3734,8 +2964,8 @@ void compileOrThrow(BytecodeBuilder& bytecode, AstStatBlock* root, const AstName for (AstExprFunction* expr : functions) compiler.compileFunction(expr); - AstExprFunction main(root->location, /*generics= */ AstArray(), /*genericPacks= */ AstArray(), /* self= */ nullptr, - AstArray(), /* vararg= */ Luau::Location(), root, /* functionDepth= */ 0, /* debugname= */ AstName()); + AstExprFunction main(root->location, /*generics= */ AstArray(), /*genericPacks= */ AstArray(), + /* self= */ nullptr, AstArray(), /* vararg= */ Luau::Location(), root, /* functionDepth= */ 0, /* debugname= */ AstName()); uint32_t mainid = compiler.compileFunction(&main); bytecode.setMainFunction(mainid); diff --git a/Compiler/src/ConstantFolding.cpp b/Compiler/src/ConstantFolding.cpp new file mode 100644 index 000000000..60a7c1692 --- /dev/null +++ b/Compiler/src/ConstantFolding.cpp @@ -0,0 +1,394 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "ConstantFolding.h" + +#include + +namespace Luau +{ +namespace Compile +{ + +static bool constantsEqual(const Constant& la, const Constant& ra) +{ + LUAU_ASSERT(la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown); + + switch (la.type) + { + case Constant::Type_Nil: + return ra.type == Constant::Type_Nil; + + case Constant::Type_Boolean: + return ra.type == Constant::Type_Boolean && la.valueBoolean == ra.valueBoolean; + + case Constant::Type_Number: + return ra.type == Constant::Type_Number && la.valueNumber == ra.valueNumber; + + case Constant::Type_String: + return ra.type == Constant::Type_String && la.stringLength == ra.stringLength && memcmp(la.valueString, ra.valueString, la.stringLength) == 0; + + default: + LUAU_ASSERT(!"Unexpected constant type in comparison"); + return false; + } +} + +static void foldUnary(Constant& result, AstExprUnary::Op op, const Constant& arg) +{ + switch (op) + { + case AstExprUnary::Not: + if (arg.type != Constant::Type_Unknown) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = !arg.isTruthful(); + } + break; + + case AstExprUnary::Minus: + if (arg.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = -arg.valueNumber; + } + break; + + case AstExprUnary::Len: + if (arg.type == Constant::Type_String) + { + result.type = Constant::Type_Number; + result.valueNumber = double(arg.stringLength); + } + break; + + default: + LUAU_ASSERT(!"Unexpected unary operation"); + } +} + +static void foldBinary(Constant& result, AstExprBinary::Op op, const Constant& la, const Constant& ra) +{ + switch (op) + { + case AstExprBinary::Add: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber + ra.valueNumber; + } + break; + + case AstExprBinary::Sub: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber - ra.valueNumber; + } + break; + + case AstExprBinary::Mul: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber * ra.valueNumber; + } + break; + + case AstExprBinary::Div: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber / ra.valueNumber; + } + break; + + case AstExprBinary::Mod: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = la.valueNumber - floor(la.valueNumber / ra.valueNumber) * ra.valueNumber; + } + break; + + case AstExprBinary::Pow: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Number; + result.valueNumber = pow(la.valueNumber, ra.valueNumber); + } + break; + + case AstExprBinary::Concat: + break; + + case AstExprBinary::CompareNe: + if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = !constantsEqual(la, ra); + } + break; + + case AstExprBinary::CompareEq: + if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = constantsEqual(la, ra); + } + break; + + case AstExprBinary::CompareLt: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = la.valueNumber < ra.valueNumber; + } + break; + + case AstExprBinary::CompareLe: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = la.valueNumber <= ra.valueNumber; + } + break; + + case AstExprBinary::CompareGt: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = la.valueNumber > ra.valueNumber; + } + break; + + case AstExprBinary::CompareGe: + if (la.type == Constant::Type_Number && ra.type == Constant::Type_Number) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = la.valueNumber >= ra.valueNumber; + } + break; + + case AstExprBinary::And: + if (la.type != Constant::Type_Unknown) + { + result = la.isTruthful() ? ra : la; + } + break; + + case AstExprBinary::Or: + if (la.type != Constant::Type_Unknown) + { + result = la.isTruthful() ? la : ra; + } + break; + + default: + LUAU_ASSERT(!"Unexpected binary operation"); + } +} + +struct ConstantVisitor : AstVisitor +{ + DenseHashMap& constants; + DenseHashMap& variables; + + DenseHashMap locals; + + ConstantVisitor(DenseHashMap& constants, DenseHashMap& variables) + : constants(constants) + , variables(variables) + , locals(nullptr) + { + } + + Constant analyze(AstExpr* node) + { + Constant result; + result.type = Constant::Type_Unknown; + + if (AstExprGroup* expr = node->as()) + { + result = analyze(expr->expr); + } + else if (node->is()) + { + result.type = Constant::Type_Nil; + } + else if (AstExprConstantBool* expr = node->as()) + { + result.type = Constant::Type_Boolean; + result.valueBoolean = expr->value; + } + else if (AstExprConstantNumber* expr = node->as()) + { + result.type = Constant::Type_Number; + result.valueNumber = expr->value; + } + else if (AstExprConstantString* expr = node->as()) + { + result.type = Constant::Type_String; + result.valueString = expr->value.data; + result.stringLength = unsigned(expr->value.size); + } + else if (AstExprLocal* expr = node->as()) + { + const Constant* l = locals.find(expr->local); + + if (l) + result = *l; + } + else if (node->is()) + { + // nope + } + else if (node->is()) + { + // nope + } + else if (AstExprCall* expr = node->as()) + { + analyze(expr->func); + + for (size_t i = 0; i < expr->args.size; ++i) + analyze(expr->args.data[i]); + } + else if (AstExprIndexName* expr = node->as()) + { + analyze(expr->expr); + } + else if (AstExprIndexExpr* expr = node->as()) + { + analyze(expr->expr); + analyze(expr->index); + } + else if (AstExprFunction* expr = node->as()) + { + // this is necessary to propagate constant information in all child functions + expr->body->visit(this); + } + else if (AstExprTable* expr = node->as()) + { + for (size_t i = 0; i < expr->items.size; ++i) + { + const AstExprTable::Item& item = expr->items.data[i]; + + if (item.key) + analyze(item.key); + + analyze(item.value); + } + } + else if (AstExprUnary* expr = node->as()) + { + Constant arg = analyze(expr->expr); + + if (arg.type != Constant::Type_Unknown) + foldUnary(result, expr->op, arg); + } + else if (AstExprBinary* expr = node->as()) + { + Constant la = analyze(expr->left); + Constant ra = analyze(expr->right); + + if (la.type != Constant::Type_Unknown && ra.type != Constant::Type_Unknown) + foldBinary(result, expr->op, la, ra); + } + else if (AstExprTypeAssertion* expr = node->as()) + { + Constant arg = analyze(expr->expr); + + result = arg; + } + else if (AstExprIfElse* expr = node->as()) + { + Constant cond = analyze(expr->condition); + Constant trueExpr = analyze(expr->trueExpr); + Constant falseExpr = analyze(expr->falseExpr); + + if (cond.type != Constant::Type_Unknown) + result = cond.isTruthful() ? trueExpr : falseExpr; + } + else + { + LUAU_ASSERT(!"Unknown expression type"); + } + + if (result.type != Constant::Type_Unknown) + constants[node] = result; + + return result; + } + + bool visit(AstExpr* node) override + { + // note: we short-circuit the visitor traversal through any expression trees by returning false + // recursive traversal is happening inside analyze() which makes it easier to get the resulting value of the subexpression + analyze(node); + + return false; + } + + bool visit(AstStatLocal* node) override + { + // all values that align wrt indexing are simple - we just match them 1-1 + for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i) + { + Constant arg = analyze(node->values.data[i]); + + if (arg.type != Constant::Type_Unknown) + { + // note: we rely on trackValues to have been run before us + Variable* v = variables.find(node->vars.data[i]); + LUAU_ASSERT(v); + + if (!v->written) + { + locals[node->vars.data[i]] = arg; + v->constant = true; + } + } + } + + if (node->vars.size > node->values.size) + { + // if we have trailing variables, then depending on whether the last value is capable of returning multiple values + // (aka call or varargs), we either don't know anything about these vars, or we know they're nil + AstExpr* last = node->values.size ? node->values.data[node->values.size - 1] : nullptr; + bool multRet = last && (last->is() || last->is()); + + if (!multRet) + { + for (size_t i = node->values.size; i < node->vars.size; ++i) + { + // note: we rely on trackValues to have been run before us + Variable* v = variables.find(node->vars.data[i]); + LUAU_ASSERT(v); + + if (!v->written) + { + locals[node->vars.data[i]].type = Constant::Type_Nil; + v->constant = true; + } + } + } + } + else + { + // we can have more values than variables; in this case we still need to analyze them to make sure we do constant propagation inside + // them + for (size_t i = node->vars.size; i < node->values.size; ++i) + analyze(node->values.data[i]); + } + + return false; + } +}; + +void foldConstants(DenseHashMap& constants, DenseHashMap& variables, AstNode* root) +{ + ConstantVisitor visitor{constants, variables}; + root->visit(&visitor); +} + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/ConstantFolding.h b/Compiler/src/ConstantFolding.h new file mode 100644 index 000000000..c0e63539b --- /dev/null +++ b/Compiler/src/ConstantFolding.h @@ -0,0 +1,48 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "ValueTracking.h" + +namespace Luau +{ +namespace Compile +{ + +struct Constant +{ + enum Type + { + Type_Unknown, + Type_Nil, + Type_Boolean, + Type_Number, + Type_String, + }; + + Type type = Type_Unknown; + unsigned int stringLength = 0; + + union + { + bool valueBoolean; + double valueNumber; + char* valueString = nullptr; // length stored in stringLength + }; + + bool isTruthful() const + { + LUAU_ASSERT(type != Type_Unknown); + return type != Type_Nil && !(type == Type_Boolean && valueBoolean == false); + } + + AstArray getString() const + { + LUAU_ASSERT(type == Type_String); + return {valueString, stringLength}; + } +}; + +void foldConstants(DenseHashMap& constants, DenseHashMap& variables, AstNode* root); + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/TableShape.cpp b/Compiler/src/TableShape.cpp new file mode 100644 index 000000000..7d99f2228 --- /dev/null +++ b/Compiler/src/TableShape.cpp @@ -0,0 +1,129 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "TableShape.h" + +namespace Luau +{ +namespace Compile +{ + +static AstExprTable* getTableHint(AstExpr* expr) +{ + // unadorned table literal + if (AstExprTable* table = expr->as()) + return table; + + // setmetatable(table literal, ...) + if (AstExprCall* call = expr->as(); call && !call->self && call->args.size == 2) + if (AstExprGlobal* func = call->func->as(); func && func->name == "setmetatable") + if (AstExprTable* table = call->args.data[0]->as()) + return table; + + return nullptr; +} + +struct ShapeVisitor : AstVisitor +{ + struct Hasher + { + size_t operator()(const std::pair& p) const + { + return std::hash()(p.first) ^ std::hash()(p.second); + } + }; + + DenseHashMap& shapes; + + DenseHashMap tables; + DenseHashSet, Hasher> fields; + + ShapeVisitor(DenseHashMap& shapes) + : shapes(shapes) + , tables(nullptr) + , fields(std::pair()) + { + } + + void assignField(AstExpr* expr, AstName index) + { + if (AstExprLocal* lv = expr->as()) + { + if (AstExprTable** table = tables.find(lv->local)) + { + std::pair field = {*table, index}; + + if (!fields.contains(field)) + { + fields.insert(field); + shapes[*table].hashSize += 1; + } + } + } + } + + void assignField(AstExpr* expr, AstExpr* index) + { + AstExprLocal* lv = expr->as(); + AstExprConstantNumber* number = index->as(); + + if (lv && number) + { + if (AstExprTable** table = tables.find(lv->local)) + { + TableShape& shape = shapes[*table]; + + if (number->value == double(shape.arraySize + 1)) + shape.arraySize += 1; + } + } + } + + void assign(AstExpr* var) + { + if (AstExprIndexName* index = var->as()) + { + assignField(index->expr, index->index); + } + else if (AstExprIndexExpr* index = var->as()) + { + assignField(index->expr, index->index); + } + } + + bool visit(AstStatLocal* node) override + { + // track local -> table association so that we can update table size prediction in assignField + if (node->vars.size == 1 && node->values.size == 1) + if (AstExprTable* table = getTableHint(node->values.data[0]); table && table->items.size == 0) + tables[node->vars.data[0]] = table; + + return true; + } + + bool visit(AstStatAssign* node) override + { + for (size_t i = 0; i < node->vars.size; ++i) + assign(node->vars.data[i]); + + for (size_t i = 0; i < node->values.size; ++i) + node->values.data[i]->visit(this); + + return false; + } + + bool visit(AstStatFunction* node) override + { + assign(node->name); + node->func->visit(this); + + return false; + } +}; + +void predictTableShapes(DenseHashMap& shapes, AstNode* root) +{ + ShapeVisitor visitor{shapes}; + root->visit(&visitor); +} + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/TableShape.h b/Compiler/src/TableShape.h new file mode 100644 index 000000000..f30853a7f --- /dev/null +++ b/Compiler/src/TableShape.h @@ -0,0 +1,21 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/DenseHash.h" + +namespace Luau +{ +namespace Compile +{ + +struct TableShape +{ + unsigned int arraySize = 0; + unsigned int hashSize = 0; +}; + +void predictTableShapes(DenseHashMap& shapes, AstNode* root); + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/ValueTracking.cpp b/Compiler/src/ValueTracking.cpp new file mode 100644 index 000000000..0bfaf9b3c --- /dev/null +++ b/Compiler/src/ValueTracking.cpp @@ -0,0 +1,103 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "ValueTracking.h" + +#include "Luau/Lexer.h" + +namespace Luau +{ +namespace Compile +{ + +struct ValueVisitor : AstVisitor +{ + DenseHashMap& globals; + DenseHashMap& variables; + + ValueVisitor(DenseHashMap& globals, DenseHashMap& variables) + : globals(globals) + , variables(variables) + { + } + + void assign(AstExpr* var) + { + if (AstExprLocal* lv = var->as()) + { + variables[lv->local].written = true; + } + else if (AstExprGlobal* gv = var->as()) + { + globals[gv->name] = Global::Written; + } + else + { + // we need to be able to track assignments in all expressions, including crazy ones like t[function() t = nil end] = 5 + var->visit(this); + } + } + + bool visit(AstStatLocal* node) override + { + for (size_t i = 0; i < node->vars.size && i < node->values.size; ++i) + variables[node->vars.data[i]].init = node->values.data[i]; + + for (size_t i = node->values.size; i < node->vars.size; ++i) + variables[node->vars.data[i]].init = nullptr; + + return true; + } + + bool visit(AstStatAssign* node) override + { + for (size_t i = 0; i < node->vars.size; ++i) + assign(node->vars.data[i]); + + for (size_t i = 0; i < node->values.size; ++i) + node->values.data[i]->visit(this); + + return false; + } + + bool visit(AstStatCompoundAssign* node) override + { + assign(node->var); + node->value->visit(this); + + return false; + } + + bool visit(AstStatLocalFunction* node) override + { + variables[node->name].init = node->func; + + return true; + } + + bool visit(AstStatFunction* node) override + { + assign(node->name); + node->func->visit(this); + + return false; + } +}; + +void assignMutable(DenseHashMap& globals, const AstNameTable& names, const char** mutableGlobals) +{ + if (AstName name = names.get("_G"); name.value) + globals[name] = Global::Mutable; + + if (mutableGlobals) + for (const char** ptr = mutableGlobals; *ptr; ++ptr) + if (AstName name = names.get(*ptr); name.value) + globals[name] = Global::Mutable; +} + +void trackValues(DenseHashMap& globals, DenseHashMap& variables, AstNode* root) +{ + ValueVisitor visitor{globals, variables}; + root->visit(&visitor); +} + +} // namespace Compile +} // namespace Luau diff --git a/Compiler/src/ValueTracking.h b/Compiler/src/ValueTracking.h new file mode 100644 index 000000000..fc74c84aa --- /dev/null +++ b/Compiler/src/ValueTracking.h @@ -0,0 +1,42 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Ast.h" +#include "Luau/DenseHash.h" + +namespace Luau +{ +class AstNameTable; +} + +namespace Luau +{ +namespace Compile +{ + +enum class Global +{ + Default = 0, + Mutable, // builtin that has contents unknown at compile time, blocks GETIMPORT for chains + Written, // written in the code which means we can't reason about the value +}; + +struct Variable +{ + AstExpr* init = nullptr; // initial value of the variable; filled by trackValues + bool written = false; // is the variable ever assigned to? filled by trackValues + bool constant = false; // is the variable's value a compile-time constant? filled by constantFold +}; + +void assignMutable(DenseHashMap& globals, const AstNameTable& names, const char** mutableGlobals); +void trackValues(DenseHashMap& globals, DenseHashMap& variables, AstNode* root); + +inline Global getGlobalState(const DenseHashMap& globals, AstName name) +{ + const Global* it = globals.find(name); + + return it ? *it : Global::Default; +} + +} // namespace Compile +} // namespace Luau diff --git a/Sources.cmake b/Sources.cmake index 5dd486aaa..bafe75948 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -29,7 +29,15 @@ target_sources(Luau.Compiler PRIVATE Compiler/src/BytecodeBuilder.cpp Compiler/src/Compiler.cpp + Compiler/src/Builtins.cpp + Compiler/src/ConstantFolding.cpp + Compiler/src/TableShape.cpp + Compiler/src/ValueTracking.cpp Compiler/src/lcode.cpp + Compiler/src/Builtins.h + Compiler/src/ConstantFolding.h + Compiler/src/TableShape.h + Compiler/src/ValueTracking.h ) # Luau.Analysis Sources diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index eb47971a8..3cce7665d 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,7 +17,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauCcallRestoreFix, false) LUAU_FASTFLAG(LuauCoroutineClose) /* @@ -545,11 +544,8 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e if (!oldactive) resetbit(L->stackstate, THREAD_ACTIVEBIT); - if (FFlag::LuauCcallRestoreFix) - { - // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. - L->nCcalls = oldnCcalls; - } + // Restore nCcalls before calling the debugprotectederror callback which may rely on the proper value to have been restored. + L->nCcalls = oldnCcalls; // an error occurred, check if we have a protected error callback if (L->global->cb.debugprotectederror) @@ -564,10 +560,6 @@ int luaD_pcall(lua_State* L, Pfunc func, void* u, ptrdiff_t old_top, ptrdiff_t e StkId oldtop = restorestack(L, old_top); luaF_close(L, oldtop); /* close eventual pending closures */ seterrorobj(L, status, oldtop); - if (!FFlag::LuauCcallRestoreFix) - { - L->nCcalls = oldnCcalls; - } L->ci = restoreci(L, old_ci); L->base = L->ci->base; restore_stack_limit(L); diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 648785697..4178eda40 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -6,6 +6,8 @@ #include "lmem.h" #include "lgc.h" +LUAU_FASTFLAGVARIABLE(LuauNoDirectUpvalRemoval, false) + Proto* luaF_newproto(lua_State* L) { Proto* f = luaM_new(L, Proto, sizeof(Proto), L->activememcat); @@ -113,14 +115,16 @@ void luaF_freeupval(lua_State* L, UpVal* uv) void luaF_close(lua_State* L, StkId level) { UpVal* uv; - global_State* g = L->global; + global_State* g = L->global; // TODO: remove with FFlagLuauNoDirectUpvalRemoval while (L->openupval != NULL && (uv = gco2uv(L->openupval))->v >= level) { GCObject* o = obj2gco(uv); LUAU_ASSERT(!isblack(o) && uv->v != &uv->u.value); L->openupval = uv->next; /* remove from `open' list */ - if (isdead(g, o)) + if (!FFlag::LuauNoDirectUpvalRemoval && isdead(g, o)) + { luaF_freeupval(L, uv); /* free upvalue */ + } else { unlinkupval(uv); diff --git a/VM/src/lstrlib.cpp b/VM/src/lstrlib.cpp index 0b3054ae4..74a8aa8a1 100644 --- a/VM/src/lstrlib.cpp +++ b/VM/src/lstrlib.cpp @@ -8,8 +8,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauStrPackUBCastFix, false) - /* macro to `unsign' a character */ #define uchar(c) ((unsigned char)(c)) @@ -1406,20 +1404,10 @@ static int str_pack(lua_State* L) } case Kuint: { /* unsigned integers */ - if (FFlag::LuauStrPackUBCastFix) - { - long long n = (long long)luaL_checknumber(L, arg); - if (size < SZINT) /* need overflow check? */ - luaL_argcheck(L, (unsigned long long)n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); - packint(&b, (unsigned long long)n, h.islittle, size, 0); - } - else - { - unsigned long long n = (unsigned long long)luaL_checknumber(L, arg); - if (size < SZINT) /* need overflow check? */ - luaL_argcheck(L, n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); - packint(&b, n, h.islittle, size, 0); - } + long long n = (long long)luaL_checknumber(L, arg); + if (size < SZINT) /* need overflow check? */ + luaL_argcheck(L, (unsigned long long)n < ((unsigned long long)1 << (size * NB)), arg, "unsigned overflow"); + packint(&b, (unsigned long long)n, h.islittle, size, 0); break; } case Kfloat: diff --git a/bench/tests/sunspider/3d-cube.lua b/bench/tests/sunspider/3d-cube.lua index 6d40406cd..5d162ab99 100644 --- a/bench/tests/sunspider/3d-cube.lua +++ b/bench/tests/sunspider/3d-cube.lua @@ -111,15 +111,10 @@ end -- multiplies two matrices function MMulti(M1, M2) local M = {{},{},{},{}}; - local i = 1; - local j = 1; - while i <= 4 do - j = 1; - while j <= 4 do - M[i][j] = M1[i][1] * M2[1][j] + M1[i][2] * M2[2][j] + M1[i][3] * M2[3][j] + M1[i][4] * M2[4][j]; j = j + 1 + for i = 1,4 do + for j = 1,4 do + M[i][j] = M1[i][1] * M2[1][j] + M1[i][2] * M2[2][j] + M1[i][3] * M2[3][j] + M1[i][4] * M2[4][j]; end - - i = i + 1 end return M; end @@ -127,28 +122,27 @@ end -- multiplies matrix with vector function VMulti(M, V) local Vect = {}; - local i = 1; - while i <= 4 do Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3] + M[i][4] * V[4]; i = i + 1 end + for i = 1,4 do + Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3] + M[i][4] * V[4]; + end return Vect; end function VMulti2(M, V) local Vect = {}; - local i = 1; - while i < 4 do Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3]; i = i + 1 end + for i = 1,3 do + Vect[i] = M[i][1] * V[1] + M[i][2] * V[2] + M[i][3] * V[3]; + end return Vect; end -- add to matrices function MAdd(M1, M2) local M = {{},{},{},{}}; - local i = 1; - local j = 1; - while i <= 4 do - j = 1; - while j <= 4 do M[i][j] = M1[i][j] + M2[i][j]; j = j + 1 end - - i = i + 1 + for i = 1,4 do + for j = 1,4 do + M[i][j] = M1[i][j] + M2[i][j]; + end end return M; end diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 210db7eea..211e1be1f 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1938,7 +1938,6 @@ return target(b@1 TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses") { ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); - ScopedFastFlag luauAutocompletePreferToCallFunctions("LuauAutocompletePreferToCallFunctions", true); check(R"( local function bar(a: number) return -a end @@ -1954,7 +1953,6 @@ local abc = b@1 TEST_CASE_FIXTURE(ACFixture, "function_result_passed_to_function_has_parentheses") { ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); - ScopedFastFlag luauAutocompletePreferToCallFunctions("LuauAutocompletePreferToCallFunctions", true); check(R"( local function foo() return 1 end @@ -2538,10 +2536,6 @@ TEST_CASE("autocomplete_documentation_symbols") TEST_CASE_FIXTURE(ACFixture, "autocomplete_ifelse_expressions") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - - { check(R"( local temp = false local even = true; @@ -2614,7 +2608,6 @@ a = if temp then even elseif true then temp else e@9 CHECK(ac.entryMap.count("then") == 0); CHECK(ac.entryMap.count("else") == 0); CHECK(ac.entryMap.count("elseif") == 0); - } } TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack") @@ -2681,4 +2674,58 @@ local r4 = t:bar1(@4) CHECK(ac.entryMap["foo2"].typeCorrect == TypeCorrectKind::None); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_default_type_parameters") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + + check(R"( +type A = () -> T + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap.count("string")); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_default_type_pack_parameters") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + + check(R"( +type A = () -> T + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("number")); + CHECK(ac.entryMap.count("string")); +} + +TEST_CASE_FIXTURE(ACFixture, "autocomplete_oop_implicit_self") +{ + ScopedFastFlag flag("LuauMissingFollowACMetatables", true); + check(R"( +--!strict +local Class = {} +Class.__index = Class +type Class = typeof(setmetatable({} :: { x: number }, Class)) +function Class.new(x: number): Class + return setmetatable({x = x}, Class) +end +function Class.getx(self: Class) + return self.x +end +function test() + local c = Class.new(42) + local n = c:@1 + print(n) +end + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("getx")); +} + TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 95811b3f4..8eed953fd 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -603,9 +603,9 @@ RETURN R0 1 )"); } -TEST_CASE("EmptyTableHashSizePredictionOptimization") +TEST_CASE("TableSizePredictionBasic") { - const char* hashSizeSource = R"( + CHECK_EQ("\n" + compileFunction0(R"( local t = {} t.a = 1 t.b = 1 @@ -616,36 +616,8 @@ t.f = 1 t.g = 1 t.h = 1 t.i = 1 -)"; - - const char* hashSizeSource2 = R"( -local t = {} -t.x = 1 -t.x = 2 -t.x = 3 -t.x = 4 -t.x = 5 -t.x = 6 -t.x = 7 -t.x = 8 -t.x = 9 -)"; - - const char* arraySizeSource = R"( -local t = {} -t[1] = 1 -t[2] = 1 -t[3] = 1 -t[4] = 1 -t[5] = 1 -t[6] = 1 -t[7] = 1 -t[8] = 1 -t[9] = 1 -t[10] = 1 -)"; - - CHECK_EQ("\n" + compileFunction0(hashSizeSource), R"( +)"), + R"( NEWTABLE R0 16 0 LOADN R1 1 SETTABLEKS R1 R0 K0 @@ -668,7 +640,19 @@ SETTABLEKS R1 R0 K8 RETURN R0 0 )"); - CHECK_EQ("\n" + compileFunction0(hashSizeSource2), R"( + CHECK_EQ("\n" + compileFunction0(R"( +local t = {} +t.x = 1 +t.x = 2 +t.x = 3 +t.x = 4 +t.x = 5 +t.x = 6 +t.x = 7 +t.x = 8 +t.x = 9 +)"), + R"( NEWTABLE R0 1 0 LOADN R1 1 SETTABLEKS R1 R0 K0 @@ -691,7 +675,20 @@ SETTABLEKS R1 R0 K0 RETURN R0 0 )"); - CHECK_EQ("\n" + compileFunction0(arraySizeSource), R"( + CHECK_EQ("\n" + compileFunction0(R"( +local t = {} +t[1] = 1 +t[2] = 1 +t[3] = 1 +t[4] = 1 +t[5] = 1 +t[6] = 1 +t[7] = 1 +t[8] = 1 +t[9] = 1 +t[10] = 1 +)"), + R"( NEWTABLE R0 0 10 LOADN R1 1 SETTABLEN R1 R0 1 @@ -717,6 +714,27 @@ RETURN R0 0 )"); } +TEST_CASE("TableSizePredictionObject") +{ + CHECK_EQ("\n" + compileFunction(R"( +local t = {} +t.field = 1 +function t:getfield() + return self.field +end +return t +)", + 1), + R"( +NEWTABLE R0 2 0 +LOADN R1 1 +SETTABLEKS R1 R0 K0 +DUPCLOSURE R1 K1 +SETTABLEKS R1 R0 K2 +RETURN R0 1 +)"); +} + TEST_CASE("TableSizePredictionSetMetatable") { CHECK_EQ("\n" + compileFunction0(R"( @@ -1031,9 +1049,6 @@ RETURN R0 1 TEST_CASE("IfElseExpression") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - // codegen for a true constant condition CHECK_EQ("\n" + compileFunction0("return if true then 10 else 20"), R"( LOADN R0 10 @@ -3058,7 +3073,7 @@ RETURN R0 0 // table variants (indexed by string, number, variable) CHECK_EQ("\n" + compileFunction0("local a = {} a.foo += 5"), R"( -NEWTABLE R0 1 0 +NEWTABLE R0 0 0 GETTABLEKS R1 R0 K0 ADDK R1 R1 K1 SETTABLEKS R1 R0 K0 @@ -3066,7 +3081,7 @@ RETURN R0 0 )"); CHECK_EQ("\n" + compileFunction0("local a = {} a[1] += 5"), R"( -NEWTABLE R0 0 1 +NEWTABLE R0 0 0 GETTABLEN R1 R0 1 ADDK R1 R1 K0 SETTABLEN R1 R0 1 diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 663b329ee..5222af33e 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -366,15 +366,11 @@ TEST_CASE("PCall") TEST_CASE("Pack") { - ScopedFastFlag sff{"LuauStrPackUBCastFix", true}; - runConformance("tpack.lua"); } TEST_CASE("Vector") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; - lua_CompileOptions copts = {}; copts.optimizationLevel = 1; copts.debugLevel = 1; @@ -861,15 +857,11 @@ TEST_CASE("ExceptionObject") TEST_CASE("IfElseExpression") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; - runConformance("ifelseexpr.lua"); } TEST_CASE("TagMethodError") { - ScopedFastFlag sff{"LuauCcallRestoreFix", true}; - runConformance("tmerror.lua", [](lua_State* L) { auto* cb = lua_callbacks(L); diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index ca4281a0b..c74bfa272 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -191,7 +191,7 @@ ParseResult Fixture::tryParse(const std::string& source, const ParseOptions& par return result; } -ParseResult Fixture::matchParseError(const std::string& source, const std::string& message) +ParseResult Fixture::matchParseError(const std::string& source, const std::string& message, std::optional location) { ParseOptions options; options.allowDeclarationSyntax = true; @@ -203,6 +203,9 @@ ParseResult Fixture::matchParseError(const std::string& source, const std::strin CHECK_EQ(result.errors.front().getMessage(), message); + if (location) + CHECK_EQ(result.errors.front().getLocation(), *location); + return result; } diff --git a/tests/Fixture.h b/tests/Fixture.h index e01632eab..ab852ef6d 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -106,7 +106,7 @@ struct Fixture /// Parse with all language extensions enabled ParseResult parseEx(const std::string& source, const ParseOptions& parseOptions = {}); ParseResult tryParse(const std::string& source, const ParseOptions& parseOptions = {}); - ParseResult matchParseError(const std::string& source, const std::string& message); + ParseResult matchParseError(const std::string& source, const std::string& message, std::optional location = std::nullopt); // Verify a parse error occurs and the parse error message has the specified prefix ParseResult matchParseErrorPrefix(const std::string& source, const std::string& prefix); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 5abcb09ac..e91356515 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1255,7 +1255,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_type_group") TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_if_statements") { ScopedFastInt sfis{"LuauRecursionLimit", 10}; - ScopedFastFlag sff{"LuauIfStatementRecursionGuard", true}; matchParseErrorPrefix( "function f() if true then if true then if true then if true then if true then if true then if true then if true then if true " @@ -1266,7 +1265,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_if_statements") TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_changed_elseif_statements") { ScopedFastInt sfis{"LuauRecursionLimit", 10}; - ScopedFastFlag sff{"LuauIfStatementRecursionGuard", true}; matchParseErrorPrefix( "function f() if false then elseif false then elseif false then elseif false then elseif false then elseif false then elseif " @@ -1276,7 +1274,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_changed_elseif_statements" TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions1") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; ScopedFastInt sfis{"LuauRecursionLimit", 10}; matchParseError("function f() return if true then 1 elseif true then 2 elseif true then 3 elseif true then 4 elseif true then 5 elseif true then " @@ -1286,7 +1283,6 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions1 TEST_CASE_FIXTURE(Fixture, "parse_error_with_too_many_nested_ifelse_expressions2") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; ScopedFastInt sfis{"LuauRecursionLimit", 10}; matchParseError( @@ -1962,6 +1958,37 @@ TEST_CASE_FIXTURE(Fixture, "function_type_named_arguments") matchParseError("type MyFunc = (number) -> (d: number) -> number", "Expected '->' when parsing function type, got '<'"); } +TEST_CASE_FIXTURE(Fixture, "function_type_matching_parenthesis") +{ + matchParseError("local a: (number -> string", "Expected ')' (to close '(' at column 13), got '->'"); +} + +TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + + AstStat* stat = parse(R"( +type A = {} +type B = {} +type C = {} +type D = {} +type E = {} +type F = (T...) -> U... +type G = (U...) -> T... + )"); + + REQUIRE(stat != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type_errors") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + + matchParseError("type Y = {}", "Expected default type after type name", Location{{0, 20}, {0, 21}}); + matchParseError("type Y = {}", "Expected default type pack after type pack name", Location{{0, 29}, {0, 30}}); + matchParseError("type Y number> = {}", "Expected type pack after '=', got type", Location{{0, 14}, {0, 32}}); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("ParseErrorRecovery"); @@ -2455,10 +2482,19 @@ do end CHECK_EQ(1, result.errors.size()); } -TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") +TEST_CASE_FIXTURE(Fixture, "recover_expected_type_pack") { - ScopedFastFlag sff{"LuauIfElseExpressionBaseSupport", true}; + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauParseRecoverTypePackEllipsis{"LuauParseRecoverTypePackEllipsis", true}; + + ParseResult result = tryParse(R"( +type Y = (T...) -> U... + )"); + CHECK_EQ(1, result.errors.size()); +} +TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") +{ { AstStat* stat = parse("return if true then 1 else 2"); @@ -2524,9 +2560,4 @@ type C = Packed<(number, X...)> REQUIRE(stat != nullptr); } -TEST_CASE_FIXTURE(Fixture, "function_type_matching_parenthesis") -{ - matchParseError("local a: (number -> string", "Expected ')' (to close '(' at column 13), got '->'"); -} - TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 80a258f5b..445ee5329 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -338,6 +338,8 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed") TEST_CASE_FIXTURE(Fixture, "toStringDetailed2") { + ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; + CheckResult result = check(R"( local base = {} function base:one() return 1 end @@ -353,7 +355,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringDetailed2") TypeId tType = requireType("inst"); ToStringResult r = toStringDetailed(tType); - CHECK_EQ("{ @metatable {| __index: { @metatable {| __index: base |}, child } |}, inst }", r.name); + CHECK_EQ("{ @metatable { __index: { @metatable { __index: base }, child } }, inst }", r.name); CHECK_EQ(0, r.nameMap.typeVars.size()); ToStringOptions opts; @@ -500,6 +502,24 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") CHECK_EQ("map(arr: {a}, fn: (a) -> b): {b}", toStringNamedFunction("map", *ftv)); } +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") +{ + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( + local function f(a: number, b: string) end + local function test(...: T...): U... + f(...) + return 1, 2, 3 + end + )"); + + TypeId ty = requireType("test"); + const FunctionTypeVar* ftv = get(follow(ty)); + + CHECK_EQ("test(...: T...): U...", toStringNamedFunction("test", *ftv)); +} + TEST_CASE("toStringNamedFunction_unit_f") { TypePackVar empty{TypePack{}}; diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 47c3883c1..ac5be859b 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -421,8 +421,6 @@ TEST_CASE_FIXTURE(Fixture, "transpile_type_assertion") TEST_CASE_FIXTURE(Fixture, "transpile_if_then_else") { - ScopedFastFlag luauIfElseExpressionBaseSupport("LuauIfElseExpressionBaseSupport", true); - std::string code = "local a = if 1 then 2 else 3"; CHECK_EQ(code, transpile(code).code); @@ -641,4 +639,16 @@ TEST_CASE_FIXTURE(Fixture, "transpile_to_string") CHECK_EQ("'hello'", toString(expr)); } +TEST_CASE_FIXTURE(Fixture, "transpile_type_alias_default_type_parameters") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + std::string code = R"( +type Packed = (T, U, V...)->(W...) +local a: Packed + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} TEST_SUITE_END(); diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 275782b30..86165814c 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -497,7 +497,7 @@ TEST_CASE_FIXTURE(Fixture, "generic_aliases_are_cloned_properly") CHECK(arrayTable->indexer); CHECK(isInArena(array.type, mod.interfaceTypes)); - CHECK_EQ(array.typeParams[0], arrayTable->indexer->indexResultType); + CHECK_EQ(array.typeParams[0].ty, arrayTable->indexer->indexResultType); } TEST_CASE_FIXTURE(Fixture, "cloned_interface_maintains_pointers_between_definitions") diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 503b613f4..d76b920bb 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1031,9 +1031,6 @@ TEST_CASE_FIXTURE(Fixture, "refine_the_correct_types_opposite_of_when_a_is_not_n TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint_ifelse_expression") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - CheckResult result = check(R"( function f(v:string?) return if v then v else tostring(v) @@ -1048,9 +1045,6 @@ TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint_ifelse_expression") TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint_ifelse_expression") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - CheckResult result = check(R"( function f(v:string?) return if not v then tostring(v) else v @@ -1065,9 +1059,6 @@ TEST_CASE_FIXTURE(Fixture, "invert_is_truthy_constraint_ifelse_expression") TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - CheckResult result = check(R"( function returnOne(x) return 1 @@ -1119,6 +1110,25 @@ TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_ CHECK_EQ("string", toString(requireTypeAtPosition({5, 30}))); } +TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_refined2") +{ + ScopedFastFlag sff{"LuauLValueAsKey", true}; + + CheckResult result = check(R"( + type T = { x: { y: number }? } + + local function f(t: T?) + if t and t.x then + local foo = t.x.y + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireTypeAtPosition({5, 32}))); +} + TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string") { ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true}; diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 68dc1b4fa..94cfb6437 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -360,6 +360,7 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") { ScopedFastFlag sffs[] = { {"LuauParseSingletonTypes", true}, + {"LuauUnsealedTableLiteral", true}, }; CheckResult result = check(R"( @@ -369,7 +370,7 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(R"(Table type '{| ["\n"]: number |}' not compatible with type '{| ["<>"]: number |}' because the former is missing field '<>')", + CHECK_EQ(R"(Table type '{ ["\n"]: number }' not compatible with type '{| ["<>"]: number |}' because the former is missing field '<>')", toString(result.errors[0])); } @@ -423,4 +424,27 @@ caused by: toString(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "if_then_else_expression_singleton_options") +{ + ScopedFastFlag sffs[] = { + {"LuauSingletonTypes", true}, + {"LuauParseSingletonTypes", true}, + {"LuauUnionHeuristic", true}, + {"LuauExpectedTypesOfProperties", true}, + {"LuauExtendedUnionMismatchError", true}, + {"LuauIfElseExpectedType2", true}, + {"LuauIfElseBranchTypeUnion", true}, + }; + + CheckResult result = check(R"( +type Cat = { tag: 'cat', catfood: string } +type Dog = { tag: 'dog', dogfood: string } +type Animal = Cat | Dog + +local a: Animal = if true then { tag = 'cat', catfood = 'something' } else { tag = 'dog', dogfood = 'other' } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 80f40407e..27cda1463 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -65,7 +65,7 @@ TEST_CASE_FIXTURE(Fixture, "augment_nested_table") TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table") { - CheckResult result = check("local t = {prop=999} t.foo = 'bar'"); + CheckResult result = check("function mkt() return {prop=999} end local t = mkt() t.foo = 'bar'"); LUAU_REQUIRE_ERROR_COUNT(1, result); TypeError& err = result.errors[0]; @@ -77,7 +77,7 @@ TEST_CASE_FIXTURE(Fixture, "cannot_augment_sealed_table") CHECK_EQ(s, "{| prop: number |}"); CHECK_EQ(error->prop, "foo"); CHECK_EQ(error->context, CannotExtendTable::Property); - CHECK_EQ(err.location, (Location{Position{0, 24}, Position{0, 29}})); + CHECK_EQ(err.location, (Location{Position{0, 59}, Position{0, 64}})); } TEST_CASE_FIXTURE(Fixture, "dont_seal_an_unsealed_table_by_passing_it_to_a_function_that_takes_a_sealed_table") @@ -1155,7 +1155,8 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_builtin_sealed_table_mu TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_sealed_table_must_fail") { CheckResult result = check(R"( - local t = {x = 1} + function mkt() return {x = 1} end + local t = mkt() function t.m() end )"); @@ -1165,13 +1166,38 @@ TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_sealed_table_must_fail TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_local_sealed_table_must_fail") { CheckResult result = check(R"( - local t = {x = 1} + function mkt() return {x = 1} end + local t = mkt() function t:m() end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(Fixture, "defining_a_method_for_a_local_unsealed_table_is_ok") +{ + ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; + + CheckResult result = check(R"( + local t = {x = 1} + function t.m() end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "defining_a_self_method_for_a_local_unsealed_table_is_ok") +{ + ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; + + CheckResult result = check(R"( + local t = {x = 1} + function t:m() end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + // This unit test could be flaky if the fix has regressed. TEST_CASE_FIXTURE(Fixture, "pass_incompatible_union_to_a_generic_table_without_crashing") { @@ -1439,8 +1465,13 @@ TEST_CASE_FIXTURE(Fixture, "right_table_missing_key2") CHECK_EQ("{| |}", toString(mp->subType)); } -TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer") +TEST_CASE_FIXTURE(Fixture, "casting_unsealed_tables_with_props_into_table_with_indexer") { + ScopedFastFlag sff[]{ + {"LuauTableSubtypingVariance2", true}, + {"LuauUnsealedTableLiteral", true}, + }; + CheckResult result = check(R"( type StringToStringMap = { [string]: string } local rt: StringToStringMap = { ["foo"] = 1 } @@ -1448,6 +1479,25 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer") LUAU_REQUIRE_ERROR_COUNT(1, result); + ToStringOptions o{/* exhaustive= */ true}; + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("{| [string]: string |}", toString(tm->wantedType, o)); + // Should t now have an indexer? + // It would if the assignment to rt was correctly typed. + CHECK_EQ("{ [string]: string, foo: number }", toString(tm->givenType, o)); +} + +TEST_CASE_FIXTURE(Fixture, "casting_sealed_tables_with_props_into_table_with_indexer") +{ + CheckResult result = check(R"( + type StringToStringMap = { [string]: string } + function mkrt() return { ["foo"] = 1 } end + local rt: StringToStringMap = mkrt() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + ToStringOptions o{/* exhaustive= */ true}; TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); @@ -1467,7 +1517,10 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer2") TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") { - ScopedFastFlag sff{"LuauTableSubtypingVariance2", true}; + ScopedFastFlag sff[]{ + {"LuauTableSubtypingVariance2", true}, + {"LuauUnsealedTableLiteral", true}, + }; CheckResult result = check(R"( local function foo(a: {[string]: number, a: string}) end @@ -1480,7 +1533,7 @@ TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer3") TypeMismatch* tm = get(result.errors[0]); REQUIRE(tm); CHECK_EQ("{| [string]: number, a: string |}", toString(tm->wantedType, o)); - CHECK_EQ("{| a: number |}", toString(tm->givenType, o)); + CHECK_EQ("{ a: number }", toString(tm->givenType, o)); } TEST_CASE_FIXTURE(Fixture, "casting_tables_with_props_into_table_with_indexer4") @@ -1536,8 +1589,11 @@ TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_missing_props_dont_report_multi TEST_CASE_FIXTURE(Fixture, "table_subtyping_with_extra_props_dont_report_multiple_errors") { CheckResult result = check(R"( - local vec3 = {{x = 1, y = 2, z = 3}} - local vec1 = {{x = 1}} + function mkvec3() return {x = 1, y = 2, z = 3} end + function mkvec1() return {x = 1} end + + local vec3 = {mkvec3()} + local vec1 = {mkvec1()} vec1 = vec3 )"); @@ -1620,7 +1676,8 @@ TEST_CASE_FIXTURE(Fixture, "reasonable_error_when_adding_a_nonexistent_property_ { CheckResult result = check(R"( --!strict - local A = {"value"} + function mkA() return {"value"} end + local A = mkA() A.B = "Hello" )"); @@ -1668,7 +1725,8 @@ TEST_CASE_FIXTURE(Fixture, "hide_table_error_properties") --!strict local function f() - local t = { x = 1 } + local function mkt() return { x = 1 } end + local t = mkt() function t.a() end function t.b() end @@ -1995,7 +2053,10 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_metatable_prop") { - ScopedFastFlag LuauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path + ScopedFastFlag sff[]{ + {"LuauTableSubtypingVariance2", true}, + {"LuauUnsealedTableLiteral", true}, + }; CheckResult result = check(R"( local a1 = setmetatable({ x = 2, y = 3 }, { __call = function(s) end }); @@ -2010,7 +2071,7 @@ local c2: typeof(a2) = b2 LUAU_REQUIRE_ERROR_COUNT(2, result); CHECK_EQ(toString(result.errors[0]), R"(Type 'b1' could not be converted into 'a1' caused by: - Type '{| x: number, y: string |}' could not be converted into '{| x: number, y: number |}' + Type '{ x: number, y: string }' could not be converted into '{ x: number, y: number }' caused by: Property 'y' is not compatible. Type 'string' could not be converted into 'number')"); @@ -2018,7 +2079,7 @@ caused by: { CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' caused by: - Type '{| __call: (a, b) -> () |}' could not be converted into '{| __call: (a) -> () |}' + Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); } @@ -2026,7 +2087,7 @@ caused by: { CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' caused by: - Type '{| __call: (a, b) -> () |}' could not be converted into '{| __call: (a) -> () |}' + Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()')"); } @@ -2059,6 +2120,7 @@ TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, {"LuauTableSubtypingVariance2", true}, + {"LuauUnsealedTableLiteral", true}, }; CheckResult result = check(R"( @@ -2077,7 +2139,7 @@ local y: number = tmp.p.y LUAU_REQUIRE_ERROR_COUNT(1, result); CHECK_EQ(toString(result.errors[0]), R"(Type 'tmp' could not be converted into 'HasSuper' caused by: - Property 'p' is not compatible. Table type '{| x: number, y: number |}' not compatible with type 'Super' because the former has extra field 'y')"); + Property 'p' is not compatible. Table type '{ x: number, y: number }' not compatible with type 'Super' because the former has extra field 'y')"); } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") @@ -2103,7 +2165,10 @@ a.p = { x = 9 } TEST_CASE_FIXTURE(Fixture, "recursive_metatable_type_call") { - ScopedFastFlag luauFixRecursiveMetatableCall{"LuauFixRecursiveMetatableCall", true}; + ScopedFastFlag sff[]{ + {"LuauFixRecursiveMetatableCall", true}, + {"LuauUnsealedTableLiteral", true}, + }; CheckResult result = check(R"( local b @@ -2112,7 +2177,7 @@ b() )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Cannot call non-function t1 where t1 = { @metatable {| __call: t1 |}, { } })"); + CHECK_EQ(toString(result.errors[0]), R"(Cannot call non-function t1 where t1 = { @metatable { __call: t1 }, { } })"); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index f70f3b1c8..7a056af52 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -4525,7 +4525,9 @@ f(function(x) print(x) end) } TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") -{ +{ + ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; + CheckResult result = check(R"( local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end return sum(2, 3, function(a, b) return a + b end) @@ -4549,7 +4551,7 @@ local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} e )"); LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("{| c: number, s: number |}", toString(requireType("r"))); + REQUIRE_EQ("{ c: number, s: number }", toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") @@ -4689,6 +4691,18 @@ a = setmetatable(a, { __call = function(x) end }) LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "infer_through_group_expr") +{ + ScopedFastFlag luauGroupExpectedType{"LuauGroupExpectedType", true}; + + CheckResult result = check(R"( +local function f(a: (number, number) -> number) return a(1, 3) end +f(((function(a, b) return a + b end))) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "refine_and_or") { CheckResult result = check(R"( @@ -4743,46 +4757,75 @@ local c: X TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions1") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; - - { - CheckResult result = check(R"(local a = if true then "true" else "false")"); - LUAU_REQUIRE_NO_ERRORS(result); - TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); - } + CheckResult result = check(R"(local a = if true then "true" else "false")"); + LUAU_REQUIRE_NO_ERRORS(result); + TypeId aType = requireType("a"); + CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); } TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions2") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; + // Test expression containing elseif + CheckResult result = check(R"( +local a = if false then "a" elseif false then "b" else "c" + )"); + LUAU_REQUIRE_NO_ERRORS(result); + TypeId aType = requireType("a"); + CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); +} + +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_type_union") +{ + ScopedFastFlag sff3{"LuauIfElseBranchTypeUnion", true}; { - // Test expression containing elseif - CheckResult result = check(R"( -local a = if false then "a" elseif false then "b" else "c" - )"); + CheckResult result = check(R"(local a: number? = if true then 42 else nil)"); + LUAU_REQUIRE_NO_ERRORS(result); - TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); + CHECK_EQ(toString(requireType("a"), {true}), "number?"); } } -TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions3") +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_1") { - ScopedFastFlag sff1{"LuauIfElseExpressionBaseSupport", true}; - ScopedFastFlag sff2{"LuauIfElseExpressionAnalysisSupport", true}; + ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true}; + ScopedFastFlag luauIfElseBranchTypeUnion{"LuauIfElseBranchTypeUnion", true}; - { - CheckResult result = check(R"(local a = if true then "true" else 42)"); - // We currently require both true/false expressions to unify to the same type. However, we do intend to lift - // this restriction in the future. - LUAU_REQUIRE_ERROR_COUNT(1, result); - TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::String); - } + CheckResult result = check(R"( +type X = {number | string} +local a: X = if true then {"1", 2, 3} else {4, 5, 6} +)"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("a"), {true}), "{number | string}"); +} + +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_2") +{ + ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true}; + ScopedFastFlag luauIfElseBranchTypeUnion{ "LuauIfElseBranchTypeUnion", true }; + + CheckResult result = check(R"( +local a: number? = if true then 1 else nil +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_3") +{ + ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true}; + + CheckResult result = check(R"( +local function times(n: any, f: () -> T) + local result: {T} = {} + local res = f() + table.insert(result, if true then res else n) + return result +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "type_error_addition") @@ -5039,4 +5082,51 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "table_oop") +{ + CheckResult result = check(R"( + --!strict +local Class = {} +Class.__index = Class + +type Class = typeof(setmetatable({} :: { x: number }, Class)) + +function Class.new(x: number): Class + return setmetatable({x = x}, Class) +end + +function Class.getx(self: Class) + return self.x +end + +function test() + local c = Class.new(42) + local n = c:getx() + local nn = c.x + + print(string.format("%d %d", n, nn)) +end +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "recursive_metatable_crash") +{ + ScopedFastFlag luauMetatableAreEqualRecursion{"LuauMetatableAreEqualRecursion", true}; + + CheckResult result = check(R"( +local function getIt() + local y + y = setmetatable({}, y) + return y +end +local a = getIt() +local b = getIt() +local c = a or b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 5d37b032a..d4878d149 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -621,4 +621,328 @@ type Other = Packed CHECK_EQ(toString(result.errors[0]), "Generic type 'Packed' expects 2 type pack arguments, but only 1 is specified"); } +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_explicit") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: T, b: U } + +local a: Y = { a = 2, b = 3 } +local b: Y = { a = 2, b = "s" } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); + + result = check(R"( +type Y = { a: T } + +local a: Y = { a = 2 } +local b: Y<> = { a = "s" } +local c: Y = { a = "s" } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); + CHECK_EQ(toString(requireType("c")), "Y"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_self") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: T, b: U } + +local a: Y = { a = 2, b = 3 } +local b: Y = { a = "h", b = "s" } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); + + result = check(R"( +type Y string> = { a: T, b: U } + +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y string>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_chained") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: T, b: U, c: V } + +local a: Y +local b: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_explicit") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: (T...) -> () } +local a: Y<> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_ty") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: T, b: (U...) -> T } + +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_tp") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: (T...) -> U... } +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y<(number, string), (number, string)>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_chained_tp") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: (T...) -> U..., b: (T...) -> V... } +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y<(number, string), (number, string), (number, string)>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_mixed_self") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: (T, U, V...) -> W... } +local a: Y +local b: Y +local c: Y +local d: Y ()> + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "Y"); + CHECK_EQ(toString(requireType("b")), "Y"); + CHECK_EQ(toString(requireType("c")), "Y"); + CHECK_EQ(toString(requireType("d")), "Y ()>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_errors") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = { a: T } +local a: Y = { a = 2 } + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type 'T'"); + + result = check(R"( +type Y = { a: (T...) -> () } +local a: Y<> + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type 'T'"); + + result = check(R"( +type Y = { a: (T) -> U... } +local a: Y<...number> + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Generic type 'Y' expects at least 1 type argument, but none are specified"); + + result = check(R"( +type Packed = (T) -> T +local a: Packed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Type parameter list is required"); + + result = check(R"( +type Y = { a: T } +local a: Y + )"); + + LUAU_REQUIRE_ERRORS(result); + + result = check(R"( +type Y = { a: T } +local a: Y<...number> + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_export") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + fileResolver.source["Module/Types"] = R"( +export type A = { a: T, b: U } +export type B = { a: T, b: U } +export type C string> = { a: T, b: U } +export type D = { a: T, b: U, c: V } +export type E = { a: (T...) -> () } +export type F = { a: T, b: (U...) -> T } +export type G = { b: (U...) -> T... } +export type H = { b: (T...) -> T... } +return {} + )"; + + CheckResult resultTypes = frontend.check("Module/Types"); + LUAU_REQUIRE_NO_ERRORS(resultTypes); + + fileResolver.source["Module/Users"] = R"( +local Types = require(script.Parent.Types) + +local a: Types.A +local b: Types.B +local c: Types.C +local d: Types.D +local e: Types.E<> +local eVoid: Types.E<()> +local f: Types.F +local g: Types.G<...number> +local h: Types.H<> + )"; + + CheckResult resultUsers = frontend.check("Module/Users"); + LUAU_REQUIRE_NO_ERRORS(resultUsers); + + CHECK_EQ(toString(requireType("Module/Users", "a")), "A"); + CHECK_EQ(toString(requireType("Module/Users", "b")), "B"); + CHECK_EQ(toString(requireType("Module/Users", "c")), "C string>"); + CHECK_EQ(toString(requireType("Module/Users", "d")), "D"); + CHECK_EQ(toString(requireType("Module/Users", "e")), "E"); + CHECK_EQ(toString(requireType("Module/Users", "eVoid")), "E<>"); + CHECK_EQ(toString(requireType("Module/Users", "f")), "F"); + CHECK_EQ(toString(requireType("Module/Users", "g")), "G<...number, ()>"); + CHECK_EQ(toString(requireType("Module/Users", "h")), "H<>"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_skip_brackets") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type Y = (T...) -> number +local a: Y + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(requireType("a")), "(...string) -> number"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_defaults_confusing_types") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type A = (T, V...) -> (U, W...) +type B = A +type C = A + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("B"), {true}), "(string, ...any) -> (number, ...any)"); + CHECK_EQ(toString(*lookupType("C"), {true}), "(string, boolean) -> (number, boolean)"); +} + +TEST_CASE_FIXTURE(Fixture, "type_alias_defaults_recursive_type") +{ + ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; + ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; + + CheckResult result = check(R"( +type F ()> = (K) -> V +type R = { m: F } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(toString(*lookupType("R"), {true}), "t1 where t1 = {| m: (t1) -> (t1) -> () |}"); +} + +TEST_CASE_FIXTURE(Fixture, "pack_tail_unification_check") +{ + ScopedFastFlag luauUnifyPackTails{"LuauUnifyPackTails", true}; + ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; + + CheckResult result = check(R"( +local a: () -> (number, ...string) +local b: () -> (number, ...boolean) +a = b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '() -> (number, ...boolean)' could not be converted into '() -> (number, ...string)' +caused by: + Type 'boolean' could not be converted into 'string')"); +} + TEST_SUITE_END(); From d70a0788c5b7993cc535c19b2bb732d56b760663 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 21 Jan 2022 08:23:02 -0800 Subject: [PATCH 15/32] Sync to upstream/release/511 --- Analysis/include/Luau/TypeVar.h | 4 + Analysis/src/Linter.cpp | 15 ++ Analysis/src/TypeInfer.cpp | 27 ++- Analysis/src/TypeVar.cpp | 46 +++++ Analysis/src/Unifier.cpp | 6 + Ast/include/Luau/Ast.h | 3 +- Ast/include/Luau/DenseHash.h | 8 +- Ast/src/Parser.cpp | 20 +- CLI/Repl.cpp | 110 +++++++--- CMakeLists.txt | 6 + Compiler/src/TableShape.cpp | 49 ++++- LICENSE.txt | 2 +- VM/include/lua.h | 2 +- VM/src/lapi.cpp | 2 +- VM/src/ldo.cpp | 9 +- VM/src/lfunc.cpp | 96 ++++++--- VM/src/lfunc.h | 7 +- VM/src/lgc.cpp | 251 +++++++++++++++++++---- VM/src/lgc.h | 1 + VM/src/lgcdebug.cpp | 77 +++++-- VM/src/lmem.cpp | 347 +++++++++++++++++++++++++++++++- VM/src/lmem.h | 15 ++ VM/src/lobject.h | 16 +- VM/src/lstate.cpp | 37 +++- VM/src/lstate.h | 15 +- VM/src/lstring.cpp | 136 +++++++++---- VM/src/lstring.h | 2 +- VM/src/ltable.cpp | 8 +- VM/src/ltable.h | 2 +- VM/src/ludata.cpp | 6 +- VM/src/ludata.h | 2 +- VM/src/lvmexecute.cpp | 2 +- tests/Compiler.test.cpp | 24 +++ tests/Linter.test.cpp | 9 +- tests/Parser.test.cpp | 13 +- tests/TypeInfer.tables.test.cpp | 48 +++++ tests/TypeInfer.test.cpp | 29 +++ tests/conformance/closure.lua | 15 ++ tests/conformance/gc.lua | 28 +++ 39 files changed, 1277 insertions(+), 218 deletions(-) diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index fd2c2afa7..3f5e26d66 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #pragma once +#include "Luau/DenseHash.h" #include "Luau/Predicate.h" #include "Luau/Unifiable.h" #include "Luau/Variant.h" @@ -499,6 +500,9 @@ bool maybeGeneric(const TypeId ty); // Checks if a type is of the form T1|...|Tn where one of the Ti is a singleton bool maybeSingleton(TypeId ty); +// Checks if the length operator can be applied on the value of type +bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount); + struct SingletonTypes { const TypeId nilType; diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 1a5b24fe2..905b70bff 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -12,6 +12,8 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauLintTableCreateTable, false) + namespace Luau { @@ -2153,6 +2155,19 @@ class LintTableOperations : AstVisitor "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); } + if (FFlag::LuauLintTableCreateTable && func->index == "create" && node->args.size == 2) + { + // table.create(n, {...}) + if (args[1]->is()) + emitWarning(*context, LintWarning::Code_TableOperations, args[1]->location, + "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); + + // table.create(n, {...} :: ?) + if (AstExprTypeAssertion* as = args[1]->as(); as && as->expr->is()) + emitWarning(*context, LintWarning::Code_TableOperations, as->expr->location, + "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); + } + return true; } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index bedcc0227..e2d8a4fb2 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -33,6 +33,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauIfElseBranchTypeUnion, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpectedType2, false) +LUAU_FASTFLAGVARIABLE(LuauLengthOnCompositeType, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) @@ -2066,17 +2067,27 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn if (get(operandType)) return {errorRecoveryType(scope)}; - if (get(operandType)) - return {numberType}; // Not strictly correct: metatables permit overriding this - - if (auto p = get(operandType)) + if (FFlag::LuauLengthOnCompositeType) { - if (p->type == PrimitiveTypeVar::String) - return {numberType}; + DenseHashSet seen{nullptr}; + + if (!hasLength(operandType, seen, &recursionCount)) + reportError(TypeError{expr.location, NotATable{operandType}}); } + else + { + if (get(operandType)) + return {numberType}; // Not strictly correct: metatables permit overriding this - if (!getTableType(operandType)) - reportError(TypeError{expr.location, NotATable{operandType}}); + if (auto p = get(operandType)) + { + if (p->type == PrimitiveTypeVar::String) + return {numberType}; + } + + if (!getTableType(operandType)) + reportError(TypeError{expr.location, NotATable{operandType}}); + } return {numberType}; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index ac2b25410..df5d76ed0 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -5,6 +5,7 @@ #include "Luau/Common.h" #include "Luau/DenseHash.h" #include "Luau/Error.h" +#include "Luau/RecursionCounter.h" #include "Luau/StringUtils.h" #include "Luau/ToString.h" #include "Luau/TypeInfer.h" @@ -19,6 +20,8 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) +LUAU_FASTINT(LuauTypeInferRecursionLimit) +LUAU_FASTFLAG(LuauLengthOnCompositeType) LUAU_FASTFLAGVARIABLE(LuauMetatableAreEqualRecursion, false) LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) LUAU_FASTFLAG(LuauErrorRecoveryType) @@ -326,6 +329,49 @@ bool maybeSingleton(TypeId ty) return false; } +bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) +{ + LUAU_ASSERT(FFlag::LuauLengthOnCompositeType); + + RecursionLimiter _rl(recursionCount, FInt::LuauTypeInferRecursionLimit); + + ty = follow(ty); + + if (seen.contains(ty)) + return true; + + if (isPrim(ty, PrimitiveTypeVar::String) || get(ty) || get(ty) || get(ty)) + return true; + + if (auto uty = get(ty)) + { + seen.insert(ty); + + for (TypeId part : uty->options) + { + if (!hasLength(part, seen, recursionCount)) + return false; + } + + return true; + } + + if (auto ity = get(ty)) + { + seen.insert(ty); + + for (TypeId part : ity->parts) + { + if (hasLength(part, seen, recursionCount)) + return true; + } + + return false; + } + + return false; +} + FunctionTypeVar::FunctionTypeVar(TypePackId argTypes, TypePackId retType, std::optional defn, bool hasSelf) : argTypes(argTypes) , retType(retType) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 6873c657c..2bd9cf83f 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -13,6 +13,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); +LUAU_FASTFLAGVARIABLE(LuauCommittingTxnLogFreeTpPromote, false) LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); @@ -99,6 +100,11 @@ struct PromoteTypeLevels bool operator()(TypePackId tp, const FreeTypePack&) { + // Surprise, it's actually a BoundTypePack that hasn't been committed yet. + // Calling getMutable on this will trigger an assertion. + if (FFlag::LuauCommittingTxnLogFreeTpPromote && FFlag::LuauUseCommittingTxnLog && !log.is(tp)) + return true; + promote(tp, FFlag::LuauUseCommittingTxnLog ? log.getMutable(tp) : getMutable(tp)); return true; } diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index 573850a5a..ac5950c0a 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -1265,7 +1265,8 @@ struct hash size_t operator()(const Luau::AstName& value) const { // note: since operator== uses pointer identity, hashing function uses it as well - return value.value ? std::hash()(value.value) : 0; + // the hasher is the same as DenseHashPointer (DenseHash.h) + return (uintptr_t(value.value) >> 4) ^ (uintptr_t(value.value) >> 9); } }; diff --git a/Ast/include/Luau/DenseHash.h b/Ast/include/Luau/DenseHash.h index a7b2515a6..65939bee1 100644 --- a/Ast/include/Luau/DenseHash.h +++ b/Ast/include/Luau/DenseHash.h @@ -12,10 +12,6 @@ namespace Luau { -// Internal implementation of DenseHashSet and DenseHashMap -namespace detail -{ - struct DenseHashPointer { size_t operator()(const void* key) const @@ -24,6 +20,10 @@ struct DenseHashPointer } }; +// Internal implementation of DenseHashSet and DenseHashMap +namespace detail +{ + template using DenseHashDefault = std::conditional_t, DenseHashPointer, std::hash>; diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 77787cb1c..3c607d24c 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauParseTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauParseRecoverTypePackEllipsis, false) +LUAU_FASTFLAGVARIABLE(LuauStartingBrokenComment, false) namespace Luau { @@ -174,10 +175,23 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n const Lexeme::Type type = p.lexer.current().type; const Location loc = p.lexer.current().location; - p.lexer.next(); + if (FFlag::LuauStartingBrokenComment) + { + if (options.captureComments) + p.commentLocations.push_back(Comment{type, loc}); + + if (type == Lexeme::BrokenComment) + break; - if (options.captureComments) - p.commentLocations.push_back(Comment{type, loc}); + p.lexer.next(); + } + else + { + p.lexer.next(); + + if (options.captureComments) + p.commentLocations.push_back(Comment{type, loc}); + } } p.lexer.setSkipComments(true); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 36747f487..e50421528 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -35,10 +35,15 @@ enum class CompileFormat Binary }; +struct GlobalOptions +{ + int optimizationLevel = 1; +} globalOptions; + static Luau::CompileOptions copts() { Luau::CompileOptions result = {}; - result.optimizationLevel = 1; + result.optimizationLevel = globalOptions.optimizationLevel; result.debugLevel = 1; result.coverageLevel = coverageActive() ? 2 : 0; @@ -232,13 +237,14 @@ static std::string runCode(lua_State* L, const std::string& source) static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, std::vector& completions) { std::string_view lookup = editBuffer + start; + char lastSep = 0; for (;;) { - size_t dot = lookup.find('.'); - std::string_view prefix = lookup.substr(0, dot); + size_t sep = lookup.find_first_of(".:"); + std::string_view prefix = lookup.substr(0, sep); - if (dot == std::string_view::npos) + if (sep == std::string_view::npos) { // table, key lua_pushnil(L); @@ -249,11 +255,22 @@ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, { // table, key, value std::string_view key = lua_tostring(L, -2); - - if (!key.empty() && Luau::startsWith(key, prefix)) - completions.push_back(editBuffer + std::string(key.substr(prefix.size()))); + int valueType = lua_type(L, -1); + + // If the last separator was a ':' (i.e. a method call) then only functions should be completed. + bool requiredValueType = (lastSep != ':' || valueType == LUA_TFUNCTION); + + if (!key.empty() && requiredValueType && Luau::startsWith(key, prefix)) + { + std::string completion(editBuffer + std::string(key.substr(prefix.size()))); + if (valueType == LUA_TFUNCTION) + { + // Add an opening paren for function calls by default. + completion += "("; + } + completions.push_back(completion); + } } - lua_pop(L, 1); } @@ -266,10 +283,21 @@ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, lua_rawget(L, -2); lua_remove(L, -2); - if (!lua_istable(L, -1)) + if (lua_type(L, -1) == LUA_TSTRING) + { + // Replace the string object with the string class to perform further lookups of string functions + // Note: We retrieve the string class from _G to prevent issues if the user assigns to `string`. + lua_getglobal(L, "_G"); + lua_pushlstring(L, "string", 6); + lua_rawget(L, -2); + lua_remove(L, -2); + LUAU_ASSERT(lua_istable(L, -1)); + } + else if (!lua_istable(L, -1)) break; - lookup.remove_prefix(dot + 1); + lastSep = lookup[sep]; + lookup.remove_prefix(sep + 1); } } @@ -279,7 +307,7 @@ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, static void completeRepl(lua_State* L, const char* editBuffer, std::vector& completions) { size_t start = strlen(editBuffer); - while (start > 0 && (isalnum(editBuffer[start - 1]) || editBuffer[start - 1] == '.' || editBuffer[start - 1] == '_')) + while (start > 0 && (isalnum(editBuffer[start - 1]) || editBuffer[start - 1] == '.' || editBuffer[start - 1] == ':' || editBuffer[start - 1] == '_')) start--; // look the value up in current global table first @@ -319,15 +347,8 @@ struct LinenoiseScopedHistory std::string historyFilepath; }; -static void runRepl() +static void runReplImpl(lua_State* L) { - std::unique_ptr globalState(luaL_newstate(), lua_close); - lua_State* L = globalState.get(); - - setupState(L); - - luaL_sandboxthread(L); - linenoise::SetCompletionCallback([L](const char* editBuffer, std::vector& completions) { completeRepl(L, editBuffer, completions); }); @@ -368,7 +389,18 @@ static void runRepl() } } -static bool runFile(const char* name, lua_State* GL) +static void runRepl() +{ + std::unique_ptr globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + setupState(L); + luaL_sandboxthread(L); + runReplImpl(L); +} + +// `repl` is used it indicate if a repl should be started after executing the file. +static bool runFile(const char* name, lua_State* GL, bool repl) { std::optional source = readFile(name); if (!source) @@ -419,6 +451,10 @@ static bool runFile(const char* name, lua_State* GL) fprintf(stderr, "%s", error.c_str()); } + if (repl) + { + runReplImpl(L); + } lua_pop(GL, 1); return status == 0; } @@ -457,7 +493,7 @@ static bool compileFile(const char* name, CompileFormat format) bcb.setDumpSource(*source); } - Luau::compileOrThrow(bcb, *source); + Luau::compileOrThrow(bcb, *source, copts()); switch (format) { @@ -495,9 +531,11 @@ static void displayHelp(const char* argv0) printf(" --compile[=format]: compile input files and output resulting formatted bytecode (binary or text)\n"); printf("\n"); printf("Available options:\n"); + printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); printf(" -h, --help: Display this usage message.\n"); + printf(" -i, --interactive: Run an interactive REPL after executing the last script specified.\n"); + printf(" -O: use compiler optimization level (n=0-2).\n"); printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); - printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); printf(" --timetrace: record compiler time tracing information into trace.json\n"); } @@ -519,6 +557,7 @@ int main(int argc, char** argv) CompileFormat compileFormat{}; int profile = 0; bool coverage = false; + bool interactive = false; // Set the mode if the user has explicitly specified one. int argStart = 1; @@ -540,8 +579,8 @@ int main(int argc, char** argv) } else { - fprintf(stdout, "Error: Unrecognized value for '--compile' specified.\n"); - return -1; + fprintf(stderr, "Error: Unrecognized value for '--compile' specified.\n"); + return 1; } } @@ -552,6 +591,20 @@ int main(int argc, char** argv) displayHelp(argv[0]); return 0; } + else if (strcmp(argv[i], "-i") == 0 || strcmp(argv[i], "--interactive") == 0) + { + interactive = true; + } + else if (strncmp(argv[i], "-O", 2) == 0) + { + int level = atoi(argv[i] + 2); + if (level < 0 || level > 2) + { + fprintf(stderr, "Error: Optimization level must be between 0 and 2 inclusive.\n"); + return 1; + } + globalOptions.optimizationLevel = level; + } else if (strcmp(argv[i], "--profile") == 0) { profile = 10000; // default to 10 KHz @@ -575,7 +628,7 @@ int main(int argc, char** argv) } else if (argv[i][0] == '-') { - fprintf(stdout, "Error: Unrecognized option '%s'.\n\n", argv[i]); + fprintf(stderr, "Error: Unrecognized option '%s'.\n\n", argv[i]); displayHelp(argv[0]); return 1; } @@ -623,8 +676,11 @@ int main(int argc, char** argv) int failed = 0; - for (const std::string& path : files) - failed += !runFile(path.c_str(), L); + for (size_t i = 0; i < files.size(); ++i) + { + bool isLastFile = i == files.size() - 1; + failed += !runFile(files[i].c_str(), L, interactive && isLastFile); + } if (profile) { diff --git a/CMakeLists.txt b/CMakeLists.txt index bafc59e59..77cf47e85 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -78,6 +78,12 @@ target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analysis PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS}) +if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924) + # disable partial redundancy elimination which regresses interpreter codegen substantially in VS2022: + # https://developercommunity.visualstudio.com/t/performance-regression-on-a-complex-interpreter-lo/1631863 + set_source_files_properties(VM/src/lvmexecute.cpp PROPERTIES COMPILE_FLAGS /d2ssa-pre-) +endif() + if(LUAU_BUILD_CLI) target_compile_options(Luau.Repl.CLI PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) diff --git a/Compiler/src/TableShape.cpp b/Compiler/src/TableShape.cpp index 7d99f2228..9dc2f0a46 100644 --- a/Compiler/src/TableShape.cpp +++ b/Compiler/src/TableShape.cpp @@ -1,11 +1,16 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "TableShape.h" +LUAU_FASTFLAGVARIABLE(LuauPredictTableSizeLoop, false) + namespace Luau { namespace Compile { +// conservative limit for the loop bound that establishes table array size +static const int kMaxLoopBound = 16; + static AstExprTable* getTableHint(AstExpr* expr) { // unadorned table literal @@ -27,7 +32,7 @@ struct ShapeVisitor : AstVisitor { size_t operator()(const std::pair& p) const { - return std::hash()(p.first) ^ std::hash()(p.second); + return DenseHashPointer()(p.first) ^ std::hash()(p.second); } }; @@ -36,10 +41,13 @@ struct ShapeVisitor : AstVisitor DenseHashMap tables; DenseHashSet, Hasher> fields; + DenseHashMap loops; // iterator => upper bound for 1..k + ShapeVisitor(DenseHashMap& shapes) : shapes(shapes) , tables(nullptr) , fields(std::pair()) + , loops(nullptr) { } @@ -63,16 +71,31 @@ struct ShapeVisitor : AstVisitor void assignField(AstExpr* expr, AstExpr* index) { AstExprLocal* lv = expr->as(); - AstExprConstantNumber* number = index->as(); + if (!lv) + return; + + AstExprTable** table = tables.find(lv->local); + if (!table) + return; - if (lv && number) + if (AstExprConstantNumber* number = index->as()) { - if (AstExprTable** table = tables.find(lv->local)) + TableShape& shape = shapes[*table]; + + if (number->value == double(shape.arraySize + 1)) + shape.arraySize += 1; + } + else if (AstExprLocal* iter = index->as()) + { + if (!FFlag::LuauPredictTableSizeLoop) + return; + + if (const unsigned int* bound = loops.find(iter->local)) { TableShape& shape = shapes[*table]; - if (number->value == double(shape.arraySize + 1)) - shape.arraySize += 1; + if (shape.arraySize == 0) + shape.arraySize = *bound; } } } @@ -117,6 +140,20 @@ struct ShapeVisitor : AstVisitor return false; } + + bool visit(AstStatFor* node) override + { + if (!FFlag::LuauPredictTableSizeLoop) + return true; + + AstExprConstantNumber* from = node->from->as(); + AstExprConstantNumber* to = node->to->as(); + + if (from && to && from->value == 1.0 && to->value >= 1.0 && to->value <= double(kMaxLoopBound) && !node->step) + loops[node->var] = unsigned(to->value); + + return true; + } }; void predictTableShapes(DenseHashMap& shapes, AstNode* root) diff --git a/LICENSE.txt b/LICENSE.txt index d63e7299e..fa9914d78 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2019-2021 Roblox Corporation +Copyright (c) 2019-2022 Roblox Corporation Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in diff --git a/VM/include/lua.h b/VM/include/lua.h index 55902160c..c5dcef251 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -382,7 +382,7 @@ typedef struct lua_Callbacks lua_Callbacks; LUA_API lua_Callbacks* lua_callbacks(lua_State* L); /****************************************************************************** - * Copyright (c) 2019-2021 Roblox Corporation + * Copyright (c) 2019-2022 Roblox Corporation * Copyright (C) 1994-2008 Lua.org, PUC-Rio. All rights reserved. * * Permission is hereby granted, free of charge, to any person obtaining diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index c98b95908..d5416285e 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -18,7 +18,7 @@ const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Ri "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; -const char* luau_ident = "$Luau: Copyright (C) 2019-2021 Roblox Corporation $\n" +const char* luau_ident = "$Luau: Copyright (C) 2019-2022 Roblox Corporation $\n" "$URL: luau-lang.org $\n"; #define api_checknelems(L, n) api_check(L, (n) <= (L->top - L->base)) diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 3cce7665d..581506a89 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -150,12 +150,11 @@ l_noret luaD_throw(lua_State* L, int errcode) static void correctstack(lua_State* L, TValue* oldstack) { - CallInfo* ci; - GCObject* up; L->top = (L->top - oldstack) + L->stack; - for (up = L->openupval; up != NULL; up = up->gch.next) - gco2uv(up)->v = (gco2uv(up)->v - oldstack) + L->stack; - for (ci = L->base_ci; ci <= L->ci; ci++) + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + for (UpVal* up = L->openupval; up != NULL; up = (UpVal*)up->next) + up->v = (up->v - oldstack) + L->stack; + for (CallInfo* ci = L->base_ci; ci <= L->ci; ci++) { ci->top = (ci->top - oldstack) + L->stack; ci->base = (ci->base - oldstack) + L->stack; diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 4178eda40..6088f71c4 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -7,10 +7,11 @@ #include "lgc.h" LUAU_FASTFLAGVARIABLE(LuauNoDirectUpvalRemoval, false) +LUAU_FASTFLAG(LuauGcPagedSweep) Proto* luaF_newproto(lua_State* L) { - Proto* f = luaM_new(L, Proto, sizeof(Proto), L->activememcat); + Proto* f = luaM_newgco(L, Proto, sizeof(Proto), L->activememcat); luaC_link(L, f, LUA_TPROTO); f->k = NULL; f->sizek = 0; @@ -38,7 +39,7 @@ Proto* luaF_newproto(lua_State* L) Closure* luaF_newLclosure(lua_State* L, int nelems, Table* e, Proto* p) { - Closure* c = luaM_new(L, Closure, sizeLclosure(nelems), L->activememcat); + Closure* c = luaM_newgco(L, Closure, sizeLclosure(nelems), L->activememcat); luaC_link(L, c, LUA_TFUNCTION); c->isC = 0; c->env = e; @@ -53,7 +54,7 @@ Closure* luaF_newLclosure(lua_State* L, int nelems, Table* e, Proto* p) Closure* luaF_newCclosure(lua_State* L, int nelems, Table* e) { - Closure* c = luaM_new(L, Closure, sizeCclosure(nelems), L->activememcat); + Closure* c = luaM_newgco(L, Closure, sizeCclosure(nelems), L->activememcat); luaC_link(L, c, LUA_TFUNCTION); c->isC = 1; c->env = e; @@ -69,10 +70,9 @@ Closure* luaF_newCclosure(lua_State* L, int nelems, Table* e) UpVal* luaF_findupval(lua_State* L, StkId level) { global_State* g = L->global; - GCObject** pp = &L->openupval; + UpVal** pp = &L->openupval; UpVal* p; - UpVal* uv; - while (*pp != NULL && (p = gco2uv(*pp))->v >= level) + while (*pp != NULL && (p = *pp)->v >= level) { LUAU_ASSERT(p->v != &p->u.value); if (p->v == level) @@ -81,53 +81,95 @@ UpVal* luaF_findupval(lua_State* L, StkId level) changewhite(obj2gco(p)); /* resurrect it */ return p; } - pp = &p->next; + + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + pp = (UpVal**)&p->next; } - uv = luaM_new(L, UpVal, sizeof(UpVal), L->activememcat); /* not found: create a new one */ + + UpVal* uv = luaM_newgco(L, UpVal, sizeof(UpVal), L->activememcat); /* not found: create a new one */ uv->tt = LUA_TUPVAL; uv->marked = luaC_white(g); uv->memcat = L->activememcat; uv->v = level; /* current value lives in the stack */ - uv->next = *pp; /* chain it in the proper position */ - *pp = obj2gco(uv); - uv->u.l.prev = &g->uvhead; /* double link it in `uvhead' list */ + + // chain the upvalue in the threads open upvalue list at the proper position + UpVal* next = *pp; + + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + uv->next = (GCObject*)next; + + if (FFlag::LuauGcPagedSweep) + { + uv->u.l.threadprev = pp; + if (next) + { + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + next->u.l.threadprev = (UpVal**)&uv->next; + } + } + + *pp = uv; + + // double link the upvalue in the global open upvalue list + uv->u.l.prev = &g->uvhead; uv->u.l.next = g->uvhead.u.l.next; uv->u.l.next->u.l.prev = uv; g->uvhead.u.l.next = uv; LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); return uv; } - -static void unlinkupval(UpVal* uv) +void luaF_unlinkupval(UpVal* uv) { + // unlink upvalue from the global open upvalue list LUAU_ASSERT(uv->u.l.next->u.l.prev == uv && uv->u.l.prev->u.l.next == uv); - uv->u.l.next->u.l.prev = uv->u.l.prev; /* remove from `uvhead' list */ + uv->u.l.next->u.l.prev = uv->u.l.prev; uv->u.l.prev->u.l.next = uv->u.l.next; + + if (FFlag::LuauGcPagedSweep) + { + // unlink upvalue from the thread open upvalue list + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and this and the following cast will not be required + *uv->u.l.threadprev = (UpVal*)uv->next; + + if (UpVal* next = (UpVal*)uv->next) + next->u.l.threadprev = uv->u.l.threadprev; + } } -void luaF_freeupval(lua_State* L, UpVal* uv) +void luaF_freeupval(lua_State* L, UpVal* uv, lua_Page* page) { if (uv->v != &uv->u.value) /* is it open? */ - unlinkupval(uv); /* remove from open list */ - luaM_free(L, uv, sizeof(UpVal), uv->memcat); /* free upvalue */ + luaF_unlinkupval(uv); /* remove from open list */ + luaM_freegco(L, uv, sizeof(UpVal), uv->memcat, page); /* free upvalue */ } void luaF_close(lua_State* L, StkId level) { - UpVal* uv; global_State* g = L->global; // TODO: remove with FFlagLuauNoDirectUpvalRemoval - while (L->openupval != NULL && (uv = gco2uv(L->openupval))->v >= level) + UpVal* uv; + while (L->openupval != NULL && (uv = L->openupval)->v >= level) { GCObject* o = obj2gco(uv); LUAU_ASSERT(!isblack(o) && uv->v != &uv->u.value); - L->openupval = uv->next; /* remove from `open' list */ - if (!FFlag::LuauNoDirectUpvalRemoval && isdead(g, o)) + + if (!FFlag::LuauGcPagedSweep) + L->openupval = (UpVal*)uv->next; /* remove from `open' list */ + + if (FFlag::LuauGcPagedSweep && isdead(g, o)) + { + // by removing the upvalue from global/thread open upvalue lists, L->openupval will be pointing to the next upvalue + luaF_unlinkupval(uv); + // close the upvalue without copying the dead data so that luaF_freeupval will not unlink again + uv->v = &uv->u.value; + } + else if (!FFlag::LuauNoDirectUpvalRemoval && isdead(g, o)) { - luaF_freeupval(L, uv); /* free upvalue */ + luaF_freeupval(L, uv, NULL); /* free upvalue */ } else { - unlinkupval(uv); + // by removing the upvalue from global/thread open upvalue lists, L->openupval will be pointing to the next upvalue + luaF_unlinkupval(uv); setobj(L, &uv->u.value, uv->v); uv->v = &uv->u.value; /* now current value lives here */ luaC_linkupval(L, uv); /* link upvalue into `gcroot' list */ @@ -135,7 +177,7 @@ void luaF_close(lua_State* L, StkId level) } } -void luaF_freeproto(lua_State* L, Proto* f) +void luaF_freeproto(lua_State* L, Proto* f, lua_Page* page) { luaM_freearray(L, f->code, f->sizecode, Instruction, f->memcat); luaM_freearray(L, f->p, f->sizep, Proto*, f->memcat); @@ -146,13 +188,13 @@ void luaF_freeproto(lua_State* L, Proto* f) luaM_freearray(L, f->upvalues, f->sizeupvalues, TString*, f->memcat); if (f->debuginsn) luaM_freearray(L, f->debuginsn, f->sizecode, uint8_t, f->memcat); - luaM_free(L, f, sizeof(Proto), f->memcat); + luaM_freegco(L, f, sizeof(Proto), f->memcat, page); } -void luaF_freeclosure(lua_State* L, Closure* c) +void luaF_freeclosure(lua_State* L, Closure* c, lua_Page* page) { int size = c->isC ? sizeCclosure(c->nupvalues) : sizeLclosure(c->nupvalues); - luaM_free(L, c, size, c->memcat); + luaM_freegco(L, c, size, c->memcat, page); } const LocVar* luaF_getlocal(const Proto* f, int local_number, int pc) diff --git a/VM/src/lfunc.h b/VM/src/lfunc.h index 4be236675..8047cebe2 100644 --- a/VM/src/lfunc.h +++ b/VM/src/lfunc.h @@ -12,7 +12,8 @@ LUAI_FUNC Closure* luaF_newLclosure(lua_State* L, int nelems, Table* e, Proto* p LUAI_FUNC Closure* luaF_newCclosure(lua_State* L, int nelems, Table* e); LUAI_FUNC UpVal* luaF_findupval(lua_State* L, StkId level); LUAI_FUNC void luaF_close(lua_State* L, StkId level); -LUAI_FUNC void luaF_freeproto(lua_State* L, Proto* f); -LUAI_FUNC void luaF_freeclosure(lua_State* L, Closure* c); -LUAI_FUNC void luaF_freeupval(lua_State* L, UpVal* uv); +LUAI_FUNC void luaF_freeproto(lua_State* L, Proto* f, struct lua_Page* page); +LUAI_FUNC void luaF_freeclosure(lua_State* L, Closure* c, struct lua_Page* page); +void luaF_unlinkupval(UpVal* uv); +LUAI_FUNC void luaF_freeupval(lua_State* L, UpVal* uv, struct lua_Page* page); LUAI_FUNC const LocVar* luaF_getlocal(const Proto* func, int local_number, int pc); diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 76ef7a06b..50859b1e8 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -8,12 +8,16 @@ #include "lfunc.h" #include "lstring.h" #include "ldo.h" +#include "lmem.h" #include "ludata.h" #include +LUAU_FASTFLAGVARIABLE(LuauGcPagedSweep, false) + #define GC_SWEEPMAX 40 #define GC_SWEEPCOST 10 +#define GC_SWEEPPAGESTEPCOST 4 #define GC_INTERRUPT(state) \ { \ @@ -457,31 +461,31 @@ static void shrinkstack(lua_State* L) condhardstacktests(luaD_reallocstack(L, s_used)); } -static void freeobj(lua_State* L, GCObject* o) +static void freeobj(lua_State* L, GCObject* o, lua_Page* page) { switch (o->gch.tt) { case LUA_TPROTO: - luaF_freeproto(L, gco2p(o)); + luaF_freeproto(L, gco2p(o), page); break; case LUA_TFUNCTION: - luaF_freeclosure(L, gco2cl(o)); + luaF_freeclosure(L, gco2cl(o), page); break; case LUA_TUPVAL: - luaF_freeupval(L, gco2uv(o)); + luaF_freeupval(L, gco2uv(o), page); break; case LUA_TTABLE: - luaH_free(L, gco2h(o)); + luaH_free(L, gco2h(o), page); break; case LUA_TTHREAD: LUAU_ASSERT(gco2th(o) != L && gco2th(o) != L->global->mainthread); - luaE_freethread(L, gco2th(o)); + luaE_freethread(L, gco2th(o), page); break; case LUA_TSTRING: - luaS_free(L, gco2ts(o)); + luaS_free(L, gco2ts(o), page); break; case LUA_TUSERDATA: - luaU_freeudata(L, gco2u(o)); + luaU_freeudata(L, gco2u(o), page); break; default: LUAU_ASSERT(0); @@ -492,6 +496,8 @@ static void freeobj(lua_State* L, GCObject* o) static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count, size_t* traversedcount) { + LUAU_ASSERT(!FFlag::LuauGcPagedSweep); + GCObject* curr; global_State* g = L->global; int deadmask = otherwhite(g); @@ -502,7 +508,7 @@ static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count, size_t* tr int alive = (curr->gch.marked ^ WHITEBITS) & deadmask; if (curr->gch.tt == LUA_TTHREAD) { - sweepwholelist(L, &gco2th(curr)->openupval, traversedcount); /* sweep open upvalues */ + sweepwholelist(L, (GCObject**)&gco2th(curr)->openupval, traversedcount); /* sweep open upvalues */ lua_State* th = gco2th(curr); @@ -524,7 +530,7 @@ static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count, size_t* tr *p = curr->gch.next; if (curr == g->rootgc) /* is the first element of the list? */ g->rootgc = curr->gch.next; /* adjust first */ - freeobj(L, curr); + freeobj(L, curr, NULL); } } @@ -537,14 +543,16 @@ static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count, size_t* tr static void deletelist(lua_State* L, GCObject** p, GCObject* limit) { + LUAU_ASSERT(!FFlag::LuauGcPagedSweep); + GCObject* curr; while ((curr = *p) != limit) { if (curr->gch.tt == LUA_TTHREAD) /* delete open upvalues of each thread */ - deletelist(L, &gco2th(curr)->openupval, NULL); + deletelist(L, (GCObject**)&gco2th(curr)->openupval, NULL); *p = curr->gch.next; - freeobj(L, curr); + freeobj(L, curr, NULL); } } @@ -567,23 +575,62 @@ static void shrinkbuffersfull(lua_State* L) luaS_resize(L, hashsize); /* table is too big */ } +static bool deletegco(void* context, lua_Page* page, GCObject* gco) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + // we are in the process of deleting everything + // threads with open upvalues will attempt to close them all on removal + // but those upvalues might point to stack values that were already deleted + if (gco->gch.tt == LUA_TTHREAD) + { + lua_State* th = gco2th(gco); + + while (UpVal* uv = th->openupval) + { + luaF_unlinkupval(uv); + // close the upvalue without copying the dead data so that luaF_freeupval will not unlink again + uv->v = &uv->u.value; + } + } + + lua_State* L = (lua_State*)context; + freeobj(L, gco, page); + return true; +} + void luaC_freeall(lua_State* L) { global_State* g = L->global; LUAU_ASSERT(L == g->mainthread); - LUAU_ASSERT(L->next == NULL); /* mainthread is at the end of rootgc list */ - deletelist(L, &g->rootgc, obj2gco(L)); + if (FFlag::LuauGcPagedSweep) + { + luaM_visitgco(L, L, deletegco); + + for (int i = 0; i < g->strt.size; i++) /* free all string lists */ + LUAU_ASSERT(g->strt.hash[i] == NULL); - for (int i = 0; i < g->strt.size; i++) /* free all string lists */ - deletelist(L, &g->strt.hash[i], NULL); + LUAU_ASSERT(L->global->strt.nuse == 0); + LUAU_ASSERT(g->strbufgc == NULL); + } + else + { + LUAU_ASSERT(L->next == NULL); /* mainthread is at the end of rootgc list */ - LUAU_ASSERT(L->global->strt.nuse == 0); - deletelist(L, &g->strbufgc, NULL); - // unfortunately, when string objects are freed, the string table use count is decremented - // even when the string is a buffer that wasn't placed into the table - L->global->strt.nuse = 0; + deletelist(L, &g->rootgc, obj2gco(L)); + + for (int i = 0; i < g->strt.size; i++) /* free all string lists */ + deletelist(L, (GCObject**)&g->strt.hash[i], NULL); + + LUAU_ASSERT(L->global->strt.nuse == 0); + deletelist(L, (GCObject**)&g->strbufgc, NULL); + + // unfortunately, when string objects are freed, the string table use count is decremented + // even when the string is a buffer that wasn't placed into the table + L->global->strt.nuse = 0; + } } static void markmt(global_State* g) @@ -648,12 +695,88 @@ static size_t atomic(lua_State* L) /* flip current white */ g->currentwhite = cast_byte(otherwhite(g)); g->sweepstrgc = 0; - g->sweepgc = &g->rootgc; - g->gcstate = GCSsweepstring; + + if (FFlag::LuauGcPagedSweep) + { + g->sweepgcopage = g->allgcopages; + g->gcstate = GCSsweep; + } + else + { + g->sweepgc = &g->rootgc; + g->gcstate = GCSsweepstring; + } return work; } +static bool sweepgco(lua_State* L, lua_Page* page, GCObject* gco) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + global_State* g = L->global; + + int deadmask = otherwhite(g); + LUAU_ASSERT(testbit(deadmask, FIXEDBIT)); // make sure we never sweep fixed objects + + int alive = (gco->gch.marked ^ WHITEBITS) & deadmask; + + g->gcstats.currcycle.sweepitems++; + + if (gco->gch.tt == LUA_TTHREAD) + { + lua_State* th = gco2th(gco); + + if (alive) + { + resetbit(th->stackstate, THREAD_SLEEPINGBIT); + shrinkstack(th); + } + } + + if (alive) + { + LUAU_ASSERT(!isdead(g, gco)); + makewhite(g, gco); // make it white (for next cycle) + return false; + } + + LUAU_ASSERT(isdead(g, gco)); + freeobj(L, gco, page); + return true; +} + +// a version of generic luaM_visitpage specialized for the main sweep stage +static int sweepgcopage(lua_State* L, lua_Page* page) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + char* start; + char* end; + int busyBlocks; + int blockSize; + luaM_getpagewalkinfo(page, &start, &end, &busyBlocks, &blockSize); + + for (char* pos = start; pos != end; pos += blockSize) + { + GCObject* gco = (GCObject*)pos; + + // skip memory blocks that are already freed + if (gco->gch.tt == LUA_TNIL) + continue; + + // when true is returned it means that the element was deleted + if (sweepgco(L, page, gco)) + { + // if the last block was removed, page would be removed as well + if (--busyBlocks == 0) + return (pos - start) / blockSize + 1; + } + } + + return (end - start) / blockSize; +} + static size_t gcstep(lua_State* L, size_t limit) { size_t cost = 0; @@ -706,15 +829,21 @@ static size_t gcstep(lua_State* L, size_t limit) g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; cost = atomic(L); /* finish mark phase */ - LUAU_ASSERT(g->gcstate == GCSsweepstring); + + if (FFlag::LuauGcPagedSweep) + LUAU_ASSERT(g->gcstate == GCSsweep); + else + LUAU_ASSERT(g->gcstate == GCSsweepstring); break; } case GCSsweepstring: { + LUAU_ASSERT(!FFlag::LuauGcPagedSweep); + while (g->sweepstrgc < g->strt.size && cost < limit) { size_t traversedcount = 0; - sweepwholelist(L, &g->strt.hash[g->sweepstrgc++], &traversedcount); + sweepwholelist(L, (GCObject**)&g->strt.hash[g->sweepstrgc++], &traversedcount); g->gcstats.currcycle.sweepitems += traversedcount; cost += GC_SWEEPCOST; @@ -727,7 +856,7 @@ static size_t gcstep(lua_State* L, size_t limit) uint32_t nuse = L->global->strt.nuse; size_t traversedcount = 0; - sweepwholelist(L, &g->strbufgc, &traversedcount); + sweepwholelist(L, (GCObject**)&g->strbufgc, &traversedcount); L->global->strt.nuse = nuse; @@ -738,19 +867,44 @@ static size_t gcstep(lua_State* L, size_t limit) } case GCSsweep: { - while (*g->sweepgc && cost < limit) + if (FFlag::LuauGcPagedSweep) { - size_t traversedcount = 0; - g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX, &traversedcount); + while (g->sweepgcopage && cost < limit) + { + lua_Page* next = luaM_getnextgcopage(g->sweepgcopage); // page sweep might destroy the page - g->gcstats.currcycle.sweepitems += traversedcount; - cost += GC_SWEEPMAX * GC_SWEEPCOST; + int steps = sweepgcopage(L, g->sweepgcopage); + + g->sweepgcopage = next; + cost += steps * GC_SWEEPPAGESTEPCOST; + } + + // nothing more to sweep? + if (g->sweepgcopage == NULL) + { + // don't forget to visit main thread + sweepgco(L, NULL, obj2gco(g->mainthread)); + + shrinkbuffers(L); + g->gcstate = GCSpause; /* end collection */ + } } + else + { + while (*g->sweepgc && cost < limit) + { + size_t traversedcount = 0; + g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX, &traversedcount); - if (*g->sweepgc == NULL) - { /* nothing more to sweep? */ - shrinkbuffers(L); - g->gcstate = GCSpause; /* end collection */ + g->gcstats.currcycle.sweepitems += traversedcount; + cost += GC_SWEEPMAX * GC_SWEEPCOST; + } + + if (*g->sweepgc == NULL) + { /* nothing more to sweep? */ + shrinkbuffers(L); + g->gcstate = GCSpause; /* end collection */ + } } break; } @@ -877,12 +1031,19 @@ void luaC_fullgc(lua_State* L) { /* reset sweep marks to sweep all elements (returning them to white) */ g->sweepstrgc = 0; - g->sweepgc = &g->rootgc; + if (FFlag::LuauGcPagedSweep) + g->sweepgcopage = g->allgcopages; + else + g->sweepgc = &g->rootgc; /* reset other collector lists */ g->gray = NULL; g->grayagain = NULL; g->weak = NULL; - g->gcstate = GCSsweepstring; + + if (FFlag::LuauGcPagedSweep) + g->gcstate = GCSsweep; + else + g->gcstate = GCSsweepstring; } LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); /* finish any pending sweep phase */ @@ -979,8 +1140,11 @@ void luaC_barrierback(lua_State* L, Table* t) void luaC_linkobj(lua_State* L, GCObject* o, uint8_t tt) { global_State* g = L->global; - o->gch.next = g->rootgc; - g->rootgc = o; + if (!FFlag::LuauGcPagedSweep) + { + o->gch.next = g->rootgc; + g->rootgc = o; + } o->gch.marked = luaC_white(g); o->gch.tt = tt; o->gch.memcat = L->activememcat; @@ -990,8 +1154,13 @@ void luaC_linkupval(lua_State* L, UpVal* uv) { global_State* g = L->global; GCObject* o = obj2gco(uv); - o->gch.next = g->rootgc; /* link upvalue into `rootgc' list */ - g->rootgc = o; + + if (!FFlag::LuauGcPagedSweep) + { + o->gch.next = g->rootgc; /* link upvalue into `rootgc' list */ + g->rootgc = o; + } + if (isgray(o)) { if (keepinvariant(g)) diff --git a/VM/src/lgc.h b/VM/src/lgc.h index f434e5064..4455fec5b 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -13,6 +13,7 @@ #define GCSpropagate 1 #define GCSpropagateagain 2 #define GCSatomic 3 +// TODO: remove with FFlagLuauGcPagedSweep #define GCSsweepstring 4 #define GCSsweep 5 diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index c66de9c1d..906fb0d04 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -2,16 +2,19 @@ // This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details #include "lgc.h" +#include "lfunc.h" +#include "lmem.h" #include "lobject.h" #include "lstate.h" -#include "ltable.h" -#include "lfunc.h" #include "lstring.h" +#include "ltable.h" #include "ludata.h" #include #include +LUAU_FASTFLAG(LuauGcPagedSweep) + static void validateobjref(global_State* g, GCObject* f, GCObject* t) { LUAU_ASSERT(!isdead(g, t)); @@ -101,10 +104,11 @@ static void validatestack(global_State* g, lua_State* l) if (l->namecall) validateobjref(g, obj2gco(l), obj2gco(l->namecall)); - for (GCObject* uv = l->openupval; uv; uv = uv->gch.next) + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + for (UpVal* uv = l->openupval; uv; uv = (UpVal*)uv->next) { - LUAU_ASSERT(uv->gch.tt == LUA_TUPVAL); - LUAU_ASSERT(gco2uv(uv)->v != &gco2uv(uv)->u.value); + LUAU_ASSERT(uv->tt == LUA_TUPVAL); + LUAU_ASSERT(uv->v != &uv->u.value); } } @@ -178,6 +182,8 @@ static void validateobj(global_State* g, GCObject* o) static void validatelist(global_State* g, GCObject* o) { + LUAU_ASSERT(!FFlag::LuauGcPagedSweep); + while (o) { validateobj(g, o); @@ -216,6 +222,17 @@ static void validategraylist(global_State* g, GCObject* o) } } +static bool validategco(void* context, lua_Page* page, GCObject* gco) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + lua_State* L = (lua_State*)context; + global_State* g = L->global; + + validateobj(g, gco); + return false; +} + void luaC_validate(lua_State* L) { global_State* g = L->global; @@ -231,11 +248,18 @@ void luaC_validate(lua_State* L) validategraylist(g, g->gray); validategraylist(g, g->grayagain); - for (int i = 0; i < g->strt.size; ++i) - validatelist(g, g->strt.hash[i]); + if (FFlag::LuauGcPagedSweep) + { + luaM_visitgco(L, L, validategco); + } + else + { + for (int i = 0; i < g->strt.size; ++i) + validatelist(g, (GCObject*)(g->strt.hash[i])); - validatelist(g, g->rootgc); - validatelist(g, g->strbufgc); + validatelist(g, g->rootgc); + validatelist(g, (GCObject*)(g->strbufgc)); + } for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) { @@ -499,6 +523,8 @@ static void dumpobj(FILE* f, GCObject* o) static void dumplist(FILE* f, GCObject* o) { + LUAU_ASSERT(!FFlag::LuauGcPagedSweep); + while (o) { dumpref(f, o); @@ -509,22 +535,45 @@ static void dumplist(FILE* f, GCObject* o) // thread has additional list containing collectable objects that are not present in rootgc if (o->gch.tt == LUA_TTHREAD) - dumplist(f, gco2th(o)->openupval); + dumplist(f, (GCObject*)gco2th(o)->openupval); o = o->gch.next; } } +static bool dumpgco(void* context, lua_Page* page, GCObject* gco) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + FILE* f = (FILE*)context; + + dumpref(f, gco); + fputc(':', f); + dumpobj(f, gco); + fputc(',', f); + fputc('\n', f); + + return false; +} + void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)) { global_State* g = L->global; FILE* f = static_cast(file); fprintf(f, "{\"objects\":{\n"); - dumplist(f, g->rootgc); - dumplist(f, g->strbufgc); - for (int i = 0; i < g->strt.size; ++i) - dumplist(f, g->strt.hash[i]); + + if (FFlag::LuauGcPagedSweep) + { + luaM_visitgco(L, f, dumpgco); + } + else + { + dumplist(f, g->rootgc); + dumplist(f, (GCObject*)(g->strbufgc)); + for (int i = 0; i < g->strt.size; ++i) + dumplist(f, (GCObject*)(g->strt.hash[i])); + } fprintf(f, "\"0\":{\"type\":\"userdata\",\"cat\":0,\"size\":0}\n"); // to avoid issues with trailing , fprintf(f, "},\"roots\":{\n"); diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 9f9d4a98f..6d3b77724 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -8,6 +8,8 @@ #include +LUAU_FASTFLAG(LuauGcPagedSweep) + #ifndef __has_feature #define __has_feature(x) 0 #endif @@ -42,13 +44,21 @@ static_assert(sizeof(LuaNode) == ABISWITCH(32, 32, 32), "size mismatch for table #endif static_assert(offsetof(TString, data) == ABISWITCH(24, 20, 20), "size mismatch for string header"); +// TODO (FFlagLuauGcPagedSweep): this will become ABISWITCH(16, 16, 16) static_assert(offsetof(Udata, data) == ABISWITCH(24, 16, 16), "size mismatch for userdata header"); +// TODO (FFlagLuauGcPagedSweep): this will become ABISWITCH(48, 32, 32) static_assert(sizeof(Table) == ABISWITCH(56, 36, 36), "size mismatch for table header"); +// TODO (FFlagLuauGcPagedSweep): new code with old 'next' pointer requires that GCObject start at the same point as TString/UpVal +static_assert(offsetof(GCObject, uv) == 0, "UpVal data must be located at the start of the GCObject"); +static_assert(offsetof(GCObject, ts) == 0, "TString data must be located at the start of the GCObject"); + const size_t kSizeClasses = LUA_SIZECLASSES; const size_t kMaxSmallSize = 512; const size_t kPageSize = 16 * 1024 - 24; // slightly under 16KB since that results in less fragmentation due to heap metadata const size_t kBlockHeader = sizeof(double) > sizeof(void*) ? sizeof(double) : sizeof(void*); // suitable for aligning double & void* on all platforms +// TODO (FFlagLuauGcPagedSweep): when 'next' is removed, 'kBlockHeader' can be used unconditionally +const size_t kGCOHeader = sizeof(GCheader) > kBlockHeader ? sizeof(GCheader) : kBlockHeader; struct SizeClassConfig { @@ -96,6 +106,7 @@ const SizeClassConfig kSizeClassConfig; // metadata for a block is stored in the first pointer of the block #define metadata(block) (*(void**)(block)) +#define freegcolink(block) (*(void**)((char*)block + kGCOHeader)) /* ** About the realloc function: @@ -117,15 +128,22 @@ const SizeClassConfig kSizeClassConfig; struct lua_Page { + // list of pages with free blocks lua_Page* prev; lua_Page* next; + // list of all gco pages + lua_Page* gcolistprev; + lua_Page* gcolistnext; + int busyBlocks; int blockSize; void* freeList; int freeNext; + int pageSize; + union { char data[1]; @@ -141,6 +159,8 @@ l_noret luaM_toobig(lua_State* L) static lua_Page* luaM_newpage(lua_State* L, uint8_t sizeClass) { + LUAU_ASSERT(!FFlag::LuauGcPagedSweep); + global_State* g = L->global; lua_Page* page = (lua_Page*)(*g->frealloc)(L, g->ud, NULL, 0, kPageSize); if (!page) @@ -155,6 +175,9 @@ static lua_Page* luaM_newpage(lua_State* L, uint8_t sizeClass) page->prev = NULL; page->next = NULL; + page->gcolistprev = NULL; + page->gcolistnext = NULL; + page->busyBlocks = 0; page->blockSize = blockSize; @@ -171,8 +194,69 @@ static lua_Page* luaM_newpage(lua_State* L, uint8_t sizeClass) return page; } +static lua_Page* newpage(lua_State* L, lua_Page** gcopageset, int pageSize, int blockSize, int blockCount) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + global_State* g = L->global; + + LUAU_ASSERT(pageSize - offsetof(lua_Page, data) >= blockSize * blockCount); + + lua_Page* page = (lua_Page*)(*g->frealloc)(L, g->ud, NULL, 0, pageSize); + if (!page) + luaD_throw(L, LUA_ERRMEM); + + ASAN_POISON_MEMORY_REGION(page->data, blockSize * blockCount); + + // setup page header + page->prev = NULL; + page->next = NULL; + + page->gcolistprev = NULL; + page->gcolistnext = NULL; + + page->busyBlocks = 0; + page->blockSize = blockSize; + + // note: we start with the last block in the page and move downward + // either order would work, but that way we don't need to store the block count in the page + // additionally, GC stores objects in singly linked lists, and this way the GC lists end up in increasing pointer order + page->freeList = NULL; + page->freeNext = (blockCount - 1) * blockSize; + + page->pageSize = pageSize; + + if (gcopageset) + { + page->gcolistnext = *gcopageset; + if (page->gcolistnext) + page->gcolistnext->gcolistprev = page; + *gcopageset = page; + } + + return page; +} + +static lua_Page* newclasspage(lua_State* L, lua_Page** freepageset, lua_Page** gcopageset, uint8_t sizeClass, bool storeMetadata) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + int blockSize = kSizeClassConfig.sizeOfClass[sizeClass] + (storeMetadata ? kBlockHeader : 0); + int blockCount = (kPageSize - offsetof(lua_Page, data)) / blockSize; + + lua_Page* page = newpage(L, gcopageset, kPageSize, blockSize, blockCount); + + // prepend a page to page freelist (which is empty because we only ever allocate a new page when it is!) + LUAU_ASSERT(!freepageset[sizeClass]); + freepageset[sizeClass] = page; + + return page; +} + static void luaM_freepage(lua_State* L, lua_Page* page, uint8_t sizeClass) { + LUAU_ASSERT(!FFlag::LuauGcPagedSweep); + global_State* g = L->global; // remove page from freelist @@ -188,6 +272,44 @@ static void luaM_freepage(lua_State* L, lua_Page* page, uint8_t sizeClass) (*g->frealloc)(L, g->ud, page, kPageSize, 0); } +static void freepage(lua_State* L, lua_Page** gcopageset, lua_Page* page) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + global_State* g = L->global; + + if (gcopageset) + { + // remove page from alllist + if (page->gcolistnext) + page->gcolistnext->gcolistprev = page->gcolistprev; + + if (page->gcolistprev) + page->gcolistprev->gcolistnext = page->gcolistnext; + else if (*gcopageset == page) + *gcopageset = page->gcolistnext; + } + + // so long + (*g->frealloc)(L, g->ud, page, page->pageSize, 0); +} + +static void freeclasspage(lua_State* L, lua_Page** freepageset, lua_Page** gcopageset, lua_Page* page, uint8_t sizeClass) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + // remove page from freelist + if (page->next) + page->next->prev = page->prev; + + if (page->prev) + page->prev->next = page->next; + else if (freepageset[sizeClass] == page) + freepageset[sizeClass] = page->next; + + freepage(L, gcopageset, page); +} + static void* luaM_newblock(lua_State* L, int sizeClass) { global_State* g = L->global; @@ -195,7 +317,12 @@ static void* luaM_newblock(lua_State* L, int sizeClass) // slow path: no page in the freelist, allocate a new one if (!page) - page = luaM_newpage(L, sizeClass); + { + if (FFlag::LuauGcPagedSweep) + page = newclasspage(L, g->freepages, NULL, sizeClass, true); + else + page = luaM_newpage(L, sizeClass); + } LUAU_ASSERT(!page->prev); LUAU_ASSERT(page->freeList || page->freeNext >= 0); @@ -236,6 +363,55 @@ static void* luaM_newblock(lua_State* L, int sizeClass) return (char*)block + kBlockHeader; } +static void* luaM_newgcoblock(lua_State* L, int sizeClass) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + global_State* g = L->global; + lua_Page* page = g->freegcopages[sizeClass]; + + // slow path: no page in the freelist, allocate a new one + if (!page) + page = newclasspage(L, g->freegcopages, &g->allgcopages, sizeClass, false); + + LUAU_ASSERT(!page->prev); + LUAU_ASSERT(page->freeList || page->freeNext >= 0); + LUAU_ASSERT(size_t(page->blockSize) == kSizeClassConfig.sizeOfClass[sizeClass]); + + void* block; + + if (page->freeNext >= 0) + { + block = &page->data + page->freeNext; + ASAN_UNPOISON_MEMORY_REGION(block, page->blockSize); + + page->freeNext -= page->blockSize; + page->busyBlocks++; + } + else + { + // when separate block metadata is not used, free list link is stored inside the block data itself + block = (char*)page->freeList - kGCOHeader; + + ASAN_UNPOISON_MEMORY_REGION((char*)block + kGCOHeader, page->blockSize - kGCOHeader); + + page->freeList = freegcolink(block); + page->busyBlocks++; + } + + // if we allocate the last block out of a page, we need to remove it from free list + if (!page->freeList && page->freeNext < 0) + { + g->freegcopages[sizeClass] = page->next; + if (page->next) + page->next->prev = NULL; + page->next = NULL; + } + + // the user data is right after the metadata + return (char*)block; +} + static void luaM_freeblock(lua_State* L, int sizeClass, void* block) { global_State* g = L->global; @@ -270,12 +446,45 @@ static void luaM_freeblock(lua_State* L, int sizeClass, void* block) // if it's the last block in the page, we don't need the page if (page->busyBlocks == 0) - luaM_freepage(L, page, sizeClass); + { + if (FFlag::LuauGcPagedSweep) + freeclasspage(L, g->freepages, NULL, page, sizeClass); + else + luaM_freepage(L, page, sizeClass); + } +} + +static void luaM_freegcoblock(lua_State* L, int sizeClass, void* block, lua_Page* page) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + global_State* g = L->global; + + // if the page wasn't in the page free list, it should be now since it got a block! + if (!page->freeList && page->freeNext < 0) + { + LUAU_ASSERT(!page->prev); + LUAU_ASSERT(!page->next); + + page->next = g->freegcopages[sizeClass]; + if (page->next) + page->next->prev = page; + g->freegcopages[sizeClass] = page; + } + + // when separate block metadata is not used, free list link is stored inside the block data itself + freegcolink(block) = page->freeList; + page->freeList = (char*)block + kGCOHeader; + + ASAN_POISON_MEMORY_REGION((char*)block + kGCOHeader, page->blockSize - kGCOHeader); + + page->busyBlocks--; + + // if it's the last block in the page, we don't need the page + if (page->busyBlocks == 0) + freeclasspage(L, g->freegcopages, &g->allgcopages, page, sizeClass); } -/* -** generic allocation routines. -*/ void* luaM_new_(lua_State* L, size_t nsize, uint8_t memcat) { global_State* g = L->global; @@ -292,6 +501,43 @@ void* luaM_new_(lua_State* L, size_t nsize, uint8_t memcat) return block; } +GCObject* luaM_newgco_(lua_State* L, size_t nsize, uint8_t memcat) +{ + if (!FFlag::LuauGcPagedSweep) + return (GCObject*)luaM_new_(L, nsize, memcat); + + global_State* g = L->global; + + int nclass = sizeclass(nsize); + + void* block = NULL; + + if (nclass >= 0) + { + LUAU_ASSERT(nsize > 8); + + block = luaM_newgcoblock(L, nclass); + } + else + { + lua_Page* page = newpage(L, &g->allgcopages, offsetof(lua_Page, data) + nsize, nsize, 1); + + block = &page->data; + ASAN_UNPOISON_MEMORY_REGION(block, page->blockSize); + + page->freeNext -= page->blockSize; + page->busyBlocks++; + } + + if (block == NULL && nsize > 0) + luaD_throw(L, LUA_ERRMEM); + + g->totalbytes += nsize; + g->memcatbytes[memcat] += nsize; + + return (GCObject*)block; +} + void luaM_free_(lua_State* L, void* block, size_t osize, uint8_t memcat) { global_State* g = L->global; @@ -308,6 +554,36 @@ void luaM_free_(lua_State* L, void* block, size_t osize, uint8_t memcat) g->memcatbytes[memcat] -= osize; } +void luaM_freegco_(lua_State* L, GCObject* block, size_t osize, uint8_t memcat, lua_Page* page) +{ + if (!FFlag::LuauGcPagedSweep) + { + luaM_free_(L, block, osize, memcat); + return; + } + + global_State* g = L->global; + LUAU_ASSERT((osize == 0) == (block == NULL)); + + int oclass = sizeclass(osize); + + if (oclass >= 0) + { + block->gch.tt = LUA_TNIL; + + luaM_freegcoblock(L, oclass, block, page); + } + else + { + LUAU_ASSERT(page->busyBlocks == 1); + + freepage(L, &g->allgcopages, page); + } + + g->totalbytes -= osize; + g->memcatbytes[memcat] -= osize; +} + void* luaM_realloc_(lua_State* L, void* block, size_t osize, size_t nsize, uint8_t memcat) { global_State* g = L->global; @@ -344,3 +620,64 @@ void* luaM_realloc_(lua_State* L, void* block, size_t osize, size_t nsize, uint8 g->memcatbytes[memcat] += nsize - osize; return result; } + +void luaM_getpagewalkinfo(lua_Page* page, char** start, char** end, int* busyBlocks, int* blockSize) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + int blockCount = (page->pageSize - offsetof(lua_Page, data)) / page->blockSize; + + *start = page->data + page->freeNext + page->blockSize; + *end = page->data + blockCount * page->blockSize; + *busyBlocks = page->busyBlocks; + *blockSize = page->blockSize; +} + +lua_Page* luaM_getnextgcopage(lua_Page* page) +{ + return page->gcolistnext; +} + +void luaM_visitpage(lua_Page* page, void* context, bool (*visitor)(void* context, lua_Page* page, GCObject* gco)) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + char* start; + char* end; + int busyBlocks; + int blockSize; + luaM_getpagewalkinfo(page, &start, &end, &busyBlocks, &blockSize); + + for (char* pos = start; pos != end; pos += blockSize) + { + GCObject* gco = (GCObject*)pos; + + // skip memory blocks that are already freed + if (gco->gch.tt == LUA_TNIL) + continue; + + // when true is returned it means that the element was deleted + if (visitor(context, page, gco)) + { + // if the last block was removed, page would be removed as well + if (--busyBlocks == 0) + break; + } + } +} + +void luaM_visitgco(lua_State* L, void* context, bool (*visitor)(void* context, lua_Page* page, GCObject* gco)) +{ + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + global_State* g = L->global; + + for (lua_Page* curr = g->allgcopages; curr;) + { + lua_Page* next = curr->gcolistnext; // page blockvisit might destroy the page + + luaM_visitpage(curr, context, visitor); + + curr = next; + } +} diff --git a/VM/src/lmem.h b/VM/src/lmem.h index f526a1b67..1bfe48fa8 100644 --- a/VM/src/lmem.h +++ b/VM/src/lmem.h @@ -4,8 +4,15 @@ #include "lua.h" +struct lua_Page; +union GCObject; + +// TODO: remove with FFlagLuauGcPagedSweep and rename luaM_newgco to luaM_new #define luaM_new(L, t, size, memcat) cast_to(t*, luaM_new_(L, size, memcat)) +#define luaM_newgco(L, t, size, memcat) cast_to(t*, luaM_newgco_(L, size, memcat)) +// TODO: remove with FFlagLuauGcPagedSweep and rename luaM_freegco to luaM_free #define luaM_free(L, p, size, memcat) luaM_free_(L, (p), size, memcat) +#define luaM_freegco(L, p, size, memcat, page) luaM_freegco_(L, obj2gco(p), size, memcat, page) #define luaM_arraysize_(n, e) ((cast_to(size_t, (n)) <= SIZE_MAX / (e)) ? (n) * (e) : (luaM_toobig(L), SIZE_MAX)) @@ -15,7 +22,15 @@ ((v) = cast_to(t*, luaM_realloc_(L, v, (oldn) * sizeof(t), luaM_arraysize_(n, sizeof(t)), memcat))) LUAI_FUNC void* luaM_new_(lua_State* L, size_t nsize, uint8_t memcat); +LUAI_FUNC GCObject* luaM_newgco_(lua_State* L, size_t nsize, uint8_t memcat); LUAI_FUNC void luaM_free_(lua_State* L, void* block, size_t osize, uint8_t memcat); +LUAI_FUNC void luaM_freegco_(lua_State* L, GCObject* block, size_t osize, uint8_t memcat, lua_Page* page); LUAI_FUNC void* luaM_realloc_(lua_State* L, void* block, size_t osize, size_t nsize, uint8_t memcat); LUAI_FUNC l_noret luaM_toobig(lua_State* L); + +LUAI_FUNC void luaM_getpagewalkinfo(lua_Page* page, char** start, char** end, int* busyBlocks, int* blockSize); +LUAI_FUNC lua_Page* luaM_getnextgcopage(lua_Page* page); + +LUAI_FUNC void luaM_visitpage(lua_Page* page, void* context, bool (*visitor)(void* context, lua_Page* page, GCObject* gco)); +LUAI_FUNC void luaM_visitgco(lua_State* L, void* context, bool (*visitor)(void* context, lua_Page* page, GCObject* gco)); diff --git a/VM/src/lobject.h b/VM/src/lobject.h index b642cf787..57ffd82ab 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -11,12 +11,11 @@ typedef union GCObject GCObject; /* -** Common Header for all collectible objects (in macro form, to be -** included in other objects) +** Common Header for all collectible objects (in macro form, to be included in other objects) */ // clang-format off #define CommonHeader \ - GCObject* next; \ + GCObject* next; /* TODO: remove with FFlagLuauGcPagedSweep */ \ uint8_t tt; uint8_t marked; uint8_t memcat // clang-format on @@ -229,8 +228,10 @@ typedef TValue* StkId; /* index to stack elements */ typedef struct TString { CommonHeader; + // 1 byte padding int16_t atom; + // 2 byte padding unsigned int hash; unsigned int len; @@ -314,14 +315,21 @@ typedef struct LocVar typedef struct UpVal { CommonHeader; + // 1 (x86) or 5 (x64) byte padding TValue* v; /* points to stack or to its own value */ union { TValue value; /* the value (when closed) */ struct - { /* double linked list (when open) */ + { + /* global double linked list (when open) */ struct UpVal* prev; struct UpVal* next; + + /* thread double linked list (when open) */ + // TODO: when FFlagLuauGcPagedSweep is removed, old outer 'next' value will be placed here + /* note: this is the location of a pointer to this upvalue in the previous element that can be either an UpVal or a lua_State */ + struct UpVal** threadprev; } l; } u; } UpVal; diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index 24e970635..6762c6380 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -10,6 +10,8 @@ #include "ldo.h" #include "ldebug.h" +LUAU_FASTFLAG(LuauGcPagedSweep) + /* ** Main thread combines a thread state and the global state */ @@ -86,14 +88,21 @@ static void close_state(lua_State* L) global_State* g = L->global; luaF_close(L, L->stack); /* close all upvalues for this thread */ luaC_freeall(L); /* collect all objects */ - LUAU_ASSERT(g->rootgc == obj2gco(L)); + if (!FFlag::LuauGcPagedSweep) + LUAU_ASSERT(g->rootgc == obj2gco(L)); LUAU_ASSERT(g->strbufgc == NULL); LUAU_ASSERT(g->strt.nuse == 0); luaM_freearray(L, L->global->strt.hash, L->global->strt.size, TString*, 0); freestack(L, L); - LUAU_ASSERT(g->totalbytes == sizeof(LG)); for (int i = 0; i < LUA_SIZECLASSES; i++) + { LUAU_ASSERT(g->freepages[i] == NULL); + if (FFlag::LuauGcPagedSweep) + LUAU_ASSERT(g->freegcopages[i] == NULL); + } + if (FFlag::LuauGcPagedSweep) + LUAU_ASSERT(g->allgcopages == NULL); + LUAU_ASSERT(g->totalbytes == sizeof(LG)); LUAU_ASSERT(g->memcatbytes[0] == sizeof(LG)); for (int i = 1; i < LUA_MEMORY_CATEGORIES; i++) LUAU_ASSERT(g->memcatbytes[i] == 0); @@ -102,7 +111,7 @@ static void close_state(lua_State* L) lua_State* luaE_newthread(lua_State* L) { - lua_State* L1 = luaM_new(L, lua_State, sizeof(lua_State), L->activememcat); + lua_State* L1 = luaM_newgco(L, lua_State, sizeof(lua_State), L->activememcat); luaC_link(L, L1, LUA_TTHREAD); preinit_state(L1, L->global); L1->activememcat = L->activememcat; // inherit the active memory category @@ -113,7 +122,7 @@ lua_State* luaE_newthread(lua_State* L) return L1; } -void luaE_freethread(lua_State* L, lua_State* L1) +void luaE_freethread(lua_State* L, lua_State* L1, lua_Page* page) { luaF_close(L1, L1->stack); /* close all upvalues for this thread */ LUAU_ASSERT(L1->openupval == NULL); @@ -121,7 +130,7 @@ void luaE_freethread(lua_State* L, lua_State* L1) if (g->cb.userthread) g->cb.userthread(NULL, L1); freestack(L, L1); - luaM_free(L, L1, sizeof(lua_State), L1->memcat); + luaM_freegco(L, L1, sizeof(lua_State), L1->memcat, page); } void lua_resetthread(lua_State* L) @@ -162,7 +171,8 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) return NULL; L = (lua_State*)l; g = &((LG*)L)->g; - L->next = NULL; + if (!FFlag::LuauGcPagedSweep) + L->next = NULL; L->tt = LUA_TTHREAD; L->marked = g->currentwhite = bit2mask(WHITE0BIT, FIXEDBIT); L->memcat = 0; @@ -185,9 +195,11 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) g->strt.hash = NULL; setnilvalue(registry(L)); g->gcstate = GCSpause; - g->rootgc = obj2gco(L); + if (!FFlag::LuauGcPagedSweep) + g->rootgc = obj2gco(L); g->sweepstrgc = 0; - g->sweepgc = &g->rootgc; + if (!FFlag::LuauGcPagedSweep) + g->sweepgc = &g->rootgc; g->gray = NULL; g->grayagain = NULL; g->weak = NULL; @@ -197,7 +209,16 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) g->gcstepmul = LUAI_GCSTEPMUL; g->gcstepsize = LUAI_GCSTEPSIZE << 10; for (i = 0; i < LUA_SIZECLASSES; i++) + { g->freepages[i] = NULL; + if (FFlag::LuauGcPagedSweep) + g->freegcopages[i] = NULL; + } + if (FFlag::LuauGcPagedSweep) + { + g->allgcopages = NULL; + g->sweepgcopage = NULL; + } for (i = 0; i < LUA_T_COUNT; i++) g->mt[i] = NULL; for (i = 0; i < LUA_UTAG_LIMIT; i++) diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 563798833..080f00248 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -22,7 +22,7 @@ typedef struct stringtable { - GCObject** hash; + TString** hash; uint32_t nuse; /* number of elements */ int size; } stringtable; @@ -149,13 +149,15 @@ typedef struct global_State int sweepstrgc; /* position of sweep in `strt' */ + // TODO: remove with FFlagLuauGcPagedSweep GCObject* rootgc; /* list of all collectable objects */ + // TODO: remove with FFlagLuauGcPagedSweep GCObject** sweepgc; /* position of sweep in `rootgc' */ GCObject* gray; /* list of gray objects */ GCObject* grayagain; /* list of objects to be traversed atomically */ GCObject* weak; /* list of weak tables (to be cleared) */ - GCObject* strbufgc; // list of all string buffer objects + TString* strbufgc; // list of all string buffer objects size_t GCthreshold; // when totalbytes > GCthreshold; run GC step @@ -164,7 +166,10 @@ typedef struct global_State int gcstepmul; // see LUAI_GCSTEPMUL int gcstepsize; // see LUAI_GCSTEPSIZE - struct lua_Page* freepages[LUA_SIZECLASSES]; /* free page linked list for each size class */ + struct lua_Page* freepages[LUA_SIZECLASSES]; // free page linked list for each size class for non-collectable objects + struct lua_Page* freegcopages[LUA_SIZECLASSES]; // free page linked list for each size class for collectable objects + struct lua_Page* allgcopages; // page linked list with all pages for all classes + struct lua_Page* sweepgcopage; // position of the sweep in `allgcopages' size_t memcatbytes[LUA_MEMORY_CATEGORIES]; /* total amount of memory used by each memory category */ @@ -231,7 +236,7 @@ struct lua_State TValue l_gt; /* table of globals */ TValue env; /* temporary place for environments */ - GCObject* openupval; /* list of open upvalues in this stack */ + UpVal* openupval; /* list of open upvalues in this stack */ GCObject* gclist; TString* namecall; /* when invoked from Luau using NAMECALL, what method do we need to invoke? */ @@ -268,4 +273,4 @@ union GCObject #define obj2gco(v) check_exp(iscollectable(v), cast_to(GCObject*, (v) + 0)) LUAI_FUNC lua_State* luaE_newthread(lua_State* L); -LUAI_FUNC void luaE_freethread(lua_State* L, lua_State* L1); +LUAI_FUNC void luaE_freethread(lua_State* L, lua_State* L1, struct lua_Page* page); diff --git a/VM/src/lstring.cpp b/VM/src/lstring.cpp index a9e90d17a..cb22cc23a 100644 --- a/VM/src/lstring.cpp +++ b/VM/src/lstring.cpp @@ -7,6 +7,8 @@ #include +LUAU_FASTFLAG(LuauGcPagedSweep) + unsigned int luaS_hash(const char* str, size_t len) { // Note that this hashing algorithm is replicated in BytecodeBuilder.cpp, BytecodeBuilder::getStringHash @@ -44,26 +46,25 @@ unsigned int luaS_hash(const char* str, size_t len) void luaS_resize(lua_State* L, int newsize) { - GCObject** newhash; - stringtable* tb; - int i; if (L->global->gcstate == GCSsweepstring) return; /* cannot resize during GC traverse */ - newhash = luaM_newarray(L, newsize, GCObject*, 0); - tb = &L->global->strt; - for (i = 0; i < newsize; i++) + TString** newhash = luaM_newarray(L, newsize, TString*, 0); + stringtable* tb = &L->global->strt; + for (int i = 0; i < newsize; i++) newhash[i] = NULL; /* rehash */ - for (i = 0; i < tb->size; i++) + for (int i = 0; i < tb->size; i++) { - GCObject* p = tb->hash[i]; + TString* p = tb->hash[i]; while (p) { /* for each node in the list */ - GCObject* next = p->gch.next; /* save next */ - unsigned int h = gco2ts(p)->hash; + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + TString* next = (TString*)p->next; /* save next */ + unsigned int h = p->hash; int h1 = lmod(h, newsize); /* new position */ LUAU_ASSERT(cast_int(h % newsize) == lmod(h, newsize)); - p->gch.next = newhash[h1]; /* chain it */ + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + p->next = (GCObject*)newhash[h1]; /* chain it */ newhash[h1] = p; p = next; } @@ -79,7 +80,7 @@ static TString* newlstr(lua_State* L, const char* str, size_t l, unsigned int h) stringtable* tb; if (l > MAXSSIZE) luaM_toobig(L); - ts = luaM_new(L, TString, sizestring(l), L->activememcat); + ts = luaM_newgco(L, TString, sizestring(l), L->activememcat); ts->len = unsigned(l); ts->hash = h; ts->marked = luaC_white(L->global); @@ -90,8 +91,9 @@ static TString* newlstr(lua_State* L, const char* str, size_t l, unsigned int h) ts->atom = L->global->cb.useratom ? L->global->cb.useratom(ts->data, l) : -1; tb = &L->global->strt; h = lmod(h, tb->size); - ts->next = tb->hash[h]; /* chain new entry */ - tb->hash[h] = obj2gco(ts); + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the case will not be required + ts->next = (GCObject*)tb->hash[h]; /* chain new entry */ + tb->hash[h] = ts; tb->nuse++; if (tb->nuse > cast_to(uint32_t, tb->size) && tb->size <= INT_MAX / 2) luaS_resize(L, tb->size * 2); /* too crowded */ @@ -101,28 +103,41 @@ static TString* newlstr(lua_State* L, const char* str, size_t l, unsigned int h) static void linkstrbuf(lua_State* L, TString* ts) { global_State* g = L->global; - GCObject* o = obj2gco(ts); - o->gch.next = g->strbufgc; - g->strbufgc = o; - o->gch.marked = luaC_white(g); + + if (FFlag::LuauGcPagedSweep) + { + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + ts->next = (GCObject*)g->strbufgc; + g->strbufgc = ts; + ts->marked = luaC_white(g); + } + else + { + GCObject* o = obj2gco(ts); + o->gch.next = (GCObject*)g->strbufgc; + g->strbufgc = gco2ts(o); + o->gch.marked = luaC_white(g); + } } static void unlinkstrbuf(lua_State* L, TString* ts) { global_State* g = L->global; - GCObject** p = &g->strbufgc; + TString** p = &g->strbufgc; - while (GCObject* curr = *p) + while (TString* curr = *p) { - if (curr == obj2gco(ts)) + if (curr == ts) { - *p = curr->gch.next; + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + *p = (TString*)curr->next; return; } else { - p = &curr->gch.next; + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + p = (TString**)&curr->next; } } @@ -134,7 +149,7 @@ TString* luaS_bufstart(lua_State* L, size_t size) if (size > MAXSSIZE) luaM_toobig(L); - TString* ts = luaM_new(L, TString, sizestring(size), L->activememcat); + TString* ts = luaM_newgco(L, TString, sizestring(size), L->activememcat); ts->tt = LUA_TSTRING; ts->memcat = L->activememcat; @@ -152,15 +167,14 @@ TString* luaS_buffinish(lua_State* L, TString* ts) int bucket = lmod(h, tb->size); // search if we already have this string in the hash table - for (GCObject* o = tb->hash[bucket]; o != NULL; o = o->gch.next) + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + for (TString* el = tb->hash[bucket]; el != NULL; el = (TString*)el->next) { - TString* el = gco2ts(o); - if (el->len == ts->len && memcmp(el->data, ts->data, ts->len) == 0) { // string may be dead - if (isdead(L->global, o)) - changewhite(o); + if (isdead(L->global, obj2gco(el))) + changewhite(obj2gco(el)); return el; } @@ -173,8 +187,9 @@ TString* luaS_buffinish(lua_State* L, TString* ts) // Complete string object ts->atom = L->global->cb.useratom ? L->global->cb.useratom(ts->data, ts->len) : -1; - ts->next = tb->hash[bucket]; // chain new entry - tb->hash[bucket] = obj2gco(ts); + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + ts->next = (GCObject*)tb->hash[bucket]; // chain new entry + tb->hash[bucket] = ts; tb->nuse++; if (tb->nuse > cast_to(uint32_t, tb->size) && tb->size <= INT_MAX / 2) @@ -185,24 +200,63 @@ TString* luaS_buffinish(lua_State* L, TString* ts) TString* luaS_newlstr(lua_State* L, const char* str, size_t l) { - GCObject* o; unsigned int h = luaS_hash(str, l); - for (o = L->global->strt.hash[lmod(h, L->global->strt.size)]; o != NULL; o = o->gch.next) + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + for (TString* el = L->global->strt.hash[lmod(h, L->global->strt.size)]; el != NULL; el = (TString*)el->next) { - TString* ts = gco2ts(o); - if (ts->len == l && (memcmp(str, getstr(ts), l) == 0)) + if (el->len == l && (memcmp(str, getstr(el), l) == 0)) { /* string may be dead */ - if (isdead(L->global, o)) - changewhite(o); - return ts; + if (isdead(L->global, obj2gco(el))) + changewhite(obj2gco(el)); + return el; } } return newlstr(L, str, l, h); /* not found */ } -void luaS_free(lua_State* L, TString* ts) +static bool unlinkstr(lua_State* L, TString* ts) { - L->global->strt.nuse--; - luaM_free(L, ts, sizestring(ts->len), ts->memcat); + LUAU_ASSERT(FFlag::LuauGcPagedSweep); + + global_State* g = L->global; + + TString** p = &g->strt.hash[lmod(ts->hash, g->strt.size)]; + + while (TString* curr = *p) + { + if (curr == ts) + { + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + *p = (TString*)curr->next; + return true; + } + else + { + // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required + p = (TString**)&curr->next; + } + } + + return false; +} + +void luaS_free(lua_State* L, TString* ts, lua_Page* page) +{ + if (FFlag::LuauGcPagedSweep) + { + // Unchain from the string table + if (!unlinkstr(L, ts)) + unlinkstrbuf(L, ts); // An unlikely scenario when we have a string buffer on our hands + else + L->global->strt.nuse--; + + luaM_freegco(L, ts, sizestring(ts->len), ts->memcat, page); + } + else + { + L->global->strt.nuse--; + + luaM_free(L, ts, sizestring(ts->len), ts->memcat); + } } diff --git a/VM/src/lstring.h b/VM/src/lstring.h index 3fd0bd39b..290b64d87 100644 --- a/VM/src/lstring.h +++ b/VM/src/lstring.h @@ -20,7 +20,7 @@ LUAI_FUNC unsigned int luaS_hash(const char* str, size_t len); LUAI_FUNC void luaS_resize(lua_State* L, int newsize); LUAI_FUNC TString* luaS_newlstr(lua_State* L, const char* str, size_t l); -LUAI_FUNC void luaS_free(lua_State* L, TString* ts); +LUAI_FUNC void luaS_free(lua_State* L, TString* ts, struct lua_Page* page); LUAI_FUNC TString* luaS_bufstart(lua_State* L, size_t size); LUAI_FUNC TString* luaS_buffinish(lua_State* L, TString* ts); diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 83b59f3fa..c57374e0e 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -424,7 +424,7 @@ static void rehash(lua_State* L, Table* t, const TValue* ek) Table* luaH_new(lua_State* L, int narray, int nhash) { - Table* t = luaM_new(L, Table, sizeof(Table), L->activememcat); + Table* t = luaM_newgco(L, Table, sizeof(Table), L->activememcat); luaC_link(L, t, LUA_TTABLE); t->metatable = NULL; t->flags = cast_byte(~0); @@ -443,12 +443,12 @@ Table* luaH_new(lua_State* L, int narray, int nhash) return t; } -void luaH_free(lua_State* L, Table* t) +void luaH_free(lua_State* L, Table* t, lua_Page* page) { if (t->node != dummynode) luaM_freearray(L, t->node, sizenode(t), LuaNode, t->memcat); luaM_freearray(L, t->array, t->sizearray, TValue, t->memcat); - luaM_free(L, t, sizeof(Table), t->memcat); + luaM_freegco(L, t, sizeof(Table), t->memcat, page); } static LuaNode* getfreepos(Table* t) @@ -741,7 +741,7 @@ int luaH_getn(Table* t) Table* luaH_clone(lua_State* L, Table* tt) { - Table* t = luaM_new(L, Table, sizeof(Table), L->activememcat); + Table* t = luaM_newgco(L, Table, sizeof(Table), L->activememcat); luaC_link(L, t, LUA_TTABLE); t->metatable = tt->metatable; t->flags = tt->flags; diff --git a/VM/src/ltable.h b/VM/src/ltable.h index 45061443e..e8413c853 100644 --- a/VM/src/ltable.h +++ b/VM/src/ltable.h @@ -20,7 +20,7 @@ LUAI_FUNC TValue* luaH_set(lua_State* L, Table* t, const TValue* key); LUAI_FUNC Table* luaH_new(lua_State* L, int narray, int lnhash); LUAI_FUNC void luaH_resizearray(lua_State* L, Table* t, int nasize); LUAI_FUNC void luaH_resizehash(lua_State* L, Table* t, int nhsize); -LUAI_FUNC void luaH_free(lua_State* L, Table* t); +LUAI_FUNC void luaH_free(lua_State* L, Table* t, struct lua_Page* page); LUAI_FUNC int luaH_next(lua_State* L, Table* t, StkId key); LUAI_FUNC int luaH_getn(Table* t); LUAI_FUNC Table* luaH_clone(lua_State* L, Table* tt); diff --git a/VM/src/ludata.cpp b/VM/src/ludata.cpp index d180c388e..758a9bdb7 100644 --- a/VM/src/ludata.cpp +++ b/VM/src/ludata.cpp @@ -11,7 +11,7 @@ Udata* luaU_newudata(lua_State* L, size_t s, int tag) { if (s > INT_MAX - sizeof(Udata)) luaM_toobig(L); - Udata* u = luaM_new(L, Udata, sizeudata(s), L->activememcat); + Udata* u = luaM_newgco(L, Udata, sizeudata(s), L->activememcat); luaC_link(L, u, LUA_TUSERDATA); u->len = int(s); u->metatable = NULL; @@ -20,7 +20,7 @@ Udata* luaU_newudata(lua_State* L, size_t s, int tag) return u; } -void luaU_freeudata(lua_State* L, Udata* u) +void luaU_freeudata(lua_State* L, Udata* u, lua_Page* page) { LUAU_ASSERT(u->tag < LUA_UTAG_LIMIT || u->tag == UTAG_IDTOR); @@ -33,5 +33,5 @@ void luaU_freeudata(lua_State* L, Udata* u) if (dtor) dtor(u->data); - luaM_free(L, u, sizeudata(u->len), u->memcat); + luaM_freegco(L, u, sizeudata(u->len), u->memcat, page); } diff --git a/VM/src/ludata.h b/VM/src/ludata.h index 59cb85bd1..ec374c28b 100644 --- a/VM/src/ludata.h +++ b/VM/src/ludata.h @@ -10,4 +10,4 @@ #define sizeudata(len) (offsetof(Udata, data) + len) LUAI_FUNC Udata* luaU_newudata(lua_State* L, size_t s, int tag); -LUAI_FUNC void luaU_freeudata(lua_State* L, Udata* u); +LUAI_FUNC void luaU_freeudata(lua_State* L, Udata* u, struct lua_Page* page); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index cebeeb584..e58ff2a8e 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -496,7 +496,7 @@ static void luau_execute(lua_State* L) Instruction insn = *pc++; StkId ra = VM_REG(LUAU_INSN_A(insn)); - if (L->openupval && gco2uv(L->openupval)->v >= ra) + if (L->openupval && L->openupval->v >= ra) luaF_close(L, ra); VM_NEXT(); } diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 8eed953fd..3b0d677de 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -756,6 +756,30 @@ RETURN R0 1 )"); } +TEST_CASE("TableSizePredictionLoop") +{ + ScopedFastFlag sff("LuauPredictTableSizeLoop", true); + + CHECK_EQ("\n" + compileFunction0(R"( +local t = {} +for i=1,4 do + t[i] = 0 +end +return t +)"), + R"( +NEWTABLE R0 0 4 +LOADN R3 1 +LOADN R1 4 +LOADN R2 1 +FORNPREP R1 +3 +LOADN R4 0 +SETTABLE R4 R0 R3 +FORNLOOP R1 -3 +RETURN R0 1 +)"); +} + TEST_CASE("ReflectionEnums") { CHECK_EQ("\n" + compileFunction0("return Enum.EasingStyle.Linear"), R"( diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 1d13df289..5ad06f0d7 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1396,6 +1396,8 @@ end TEST_CASE_FIXTURE(Fixture, "TableOperations") { + ScopedFastFlag sff("LuauLintTableCreateTable", true); + LintResult result = lintTyped(R"( local t = {} local tt = {} @@ -1416,9 +1418,12 @@ table.insert(t, string.find("hello", "h")) table.move(t, 0, #t, 1, tt) table.move(t, 1, #t, 0, tt) + +table.create(42, {}) +table.create(42, {} :: {}) )"); - REQUIRE_EQ(result.warnings.size(), 8); + REQUIRE_EQ(result.warnings.size(), 10); CHECK_EQ(result.warnings[0].text, "table.insert will insert the value before the last element, which is likely a bug; consider removing the " "second argument or wrap it in parentheses to silence"); CHECK_EQ(result.warnings[1].text, "table.insert will append the value to the table; consider removing the second argument for efficiency"); @@ -1430,6 +1435,8 @@ table.move(t, 1, #t, 0, tt) "table.insert may change behavior if the call returns more than one result; consider adding parentheses around second argument"); CHECK_EQ(result.warnings[6].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); CHECK_EQ(result.warnings[7].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); + CHECK_EQ(result.warnings[8].text, "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); + CHECK_EQ(result.warnings[9].text, "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); } TEST_CASE_FIXTURE(Fixture, "DuplicateConditions") diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index e91356515..ac81005c5 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1498,6 +1498,17 @@ return CHECK_EQ(std::string(str->value.data, str->value.size), "\n"); } +TEST_CASE_FIXTURE(Fixture, "parse_error_broken_comment") +{ + ScopedFastFlag luauStartingBrokenComment{"LuauStartingBrokenComment", true}; + + const char* expected = "Expected identifier when parsing expression, got unfinished comment"; + + matchParseError("--[[unfinished work", expected); + matchParseError("--!strict\n--[[unfinished work", expected); + matchParseError("local x = 1 --[[unfinished work", expected); +} + TEST_CASE_FIXTURE(Fixture, "string_literals_escapes_broken") { const char* expected = "String literal contains malformed escape sequence"; @@ -2333,7 +2344,7 @@ TEST_CASE_FIXTURE(Fixture, "capture_broken_comment_at_the_start_of_the_file") ParseOptions options; options.captureComments = true; - ParseResult result = parseEx(R"( + ParseResult result = tryParse(R"( --[[ )", options); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 27cda1463..644efed75 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2180,4 +2180,52 @@ b() CHECK_EQ(toString(result.errors[0]), R"(Cannot call non-function t1 where t1 = { @metatable { __call: t1 }, { } })"); } +TEST_CASE_FIXTURE(Fixture, "length_operator_union") +{ + ScopedFastFlag luauLengthOnCompositeType{"LuauLengthOnCompositeType", true}; + + CheckResult result = check(R"( +local x: {number} | {string} +local y = #x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "length_operator_intersection") +{ + ScopedFastFlag luauLengthOnCompositeType{"LuauLengthOnCompositeType", true}; + + CheckResult result = check(R"( +local x: {number} & {z:string} -- mixed tables are evil +local y = #x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "length_operator_non_table_union") +{ + ScopedFastFlag luauLengthOnCompositeType{"LuauLengthOnCompositeType", true}; + + CheckResult result = check(R"( +local x: {number} | any | string +local y = #x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "length_operator_union_errors") +{ + ScopedFastFlag luauLengthOnCompositeType{"LuauLengthOnCompositeType", true}; + + CheckResult result = check(R"( +local x: {number} | number | string +local y = #x + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 7a056af52..7ee5253c7 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -5129,4 +5129,33 @@ local c = a or b LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "bound_typepack_promote") +{ + ScopedFastFlag luauCommittingTxnLogFreeTpPromote{"LuauCommittingTxnLogFreeTpPromote", true}; + + // No assertions should trigger + check(R"( +local function p() + local this = {} + this.pf = foo() + function this:IsActive() end + function this:Start(o) end + return this +end + +local function h(tp, o) + ep = tp + tp:Start(o) + tp.pf.Connect(function() + ep:IsActive() + end) +end + +function on() + local t = p() + h(t) +end + )"); +} + TEST_SUITE_END(); diff --git a/tests/conformance/closure.lua b/tests/conformance/closure.lua index f32d5bdcb..7b0573546 100644 --- a/tests/conformance/closure.lua +++ b/tests/conformance/closure.lua @@ -419,5 +419,20 @@ co = coroutine.create(function () return loadstring("return a")() end) +-- large closure size +do + local a1, a2, a3, a4, a5, a6, a7, a8, a9, a0 + local b1, b2, b3, b4, b5, b6, b7, b8, b9, b0 + local c1, c2, c3, c4, c5, c6, c7, c8, c9, c0 + local d1, d2, d3, d4, d5, d6, d7, d8, d9, d0 + + local f = function() + return + a1 + a2 + a3 + a4 + a5 + a6 + a7 + a8 + a9 + a0 + + b1 + b2 + b3 + b4 + b5 + b6 + b7 + b8 + b9 + b0 + + c1 + c2 + c3 + c4 + c5 + c6 + c7 + c8 + c9 + c0 + + d1 + d2 + d3 + d4 + d5 + d6 + d7 + d8 + d9 + d0 + end +end return 'OK' diff --git a/tests/conformance/gc.lua b/tests/conformance/gc.lua index 6d9eb8544..409cd2247 100644 --- a/tests/conformance/gc.lua +++ b/tests/conformance/gc.lua @@ -291,4 +291,32 @@ do for i = 1,10 do table.insert(___Glob, newproxy(true)) end end +-- create threads that die together with their unmarked upvalues +do + local t = {} + + for i = 1,100 do + local c = coroutine.wrap(function() + local uv = {i + 1} + local function f() + return uv[1] * 10 + end + coroutine.yield(uv[1]) + uv = {i + 2} + coroutine.yield(f()) + end) + + assert(c() == i + 1) + table.insert(t, c) + end + + for i = 1,100 do + t[i] = nil + end + + collectgarbage() + +end + + return('OK') From 699660a4ebe33c582a71f73caacdc98440228b5d Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 21 Jan 2022 08:37:50 -0800 Subject: [PATCH 16/32] Fix MSVC warnings --- VM/src/lgc.cpp | 4 ++-- VM/src/lmem.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 50859b1e8..9a8cb0797 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -770,11 +770,11 @@ static int sweepgcopage(lua_State* L, lua_Page* page) { // if the last block was removed, page would be removed as well if (--busyBlocks == 0) - return (pos - start) / blockSize + 1; + return int(pos - start) / blockSize + 1; } } - return (end - start) / blockSize; + return int(end - start) / blockSize; } static size_t gcstep(lua_State* L, size_t limit) diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 6d3b77724..7a31d6c81 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -520,7 +520,7 @@ GCObject* luaM_newgco_(lua_State* L, size_t nsize, uint8_t memcat) } else { - lua_Page* page = newpage(L, &g->allgcopages, offsetof(lua_Page, data) + nsize, nsize, 1); + lua_Page* page = newpage(L, &g->allgcopages, offsetof(lua_Page, data) + int(nsize), int(nsize), 1); block = &page->data; ASAN_UNPOISON_MEMORY_REGION(block, page->blockSize); From 0062000d4674c44bb145584b2baa531388c9f355 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 21 Jan 2022 08:43:41 -0800 Subject: [PATCH 17/32] One more --- VM/src/lmem.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 7a31d6c81..beacca656 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -200,7 +200,7 @@ static lua_Page* newpage(lua_State* L, lua_Page** gcopageset, int pageSize, int global_State* g = L->global; - LUAU_ASSERT(pageSize - offsetof(lua_Page, data) >= blockSize * blockCount); + LUAU_ASSERT(pageSize - int(offsetof(lua_Page, data)) >= blockSize * blockCount); lua_Page* page = (lua_Page*)(*g->frealloc)(L, g->ud, NULL, 0, pageSize); if (!page) From 9c15f6a6d79d769f4ac6cf80b1521826035f4d76 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 21 Jan 2022 08:52:48 -0800 Subject: [PATCH 18/32] And one more --- VM/src/lmem.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index beacca656..e1dbce504 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -376,7 +376,7 @@ static void* luaM_newgcoblock(lua_State* L, int sizeClass) LUAU_ASSERT(!page->prev); LUAU_ASSERT(page->freeList || page->freeNext >= 0); - LUAU_ASSERT(size_t(page->blockSize) == kSizeClassConfig.sizeOfClass[sizeClass]); + LUAU_ASSERT(page->blockSize == kSizeClassConfig.sizeOfClass[sizeClass]); void* block; From 6e1e277cb8f19f18740259c830ef49920f76043d Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 27 Jan 2022 13:29:34 -0800 Subject: [PATCH 19/32] Sync to upstream/release/512 --- Analysis/include/Luau/AstQuery.h | 15 + Analysis/include/Luau/Module.h | 8 + Analysis/include/Luau/TypeInfer.h | 21 +- Analysis/include/Luau/TypeVar.h | 8 +- Analysis/include/Luau/Unifier.h | 7 + Analysis/src/Autocomplete.cpp | 101 ++--- Analysis/src/Frontend.cpp | 7 +- Analysis/src/Module.cpp | 28 +- Analysis/src/ToString.cpp | 11 +- Analysis/src/TypeAttach.cpp | 2 +- Analysis/src/TypeInfer.cpp | 476 ++++++++++++--------- Analysis/src/TypeVar.cpp | 130 +++--- Analysis/src/Unifier.cpp | 317 ++++++-------- CLI/Analyze.cpp | 40 +- CLI/FileUtils.cpp | 24 +- CLI/FileUtils.h | 1 + CLI/Repl.cpp | 17 +- CLI/Repl.h | 12 + CLI/ReplEntry.cpp | 10 + CMakeLists.txt | 12 + Compiler/include/Luau/Bytecode.h | 3 + Compiler/src/Builtins.cpp | 5 + Compiler/src/Compiler.cpp | 353 ++++++++++------ Makefile | 7 +- Sources.cmake | 18 +- VM/src/lapi.cpp | 13 +- VM/src/lbuiltins.cpp | 30 ++ VM/src/lcorolib.cpp | 5 - VM/src/ldo.cpp | 4 +- VM/src/lgc.cpp | 35 +- VM/src/lgc.h | 1 + VM/src/lmem.cpp | 6 +- VM/src/lstate.h | 3 - tests/AstQuery.test.cpp | 6 - tests/Autocomplete.test.cpp | 36 +- tests/Compiler.test.cpp | 131 ++++++ tests/Conformance.test.cpp | 2 - tests/Frontend.test.cpp | 1 - tests/Parser.test.cpp | 134 +++--- tests/Repl.test.cpp | 117 ++++++ tests/ToString.test.cpp | 4 - tests/TypeInfer.aliases.test.cpp | 4 - tests/TypeInfer.generics.test.cpp | 15 +- tests/TypeInfer.provisional.test.cpp | 28 +- tests/TypeInfer.refinements.test.cpp | 591 ++++++++++++++++----------- tests/TypeInfer.singletons.test.cpp | 6 - tests/TypeInfer.tables.test.cpp | 16 +- tests/TypeInfer.test.cpp | 64 +-- tests/TypeInfer.typePacks.cpp | 1 - tests/TypeInfer.unionTypes.test.cpp | 2 - tests/TypeVar.test.cpp | 72 +++- 51 files changed, 1825 insertions(+), 1135 deletions(-) create mode 100644 CLI/Repl.h create mode 100644 CLI/ReplEntry.cpp create mode 100644 tests/Repl.test.cpp diff --git a/Analysis/include/Luau/AstQuery.h b/Analysis/include/Luau/AstQuery.h index d38976ef7..dfe373a5c 100644 --- a/Analysis/include/Luau/AstQuery.h +++ b/Analysis/include/Luau/AstQuery.h @@ -42,6 +42,21 @@ struct ExprOrLocal { return expr ? expr->location : (local ? local->location : std::optional{}); } + std::optional getName() + { + if (expr) + { + if (AstName name = getIdentifier(expr); name.value) + { + return name; + } + } + else if (local) + { + return local->name; + } + return std::nullopt; + } private: AstExpr* expr = nullptr; diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 2e41674bf..1bf0473c1 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -13,6 +13,8 @@ #include #include +LUAU_FASTFLAG(LuauPrepopulateUnionOptionsBeforeAllocation) + namespace Luau { @@ -58,6 +60,12 @@ struct TypeArena template TypeId addType(T tv) { + if (FFlag::LuauPrepopulateUnionOptionsBeforeAllocation) + { + if constexpr (std::is_same_v) + LUAU_ASSERT(tv.options.size() >= 2); + } + return addTV(TypeVar(std::move(tv))); } diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index aa0900140..b843509dc 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -135,7 +135,8 @@ struct TypeChecker void checkBlock(const ScopePtr& scope, const AstStatBlock& statement); void checkBlockTypeAliases(const ScopePtr& scope, std::vector& sorted); - ExprResult checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType = std::nullopt); + ExprResult checkExpr( + const ScopePtr& scope, const AstExpr& expr, std::optional expectedType = std::nullopt, bool forceSingleton = false); ExprResult checkExpr(const ScopePtr& scope, const AstExprLocal& expr); ExprResult checkExpr(const ScopePtr& scope, const AstExprGlobal& expr); ExprResult checkExpr(const ScopePtr& scope, const AstExprVarargs& expr); @@ -160,14 +161,12 @@ struct TypeChecker // Returns the type of the lvalue. TypeId checkLValue(const ScopePtr& scope, const AstExpr& expr); - // Returns both the type of the lvalue and its binding (if the caller wants to mutate the binding). - // Note: the binding may be null. - // TODO: remove second return value with FFlagLuauUpdateFunctionNameBinding - std::pair checkLValueBinding(const ScopePtr& scope, const AstExpr& expr); - std::pair checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr); - std::pair checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr); - std::pair checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr); - std::pair checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr); + // Returns the type of the lvalue. + TypeId checkLValueBinding(const ScopePtr& scope, const AstExpr& expr); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr); + TypeId checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr); TypeId checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level); std::pair checkFunctionSignature(const ScopePtr& scope, int subLevel, const AstExprFunction& expr, @@ -322,8 +321,6 @@ struct TypeChecker return addTV(TypeVar(tv)); } - TypeId addType(const UnionTypeVar& utv); - TypeId addTV(TypeVar&& tv); TypePackId addTypePack(TypePackVar&& tp); @@ -349,6 +346,8 @@ struct TypeChecker ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); private: + void refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate); + std::optional resolveLValue(const ScopePtr& scope, const LValue& lvalue); std::optional DEPRECATED_resolveLValue(const ScopePtr& scope, const LValue& lvalue); std::optional resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 3f5e26d66..11dc93773 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -111,16 +111,16 @@ struct PrimitiveTypeVar // Singleton types https://github.com/Roblox/luau/blob/master/rfcs/syntax-singleton-types.md // Types for true and false -struct BoolSingleton +struct BooleanSingleton { bool value; - bool operator==(const BoolSingleton& rhs) const + bool operator==(const BooleanSingleton& rhs) const { return value == rhs.value; } - bool operator!=(const BoolSingleton& rhs) const + bool operator!=(const BooleanSingleton& rhs) const { return !(*this == rhs); } @@ -145,7 +145,7 @@ struct StringSingleton // No type for float singletons, partly because === isn't any equalivalence on floats // (NaN != NaN). -using SingletonVariant = Luau::Variant; +using SingletonVariant = Luau::Variant; struct SingletonTypeVar { diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index a3be739a6..1b1671c0a 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -85,6 +85,13 @@ struct Unifier Unifier makeChildUnifier(); + // A utility function that appends the given error to the unifier's error log. + // This allows setting a breakpoint wherever the unifier reports an error. + void reportError(TypeError error) + { + errors.push_back(error); + } + private: bool isNonstrictMode() const; diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 7a801f970..85099e12c 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,9 +14,9 @@ LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); -LUAU_FASTFLAGVARIABLE(LuauAutocompleteFirstArg, false); LUAU_FASTFLAGVARIABLE(LuauCompleteBrokenStringParams, false); LUAU_FASTFLAGVARIABLE(LuauMissingFollowACMetatables, false); +LUAU_FASTFLAGVARIABLE(PreferToCallFunctionsForIntersects, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -194,8 +194,6 @@ static ParenthesesRecommendation getParenRecommendation(TypeId id, const std::ve static std::optional findExpectedTypeAt(const Module& module, AstNode* node, Position position) { - LUAU_ASSERT(FFlag::LuauAutocompleteFirstArg); - auto expr = node->asExpr(); if (!expr) return std::nullopt; @@ -266,43 +264,63 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ } }; - TypeId expectedType; + auto typeAtPosition = findExpectedTypeAt(module, node, position); - if (FFlag::LuauAutocompleteFirstArg) - { - auto typeAtPosition = findExpectedTypeAt(module, node, position); + if (!typeAtPosition) + return TypeCorrectKind::None; - if (!typeAtPosition) - return TypeCorrectKind::None; + TypeId expectedType = follow(*typeAtPosition); - expectedType = follow(*typeAtPosition); - } - else + if (FFlag::PreferToCallFunctionsForIntersects) { - auto expr = node->asExpr(); - if (!expr) - return TypeCorrectKind::None; + auto checkFunctionType = [&canUnify, &expectedType](const FunctionTypeVar* ftv) { + auto [retHead, retTail] = flatten(ftv->retType); - auto it = module.astExpectedTypes.find(expr); - if (!it) - return TypeCorrectKind::None; + if (!retHead.empty() && canUnify(retHead.front(), expectedType)) + return true; - expectedType = follow(*it); - } + // We might only have a variadic tail pack, check if the element is compatible + if (retTail) + { + if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) + return true; + } - // We also want to suggest functions that return compatible result - if (const FunctionTypeVar* ftv = get(ty)) - { - auto [retHead, retTail] = flatten(ftv->retType); + return false; + }; - if (!retHead.empty() && canUnify(retHead.front(), expectedType)) + // We also want to suggest functions that return compatible result + if (const FunctionTypeVar* ftv = get(ty); ftv && checkFunctionType(ftv)) + { return TypeCorrectKind::CorrectFunctionResult; - - // We might only have a variadic tail pack, check if the element is compatible - if (retTail) + } + else if (const IntersectionTypeVar* itv = get(ty)) + { + for (TypeId id : itv->parts) + { + if (const FunctionTypeVar* ftv = get(id); ftv && checkFunctionType(ftv)) + { + return TypeCorrectKind::CorrectFunctionResult; + } + } + } + } + else + { + // We also want to suggest functions that return compatible result + if (const FunctionTypeVar* ftv = get(ty)) { - if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) + auto [retHead, retTail] = flatten(ftv->retType); + + if (!retHead.empty() && canUnify(retHead.front(), expectedType)) return TypeCorrectKind::CorrectFunctionResult; + + // We might only have a variadic tail pack, check if the element is compatible + if (retTail) + { + if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) + return TypeCorrectKind::CorrectFunctionResult; + } } } @@ -741,29 +759,12 @@ std::optional returnFirstNonnullOptionOfType(const UnionTypeVar* utv) static std::optional functionIsExpectedAt(const Module& module, AstNode* node, Position position) { - TypeId expectedType; + auto typeAtPosition = findExpectedTypeAt(module, node, position); - if (FFlag::LuauAutocompleteFirstArg) - { - auto typeAtPosition = findExpectedTypeAt(module, node, position); - - if (!typeAtPosition) - return std::nullopt; - - expectedType = follow(*typeAtPosition); - } - else - { - auto expr = node->asExpr(); - if (!expr) - return std::nullopt; - - auto it = module.astExpectedTypes.find(expr); - if (!it) - return std::nullopt; + if (!typeAtPosition) + return std::nullopt; - expectedType = follow(*it); - } + TypeId expectedType = follow(*typeAtPosition); if (get(expectedType)) return true; diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index fe4b6529a..9001b19df 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -18,7 +18,6 @@ LUAU_FASTFLAG(LuauInferInNoCheckMode) LUAU_FASTFLAGVARIABLE(LuauTypeCheckTwice, false) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) -LUAU_FASTFLAGVARIABLE(LuauPersistDefinitionFileTypes, false) namespace Luau { @@ -102,8 +101,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t generateDocumentationSymbols(globalTy, documentationSymbol); targetScope->bindings[typeChecker.globalNames.names->getOrAdd(name.c_str())] = {globalTy, Location(), false, {}, documentationSymbol}; - if (FFlag::LuauPersistDefinitionFileTypes) - persist(globalTy); + persist(globalTy); } for (const auto& [name, ty] : checkedModule->getModuleScope()->exportedTypeBindings) @@ -113,8 +111,7 @@ LoadDefinitionFileResult loadDefinitionFile(TypeChecker& typeChecker, ScopePtr t generateDocumentationSymbols(globalTy.type, documentationSymbol); targetScope->exportedTypeBindings[name] = globalTy; - if (FFlag::LuauPersistDefinitionFileTypes) - persist(globalTy.type); + persist(globalTy.type); } return LoadDefinitionFileResult{true, parseResult, checkedModule}; diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 9f352f4b3..4fdff8f7a 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -16,6 +16,8 @@ LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) LUAU_FASTFLAG(LuauTypeAliasDefaults) +LUAU_FASTFLAGVARIABLE(LuauPrepopulateUnionOptionsBeforeAllocation, false) + namespace Luau { @@ -377,14 +379,28 @@ void TypeCloner::operator()(const AnyTypeVar& t) void TypeCloner::operator()(const UnionTypeVar& t) { - TypeId result = dest.addType(UnionTypeVar{}); - seenTypes[typeId] = result; + if (FFlag::LuauPrepopulateUnionOptionsBeforeAllocation) + { + std::vector options; + options.reserve(t.options.size()); - UnionTypeVar* option = getMutable(result); - LUAU_ASSERT(option != nullptr); + for (TypeId ty : t.options) + options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); + + TypeId result = dest.addType(UnionTypeVar{std::move(options)}); + seenTypes[typeId] = result; + } + else + { + TypeId result = dest.addType(UnionTypeVar{}); + seenTypes[typeId] = result; - for (TypeId ty : t.options) - option->options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); + UnionTypeVar* option = getMutable(result); + LUAU_ASSERT(option != nullptr); + + for (TypeId ty : t.options) + option->options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); + } } void TypeCloner::operator()(const IntersectionTypeVar& t) diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 4b898d3a6..5e79b8413 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -10,7 +10,6 @@ #include #include -LUAU_FASTFLAG(LuauOccursCheckOkWithRecursiveFunctions) LUAU_FASTFLAG(LuauTypeAliasDefaults) /* @@ -374,7 +373,7 @@ struct TypeVarStringifier void operator()(TypeId, const SingletonTypeVar& stv) { - if (const BoolSingleton* bs = Luau::get(&stv)) + if (const BooleanSingleton* bs = Luau::get(&stv)) state.emit(bs->value ? "true" : "false"); else if (const StringSingleton* ss = Luau::get(&stv)) { @@ -617,9 +616,7 @@ struct TypeVarStringifier std::string saved = std::move(state.result.name); - bool needParens = FFlag::LuauOccursCheckOkWithRecursiveFunctions - ? !state.cycleNames.count(el) && (get(el) || get(el)) - : get(el) || get(el); + bool needParens = !state.cycleNames.count(el) && (get(el) || get(el)); if (needParens) state.emit("("); @@ -675,9 +672,7 @@ struct TypeVarStringifier std::string saved = std::move(state.result.name); - bool needParens = FFlag::LuauOccursCheckOkWithRecursiveFunctions - ? !state.cycleNames.count(el) && (get(el) || get(el)) - : get(el) || get(el); + bool needParens = !state.cycleNames.count(el) && (get(el) || get(el)); if (needParens) state.emit("("); diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 2ec020937..2208213f7 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -97,7 +97,7 @@ class TypeRehydrationVisitor AstType* operator()(const SingletonTypeVar& stv) { - if (const BoolSingleton* bs = get(&stv)) + if (const BooleanSingleton* bs = get(&stv)) return allocator->alloc(Location(), bs->value); else if (const StringSingleton* ss = get(&stv)) { diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index e2d8a4fb2..23fcc2d5b 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -26,8 +26,6 @@ LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauGroupExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. -LUAU_FASTFLAGVARIABLE(LuauCloneCorrectlyBeforeMutatingTableType, false) -LUAU_FASTFLAGVARIABLE(LuauStoreMatchingOverloadFnType, false) LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) @@ -37,6 +35,7 @@ LUAU_FASTFLAGVARIABLE(LuauLengthOnCompositeType, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) +LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions, false) LUAU_FASTFLAGVARIABLE(LuauTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) @@ -46,10 +45,8 @@ LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false) LUAU_FASTFLAGVARIABLE(LuauPerModuleUnificationCache, false) LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) -LUAU_FASTFLAGVARIABLE(LuauFixRecursiveMetatableCall, false) LUAU_FASTFLAGVARIABLE(LuauBidirectionalAsExpr, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) -LUAU_FASTFLAGVARIABLE(LuauUpdateFunctionNameBinding, false) namespace Luau { @@ -1139,33 +1136,25 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco } else { - auto [leftType, leftTypeBinding] = checkLValueBinding(scope, *function.name); + TypeId leftType = checkLValueBinding(scope, *function.name); checkFunctionBody(funScope, ty, *function.func); unify(ty, leftType, function.location); - if (FFlag::LuauUpdateFunctionNameBinding) - { - LUAU_ASSERT(function.name->is() || function.name->is()); + LUAU_ASSERT(function.name->is() || function.name->is()); - if (auto exprIndexName = function.name->as()) + if (auto exprIndexName = function.name->as()) + { + if (auto typeIt = currentModule->astTypes.find(exprIndexName->expr)) { - if (auto typeIt = currentModule->astTypes.find(exprIndexName->expr)) + if (auto ttv = getMutableTableType(*typeIt)) { - if (auto ttv = getMutableTableType(*typeIt)) - { - if (auto it = ttv->props.find(exprIndexName->index.value); it != ttv->props.end()) - it->second.type = follow(quantify(funScope, leftType, function.name->location)); - } + if (auto it = ttv->props.find(exprIndexName->index.value); it != ttv->props.end()) + it->second.type = follow(quantify(funScope, leftType, function.name->location)); } } } - else - { - if (leftTypeBinding) - *leftTypeBinding = follow(quantify(funScope, leftType, function.name->location)); - } } } @@ -1426,7 +1415,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareFunction& glo currentModule->getModuleScope()->bindings[global.name] = Binding{fnType, global.location}; } -ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType) +ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& expr, std::optional expectedType, bool forceSingleton) { RecursionCounter _rc(&checkRecursionCount); if (FInt::LuauCheckRecursionLimit > 0 && checkRecursionCount >= FInt::LuauCheckRecursionLimit) @@ -1443,14 +1432,14 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& result = {nilType}; else if (const AstExprConstantBool* bexpr = expr.as()) { - if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType)) + if (FFlag::LuauSingletonTypes && (forceSingleton || (expectedType && maybeSingleton(*expectedType)))) result = {singletonType(bexpr->value)}; else result = {booleanType}; } else if (const AstExprConstantString* sexpr = expr.as()) { - if (FFlag::LuauSingletonTypes && expectedType && maybeSingleton(*expectedType)) + if (FFlag::LuauSingletonTypes && (forceSingleton || (expectedType && maybeSingleton(*expectedType)))) result = {singletonType(std::string(sexpr->value.data, sexpr->value.size))}; else result = {stringType}; @@ -1488,15 +1477,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& result.type = follow(result.type); - if (FFlag::LuauStoreMatchingOverloadFnType) - { - if (!currentModule->astTypes.find(&expr)) - currentModule->astTypes[&expr] = result.type; - } - else - { + if (!currentModule->astTypes.find(&expr)) currentModule->astTypes[&expr] = result.type; - } if (expectedType) currentModule->astExpectedTypes[&expr] = *expectedType; @@ -2242,7 +2224,6 @@ TypeId TypeChecker::checkRelationalOperation( state.log.commit(); } - bool needsMetamethod = !isEquality; TypeId leftType = follow(lhsType); @@ -2250,10 +2231,11 @@ TypeId TypeChecker::checkRelationalOperation( { reportErrors(state.errors); - const PrimitiveTypeVar* ptv = get(leftType); - if (!isEquality && state.errors.empty() && (get(leftType) || (ptv && ptv->type == PrimitiveTypeVar::Boolean))) + if (!isEquality && state.errors.empty() && (get(leftType) || isBoolean(leftType))) + { reportError(expr.location, GenericError{format("Type '%s' cannot be compared with relational operator %s", toString(leftType).c_str(), toString(expr.op).c_str())}); + } return booleanType; } @@ -2501,7 +2483,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi ExprResult rhs = checkExpr(innerScope, *expr.right); - return {checkBinaryOperation(innerScope, expr, lhs.type, rhs.type), {AndPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; + return {checkBinaryOperation(FFlag::LuauDiscriminableUnions ? scope : innerScope, expr, lhs.type, rhs.type), + {AndPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; } else if (expr.op == AstExprBinary::Or) { @@ -2513,7 +2496,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi ExprResult rhs = checkExpr(innerScope, *expr.right); // Because of C++, I'm not sure if lhs.predicates was not moved out by the time we call checkBinaryOperation. - TypeId result = checkBinaryOperation(innerScope, expr, lhs.type, rhs.type, lhs.predicates); + TypeId result = checkBinaryOperation(FFlag::LuauDiscriminableUnions ? scope : innerScope, expr, lhs.type, rhs.type, lhs.predicates); return {result, {OrPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; } else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) @@ -2521,8 +2504,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi if (auto predicate = tryGetTypeGuardPredicate(expr)) return {booleanType, {std::move(*predicate)}}; - ExprResult lhs = checkExpr(scope, *expr.left); - ExprResult rhs = checkExpr(scope, *expr.right); + ExprResult lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions); + ExprResult rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions); PredicateVec predicates; @@ -2621,11 +2604,10 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIf TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr) { - auto [ty, binding] = checkLValueBinding(scope, expr); - return ty; + return checkLValueBinding(scope, expr); } -std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExpr& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExpr& expr) { if (auto a = expr.as()) return checkLValueBinding(scope, *a); @@ -2639,22 +2621,22 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope { for (AstExpr* expr : a->expressions) checkExpr(scope, *expr); - return {errorRecoveryType(scope), nullptr}; + return errorRecoveryType(scope); } else ice("Unexpected AST node in checkLValue", expr.location); } -std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprLocal& expr) { if (std::optional ty = scope->lookup(expr.local)) - return {*ty, nullptr}; + return *ty; reportError(expr.location, UnknownSymbol{expr.local->name.value, UnknownSymbol::Binding}); - return {errorRecoveryType(scope), nullptr}; + return errorRecoveryType(scope); } -std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprGlobal& expr) { Name name = expr.name.value; ScopePtr moduleScope = currentModule->getModuleScope(); @@ -2662,7 +2644,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope const auto it = moduleScope->bindings.find(expr.name); if (it != moduleScope->bindings.end()) - return std::pair(it->second.typeId, &it->second.typeId); + return it->second.typeId; TypeId result = freshType(scope); Binding& binding = moduleScope->bindings[expr.name]; @@ -2673,15 +2655,15 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!isNonstrictMode()) reportError(TypeError{expr.location, UnknownSymbol{name, UnknownSymbol::Binding}}); - return std::pair(result, &binding.typeId); + return result; } -std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexName& expr) { TypeId lhs = checkExpr(scope, *expr.expr).type; if (get(lhs) || get(lhs)) - return std::pair(lhs, nullptr); + return lhs; tablify(lhs); @@ -2694,7 +2676,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope const auto& it = lhsTable->props.find(name); if (it != lhsTable->props.end()) { - return std::pair(it->second.type, &it->second.type); + return it->second.type; } else if (lhsTable->state == TableState::Unsealed || lhsTable->state == TableState::Free) { @@ -2702,7 +2684,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope Property& property = lhsTable->props[name]; property.type = theType; property.location = expr.indexLocation; - return std::pair(theType, &property.type); + return theType; } else if (auto indexer = lhsTable->indexer) { @@ -2720,17 +2702,17 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope else if (FFlag::LuauUseCommittingTxnLog) state.log.commit(); - return std::pair(retType, nullptr); + return retType; } else if (lhsTable->state == TableState::Sealed) { reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); - return std::pair(errorRecoveryType(scope), nullptr); + return errorRecoveryType(scope); } else { reportError(TypeError{expr.location, GenericError{"Internal error: generic tables are not lvalues"}}); - return std::pair(errorRecoveryType(scope), nullptr); + return errorRecoveryType(scope); } } else if (const ClassTypeVar* lhsClass = get(lhs)) @@ -2739,29 +2721,29 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!prop) { reportError(TypeError{expr.location, UnknownProperty{lhs, name}}); - return std::pair(errorRecoveryType(scope), nullptr); + return errorRecoveryType(scope); } - return std::pair(prop->type, nullptr); + return prop->type; } else if (get(lhs)) { if (std::optional ty = getIndexTypeFromType(scope, lhs, name, expr.location, false)) - return std::pair(*ty, nullptr); + return *ty; // If intersection has a table part, report that it cannot be extended just as a sealed table if (isTableIntersection(lhs)) { reportError(TypeError{expr.location, CannotExtendTable{lhs, CannotExtendTable::Property, name}}); - return std::pair(errorRecoveryType(scope), nullptr); + return errorRecoveryType(scope); } } reportError(TypeError{expr.location, NotATable{lhs}}); - return std::pair(errorRecoveryType(scope), nullptr); + return errorRecoveryType(scope); } -std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr) +TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndexExpr& expr) { TypeId exprType = checkExpr(scope, *expr.expr).type; tablify(exprType); @@ -2771,7 +2753,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope TypeId indexType = checkExpr(scope, *expr.index).type; if (get(exprType) || get(exprType)) - return std::pair(exprType, nullptr); + return exprType; AstExprConstantString* value = expr.index->as(); @@ -2783,9 +2765,9 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!prop) { reportError(TypeError{expr.location, UnknownProperty{exprType, value->value.data}}); - return std::pair(errorRecoveryType(scope), nullptr); + return errorRecoveryType(scope); } - return std::pair(prop->type, nullptr); + return prop->type; } } @@ -2794,7 +2776,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope if (!exprTable) { reportError(TypeError{expr.expr->location, NotATable{exprType}}); - return std::pair(errorRecoveryType(scope), nullptr); + return errorRecoveryType(scope); } if (value) @@ -2802,7 +2784,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope const auto& it = exprTable->props.find(value->value.data); if (it != exprTable->props.end()) { - return std::pair(it->second.type, &it->second.type); + return it->second.type; } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) { @@ -2810,7 +2792,7 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope Property& property = exprTable->props[value->value.data]; property.type = resultType; property.location = expr.index->location; - return std::pair(resultType, &property.type); + return resultType; } } @@ -2818,18 +2800,18 @@ std::pair TypeChecker::checkLValueBinding(const ScopePtr& scope { const TableIndexer& indexer = *exprTable->indexer; unify(indexType, indexer.indexType, expr.index->location); - return std::pair(indexer.indexResultType, nullptr); + return indexer.indexResultType; } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) { TypeId resultType = freshType(scope); exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; - return std::pair(resultType, nullptr); + return resultType; } else { TypeId resultType = freshType(scope); - return std::pair(resultType, nullptr); + return resultType; } } @@ -3326,7 +3308,7 @@ void TypeChecker::checkArgumentList( } // ok else { - state.errors.push_back(TypeError{state.location, CountMismatch{minParams, paramIndex}}); + state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex}}); return; } ++paramIter; @@ -3348,7 +3330,7 @@ void TypeChecker::checkArgumentList( Location location = state.location; if (!argLocations.empty()) location = {state.location.begin, argLocations.back().end}; - state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); return; } TypePackId tail = state.log.follow(*paramIter.tail()); @@ -3405,7 +3387,7 @@ void TypeChecker::checkArgumentList( if (!argLocations.empty()) location = {state.location.begin, argLocations.back().end}; // TODO: Better error message? - state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); return; } } @@ -3520,7 +3502,7 @@ void TypeChecker::checkArgumentList( } // ok else { - state.errors.push_back(TypeError{state.location, CountMismatch{minParams, paramIndex}}); + state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex}}); return; } ++paramIter; @@ -3540,7 +3522,7 @@ void TypeChecker::checkArgumentList( Location location = state.location; if (!argLocations.empty()) location = {state.location.begin, argLocations.back().end}; - state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); return; } TypePackId tail = *paramIter.tail(); @@ -3606,7 +3588,7 @@ void TypeChecker::checkArgumentList( if (!argLocations.empty()) location = {state.location.begin, argLocations.back().end}; // TODO: Better error message? - state.errors.push_back(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); return; } } @@ -3825,22 +3807,11 @@ std::optional> TypeChecker::checkCallOverload(const Scope metaArgLocations = *argLocations; metaArgLocations.insert(metaArgLocations.begin(), expr.func->location); - if (FFlag::LuauFixRecursiveMetatableCall) - { - fn = instantiate(scope, *ty, expr.func->location); - - argPack = metaCallArgPack; - args = metaCallArgs; - argLocations = &metaArgLocations; - } - else - { - TypeId fn = *ty; - fn = instantiate(scope, fn, expr.func->location); + fn = instantiate(scope, *ty, expr.func->location); - return checkCallOverload(scope, expr, fn, retPack, metaCallArgPack, metaCallArgs, &metaArgLocations, argListResult, - overloadsThatMatchArgCount, overloadsThatDont, errors); - } + argPack = metaCallArgPack; + args = metaCallArgs; + argLocations = &metaArgLocations; } } @@ -3932,8 +3903,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope } } - if (FFlag::LuauStoreMatchingOverloadFnType) - currentModule->astOverloadResolvedTypes[&expr] = fn; + currentModule->astOverloadResolvedTypes[&expr] = fn; // We select this overload return {{retPack}}; @@ -4776,7 +4746,7 @@ TypeId TypeChecker::freshType(TypeLevel level) TypeId TypeChecker::singletonType(bool value) { // TODO: cache singleton types - return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(BoolSingleton{value}))); + return currentModule->internalTypes.addType(TypeVar(SingletonTypeVar(BooleanSingleton{value}))); } TypeId TypeChecker::singletonType(std::string value) @@ -4813,13 +4783,6 @@ std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predic return std::nullopt; } -TypeId TypeChecker::addType(const UnionTypeVar& utv) -{ - LUAU_ASSERT(utv.options.size() > 1); - - return addTV(TypeVar(utv)); -} - TypeId TypeChecker::addTV(TypeVar&& tv) { return currentModule->internalTypes.addType(std::move(tv)); @@ -5347,54 +5310,35 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, TypeId instantiated = *maybeInstantiated; - if (FFlag::LuauCloneCorrectlyBeforeMutatingTableType) - { - // TODO: CLI-46926 it's not a good idea to rename the type here - TypeId target = follow(instantiated); - bool needsClone = follow(tf.type) == target; - TableTypeVar* ttv = getMutableTableType(target); + // TODO: CLI-46926 it's not a good idea to rename the type here + TypeId target = follow(instantiated); + bool needsClone = follow(tf.type) == target; + TableTypeVar* ttv = getMutableTableType(target); - if (ttv && needsClone) + if (ttv && needsClone) + { + // Substitution::clone is a shallow clone. If this is a metatable type, we + // want to mutate its table, so we need to explicitly clone that table as + // well. If we don't, we will mutate another module's type surface and cause + // a use-after-free. + if (get(target)) { - // Substitution::clone is a shallow clone. If this is a metatable type, we - // want to mutate its table, so we need to explicitly clone that table as - // well. If we don't, we will mutate another module's type surface and cause - // a use-after-free. - if (get(target)) - { - instantiated = applyTypeFunction.clone(tf.type); - MetatableTypeVar* mtv = getMutable(instantiated); - mtv->table = applyTypeFunction.clone(mtv->table); - ttv = getMutable(mtv->table); - } - if (get(target)) - { - instantiated = applyTypeFunction.clone(tf.type); - ttv = getMutable(instantiated); - } + instantiated = applyTypeFunction.clone(tf.type); + MetatableTypeVar* mtv = getMutable(instantiated); + mtv->table = applyTypeFunction.clone(mtv->table); + ttv = getMutable(mtv->table); } - - if (ttv) + if (get(target)) { - ttv->instantiatedTypeParams = typeParams; - ttv->instantiatedTypePackParams = typePackParams; + instantiated = applyTypeFunction.clone(tf.type); + ttv = getMutable(instantiated); } } - else - { - if (TableTypeVar* ttv = getMutableTableType(instantiated)) - { - if (follow(tf.type) == instantiated) - { - // This can happen if a type alias has generics that it does not use at all. - // ex type FooBar = { a: number } - instantiated = applyTypeFunction.clone(tf.type); - ttv = getMutableTableType(instantiated); - } - ttv->instantiatedTypeParams = typeParams; - ttv->instantiatedTypePackParams = typePackParams; - } + if (ttv) + { + ttv->instantiatedTypeParams = typeParams; + ttv->instantiatedTypePackParams = typePackParams; } return instantiated; @@ -5482,6 +5426,85 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st return {generics, genericPacks}; } +void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate) +{ + LUAU_ASSERT(FFlag::LuauDiscriminableUnions); + + const LValue* target = &lvalue; + std::optional key; // If set, we know we took the base of the lvalue path and should be walking down each option of the base's type. + + auto ty = resolveLValue(scope, *target); + if (!ty) + return; // Do nothing. An error was already reported. + + // If the provided lvalue is a local or global, then that's without a doubt the target. + // However, if there is a base lvalue, then we'll want that to be the target iff the base is a union type. + if (auto base = baseof(lvalue)) + { + std::optional baseTy = resolveLValue(scope, *base); + if (baseTy && get(follow(*baseTy))) + { + ty = baseTy; + target = base; + key = lvalue; + } + } + + // If we do not have a key, it means we're not trying to discriminate anything, so it's a simple matter of just filtering for a subset. + if (!key) + { + if (std::optional result = filterMap(*ty, predicate)) + addRefinement(refis, *target, *result); + else + addRefinement(refis, *target, errorRecoveryType(scope)); + + return; + } + + // Otherwise, we'll want to walk each option of ty, get its index type, and filter that. + auto utv = get(follow(*ty)); + LUAU_ASSERT(utv); + + std::unordered_set viableTargetOptions; + std::unordered_set viableChildOptions; // There may be additional refinements that apply. We add those here too. + + for (TypeId option : utv) + { + std::optional discriminantTy; + if (auto field = Luau::get(*key)) // need to fully qualify Luau::get because of ADL. + discriminantTy = getIndexTypeFromType(scope, option, field->key, Location(), false); + else + LUAU_ASSERT(!"Unhandled LValue alternative?"); + + if (!discriminantTy) + return; // Do nothing. An error was already reported, as per usual. + + if (std::optional result = filterMap(*discriminantTy, predicate)) + { + viableTargetOptions.insert(option); + viableChildOptions.insert(*result); + } + } + + auto intoType = [this](const std::unordered_set& s) -> std::optional { + if (s.empty()) + return std::nullopt; + + // TODO: allocate UnionTypeVar and just normalize. + std::vector options(s.begin(), s.end()); + if (options.size() == 1) + return options[0]; + + return addType(UnionTypeVar{std::move(options)}); + }; + + if (std::optional viableTargetType = intoType(viableTargetOptions)) + addRefinement(refis, *target, *viableTargetType); + + if (std::optional viableChildType = intoType(viableChildOptions)) + addRefinement(refis, lvalue, *viableChildType); +} + std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LValue& lvalue) { if (!FFlag::LuauLValueAsKey) @@ -5645,18 +5668,29 @@ void TypeChecker::resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, Refi return std::nullopt; }; - std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); - if (!ty) - return; + if (FFlag::LuauDiscriminableUnions) + { + std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); + if (ty && fromOr) + return addRefinement(refis, truthyP.lvalue, *ty); + + refineLValue(truthyP.lvalue, refis, scope, predicate); + } + else + { + std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); + if (!ty) + return; - // This is a hack. :( - // Without this, the expression 'a or b' might refine 'b' to be falsy. - // I'm not yet sure how else to get this to do the right thing without this hack, so we'll do this for now in the meantime. - if (fromOr) - return addRefinement(refis, truthyP.lvalue, *ty); + // This is a hack. :( + // Without this, the expression 'a or b' might refine 'b' to be falsy. + // I'm not yet sure how else to get this to do the right thing without this hack, so we'll do this for now in the meantime. + if (fromOr) + return addRefinement(refis, truthyP.lvalue, *ty); - if (std::optional result = filterMap(*ty, predicate)) - addRefinement(refis, truthyP.lvalue, *result); + if (std::optional result = filterMap(*ty, predicate)) + addRefinement(refis, truthyP.lvalue, *result); + } } void TypeChecker::resolve(const AndPredicate& andP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) @@ -5746,16 +5780,23 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement return res; }; - std::optional ty = resolveLValue(refis, scope, isaP.lvalue); - if (!ty) - return; - - if (std::optional result = filterMap(*ty, predicate)) - addRefinement(refis, isaP.lvalue, *result); + if (FFlag::LuauDiscriminableUnions) + { + refineLValue(isaP.lvalue, refis, scope, predicate); + } else { - addRefinement(refis, isaP.lvalue, errorRecoveryType(scope)); - errVec.push_back(TypeError{isaP.location, TypeMismatch{isaP.ty, *ty}}); + std::optional ty = resolveLValue(refis, scope, isaP.lvalue); + if (!ty) + return; + + if (std::optional result = filterMap(*ty, predicate)) + addRefinement(refis, isaP.lvalue, *result); + else + { + addRefinement(refis, isaP.lvalue, errorRecoveryType(scope)); + errVec.push_back(TypeError{isaP.location, TypeMismatch{isaP.ty, *ty}}); + } } } @@ -5814,21 +5855,30 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec if (auto it = primitives.find(typeguardP.kind); it != primitives.end()) { - if (std::optional result = filterMap(*ty, it->second(sense))) - addRefinement(refis, typeguardP.lvalue, *result); - else + if (FFlag::LuauDiscriminableUnions) { - addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); - if (sense) - errVec.push_back( - TypeError{typeguardP.location, GenericError{"Type '" + toString(*ty) + "' has no overlap with '" + typeguardP.kind + "'"}}); + refineLValue(typeguardP.lvalue, refis, scope, it->second(sense)); + return; } + else + { + if (std::optional result = filterMap(*ty, it->second(sense))) + addRefinement(refis, typeguardP.lvalue, *result); + else + { + addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); + if (sense) + errVec.push_back( + TypeError{typeguardP.location, GenericError{"Type '" + toString(*ty) + "' has no overlap with '" + typeguardP.kind + "'"}}); + } - return; + return; + } } auto fail = [&](const TypeErrorData& err) { - errVec.push_back(TypeError{typeguardP.location, err}); + if (!FFlag::LuauDiscriminableUnions) + errVec.push_back(TypeError{typeguardP.location, err}); addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); }; @@ -5853,55 +5903,85 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense) { // This refinement will require success typing to do everything correctly. For now, we can get most of the way there. - auto options = [](TypeId ty) -> std::vector { if (auto utv = get(follow(ty))) return std::vector(begin(utv), end(utv)); return {ty}; }; - if (FFlag::LuauWeakEqConstraint) + if (FFlag::LuauDiscriminableUnions) { - if (!sense && isNil(eqP.type)) - resolve(TruthyPredicate{std::move(eqP.lvalue), eqP.location}, errVec, refis, scope, true, /* fromOr= */ false); + std::vector rhs = options(eqP.type); - return; - } + if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) + return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. - if (FFlag::LuauEqConstraint) - { - std::optional ty = resolveLValue(refis, scope, eqP.lvalue); - if (!ty) - return; + auto predicate = [&](TypeId option) -> std::optional { + if (sense && isUndecidable(option)) + return FFlag::LuauWeakEqConstraint ? option : eqP.type; - std::vector lhs = options(*ty); - std::vector rhs = options(eqP.type); + if (!sense && isNil(eqP.type)) + return (isUndecidable(option) || !isNil(option)) ? std::optional(option) : std::nullopt; + + if (maybeSingleton(eqP.type)) + { + // Normally we'd write option <: eqP.type, but singletons are always the subtype, so we flip this. + if (!sense || canUnify(eqP.type, option, eqP.location).empty()) + return sense ? eqP.type : option; + + return std::nullopt; + } + + return option; + }; - if (sense && std::any_of(lhs.begin(), lhs.end(), isUndecidable)) + refineLValue(eqP.lvalue, refis, scope, predicate); + } + else + { + if (FFlag::LuauWeakEqConstraint) { - addRefinement(refis, eqP.lvalue, eqP.type); + if (!sense && isNil(eqP.type)) + resolve(TruthyPredicate{std::move(eqP.lvalue), eqP.location}, errVec, refis, scope, true, /* fromOr= */ false); + return; } - else if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) - return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. - std::unordered_set set; - for (TypeId left : lhs) + if (FFlag::LuauEqConstraint) { - for (TypeId right : rhs) + std::optional ty = resolveLValue(refis, scope, eqP.lvalue); + if (!ty) + return; + + std::vector lhs = options(*ty); + std::vector rhs = options(eqP.type); + + if (sense && std::any_of(lhs.begin(), lhs.end(), isUndecidable)) { - // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. - if (canUnify(right, left, eqP.location).empty() == sense || (!sense && !isNil(left))) - set.insert(left); + addRefinement(refis, eqP.lvalue, eqP.type); + return; } - } + else if (sense && std::any_of(rhs.begin(), rhs.end(), isUndecidable)) + return; // Optimization: the other side has unknown types, so there's probably an overlap. Refining is no-op here. - if (set.empty()) - return; + std::unordered_set set; + for (TypeId left : lhs) + { + for (TypeId right : rhs) + { + // When singleton types arrive, `isNil` here probably should be replaced with `isLiteral`. + if (canUnify(right, left, eqP.location).empty() == sense || (!sense && !isNil(left))) + set.insert(left); + } + } + + if (set.empty()) + return; - std::vector viable(set.begin(), set.end()); - TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)}); - addRefinement(refis, eqP.lvalue, result); + std::vector viable(set.begin(), set.end()); + TypeId result = viable.size() == 1 ? viable[0] : addType(UnionTypeVar{std::move(viable)}); + addRefinement(refis, eqP.lvalue, result); + } } } diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index df5d76ed0..5b162b31b 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -18,14 +18,15 @@ #include #include +LUAU_FASTFLAG(DebugLuauFreezeArena) + LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauLengthOnCompositeType) LUAU_FASTFLAGVARIABLE(LuauMetatableAreEqualRecursion, false) -LUAU_FASTFLAGVARIABLE(LuauRefactorTagging, false) +LUAU_FASTFLAGVARIABLE(LuauRefactorTypeVarQuestions, false) LUAU_FASTFLAG(LuauErrorRecoveryType) -LUAU_FASTFLAG(DebugLuauFreezeArena) namespace Luau { @@ -144,7 +145,20 @@ bool isNil(TypeId ty) bool isBoolean(TypeId ty) { - return isPrim(ty, PrimitiveTypeVar::Boolean); + if (FFlag::LuauRefactorTypeVarQuestions) + { + if (isPrim(ty, PrimitiveTypeVar::Boolean) || get(get(follow(ty)))) + return true; + + if (auto utv = get(follow(ty))) + return std::all_of(begin(utv), end(utv), isBoolean); + + return false; + } + else + { + return isPrim(ty, PrimitiveTypeVar::Boolean); + } } bool isNumber(TypeId ty) @@ -154,7 +168,20 @@ bool isNumber(TypeId ty) bool isString(TypeId ty) { - return isPrim(ty, PrimitiveTypeVar::String); + if (FFlag::LuauRefactorTypeVarQuestions) + { + if (isPrim(ty, PrimitiveTypeVar::String) || get(get(follow(ty)))) + return true; + + if (auto utv = get(follow(ty))) + return std::all_of(begin(utv), end(utv), isString); + + return false; + } + else + { + return isPrim(ty, PrimitiveTypeVar::String); + } } bool isThread(TypeId ty) @@ -167,37 +194,45 @@ bool isOptional(TypeId ty) if (isNil(ty)) return true; - if (!get(follow(ty))) - return false; + if (FFlag::LuauRefactorTypeVarQuestions) + { + auto utv = get(follow(ty)); + if (!utv) + return false; - std::unordered_set seen; - std::deque queue{ty}; - while (!queue.empty()) + return std::any_of(begin(utv), end(utv), isNil); + } + else { - TypeId current = follow(queue.front()); - queue.pop_front(); + std::unordered_set seen; + std::deque queue{ty}; + while (!queue.empty()) + { + TypeId current = follow(queue.front()); + queue.pop_front(); - if (seen.count(current)) - continue; + if (seen.count(current)) + continue; - seen.insert(current); + seen.insert(current); - if (isNil(current)) - return true; + if (isNil(current)) + return true; - if (auto u = get(current)) - { - for (TypeId option : u->options) + if (auto u = get(current)) { - if (isNil(option)) - return true; + for (TypeId option : u->options) + { + if (isNil(option)) + return true; - queue.push_back(option); + queue.push_back(option); + } } } - } - return false; + return false; + } } bool isTableIntersection(TypeId ty) @@ -228,13 +263,27 @@ std::optional getMetatable(TypeId type) return mtType->metatable; else if (const ClassTypeVar* classType = get(type)) return classType->metatable; - else if (const PrimitiveTypeVar* primitiveType = get(type); primitiveType && primitiveType->metatable) + else if (FFlag::LuauRefactorTypeVarQuestions) { - LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String); - return primitiveType->metatable; + if (isString(type)) + { + auto ptv = get(getSingletonTypes().stringType); + LUAU_ASSERT(ptv && ptv->metatable); + return ptv->metatable; + } + else + return std::nullopt; } else - return std::nullopt; + { + if (const PrimitiveTypeVar* primitiveType = get(type); primitiveType && primitiveType->metatable) + { + LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String); + return primitiveType->metatable; + } + else + return std::nullopt; + } } const TableTypeVar* getTableType(TypeId type) @@ -696,7 +745,7 @@ TypeId SingletonTypes::makeStringMetatable() {"reverse", {stringToStringType}}, {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, {"upper", {stringToStringType}}, - {"split", {makeFunction(*arena, stringType, {}, {}, {stringType, optionalString}, {}, + {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, {arena->addType(TableTypeVar{{}, TableIndexer{numberType, stringType}, TypeLevel{}})})}}, {"pack", {arena->addType(FunctionTypeVar{ arena->addTypePack(TypePack{{stringType}, anyTypePack}), @@ -1108,30 +1157,14 @@ static Tags* getTags(TypeId ty) void attachTag(TypeId ty, const std::string& tagName) { - if (!FFlag::LuauRefactorTagging) - { - if (auto ftv = getMutable(ty)) - { - ftv->tags.emplace_back(tagName); - } - else - { - LUAU_ASSERT(!"Got a non functional type"); - } - } + if (auto tags = getTags(ty)) + tags->push_back(tagName); else - { - if (auto tags = getTags(ty)) - tags->push_back(tagName); - else - LUAU_ASSERT(!"This TypeId does not support tags"); - } + LUAU_ASSERT(!"This TypeId does not support tags"); } void attachTag(Property& prop, const std::string& tagName) { - LUAU_ASSERT(FFlag::LuauRefactorTagging); - prop.tags.push_back(tagName); } @@ -1140,7 +1173,6 @@ void attachTag(Property& prop, const std::string& tagName) // Unfortunately, there's already use cases that's hard to disentangle. For now, we expose it. bool hasTag(const Tags& tags, const std::string& tagName) { - LUAU_ASSERT(FFlag::LuauRefactorTagging); return std::find(tags.begin(), tags.end(), tagName) != tags.end(); } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 2bd9cf83f..17d9bf58f 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -17,15 +17,11 @@ LUAU_FASTFLAGVARIABLE(LuauCommittingTxnLogFreeTpPromote, false) LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); -LUAU_FASTFLAGVARIABLE(LuauUnionHeuristic, false) LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) -LUAU_FASTFLAGVARIABLE(LuauOccursCheckOkWithRecursiveFunctions, false) LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauProperTypeLevels); LUAU_FASTFLAGVARIABLE(LuauUnifyPackTails, false) -LUAU_FASTFLAGVARIABLE(LuauExtendedUnionMismatchError, false) -LUAU_FASTFLAGVARIABLE(LuauExtendedFunctionMismatchError, false) namespace Luau { @@ -229,8 +225,6 @@ static std::optional hasUnificationTooComplex(const ErrorVec& errors) // Used for tagged union matching heuristic, returns first singleton type field static std::optional> getTableMatchTag(TypeId type) { - LUAU_ASSERT(FFlag::LuauExtendedUnionMismatchError); - type = follow(type); if (auto ttv = get(type)) @@ -291,7 +285,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) { - errors.push_back(TypeError{location, UnificationTooComplex{}}); + reportError(TypeError{location, UnificationTooComplex{}}); return; } @@ -403,7 +397,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (subGeneric && !subGeneric->level.subsumes(superLevel)) { // TODO: a more informative error message? CLI-39912 - errors.push_back(TypeError{location, GenericError{"Generic subtype escaping scope"}}); + reportError(TypeError{location, GenericError{"Generic subtype escaping scope"}}); return; } @@ -448,7 +442,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (superGeneric && !superGeneric->level.subsumes(subFree->level)) { // TODO: a more informative error message? CLI-39912 - errors.push_back(TypeError{location, GenericError{"Generic supertype escaping scope"}}); + reportError(TypeError{location, GenericError{"Generic supertype escaping scope"}}); return; } @@ -561,13 +555,13 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool } if (unificationTooComplex) - errors.push_back(*unificationTooComplex); + reportError(*unificationTooComplex); else if (failed) { if (firstFailedOption) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } } else if (const UnionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) @@ -582,50 +576,44 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool bool foundHeuristic = false; size_t startIndex = 0; - if (FFlag::LuauUnionHeuristic) + if (const std::string* subName = getName(subTy)) { - if (const std::string* subName = getName(subTy)) + for (size_t i = 0; i < uv->options.size(); ++i) { - for (size_t i = 0; i < uv->options.size(); ++i) + const std::string* optionName = getName(uv->options[i]); + if (optionName && *optionName == *subName) { - const std::string* optionName = getName(uv->options[i]); - if (optionName && *optionName == *subName) - { - foundHeuristic = true; - startIndex = i; - break; - } + foundHeuristic = true; + startIndex = i; + break; } } + } - if (FFlag::LuauExtendedUnionMismatchError) + if (auto subMatchTag = getTableMatchTag(subTy)) + { + for (size_t i = 0; i < uv->options.size(); ++i) { - if (auto subMatchTag = getTableMatchTag(subTy)) + auto optionMatchTag = getTableMatchTag(uv->options[i]); + if (optionMatchTag && optionMatchTag->first == subMatchTag->first && *optionMatchTag->second == *subMatchTag->second) { - for (size_t i = 0; i < uv->options.size(); ++i) - { - auto optionMatchTag = getTableMatchTag(uv->options[i]); - if (optionMatchTag && optionMatchTag->first == subMatchTag->first && *optionMatchTag->second == *subMatchTag->second) - { - foundHeuristic = true; - startIndex = i; - break; - } - } + foundHeuristic = true; + startIndex = i; + break; } } + } - if (!foundHeuristic && cacheEnabled) + if (!foundHeuristic && cacheEnabled) + { + for (size_t i = 0; i < uv->options.size(); ++i) { - for (size_t i = 0; i < uv->options.size(); ++i) - { - TypeId type = uv->options[i]; + TypeId type = uv->options[i]; - if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type}))) - { - startIndex = i; - break; - } + if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type}))) + { + startIndex = i; + break; } } } @@ -650,7 +638,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { unificationTooComplex = e; } - else if (FFlag::LuauExtendedUnionMismatchError && !isNil(type)) + else if (!isNil(type)) { failedOptionCount++; @@ -664,15 +652,15 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (unificationTooComplex) { - errors.push_back(*unificationTooComplex); + reportError(*unificationTooComplex); } else if (!found) { - if (FFlag::LuauExtendedUnionMismatchError && (failedOptionCount == 1 || foundHeuristic) && failedOption) - errors.push_back( + if ((failedOptionCount == 1 || foundHeuristic) && failedOption) + reportError( TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}}); else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); } } else if (const IntersectionTypeVar* uv = @@ -702,9 +690,9 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool } if (unificationTooComplex) - errors.push_back(*unificationTooComplex); + reportError(*unificationTooComplex); else if (firstFailedOption) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); } else if (const IntersectionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) @@ -754,10 +742,10 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool } if (unificationTooComplex) - errors.push_back(*unificationTooComplex); + reportError(*unificationTooComplex); else if (!found) { - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); } } else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || @@ -801,7 +789,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool tryUnifyWithClass(subTy, superTy, /*reversed*/ true); else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); if (FFlag::LuauUseCommittingTxnLog) log.popSeen(superTy, subTy); @@ -1067,7 +1055,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (FInt::LuauTypeInferIterationLimit > 0 && FInt::LuauTypeInferIterationLimit < sharedState.counters.iterationCount) { - errors.push_back(TypeError{location, UnificationTooComplex{}}); + reportError(TypeError{location, UnificationTooComplex{}}); return; } @@ -1166,7 +1154,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal { tryUnify_(*subIter, *superIter); - if (FFlag::LuauExtendedFunctionMismatchError && !errors.empty() && !firstPackErrorPos) + if (!errors.empty() && !firstPackErrorPos) firstPackErrorPos = loopCount; superIter.advance(); @@ -1251,7 +1239,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal size_t actualSize = size(subTp); if (ctx == CountMismatch::Result) std::swap(expectedSize, actualSize); - errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); + reportError(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); while (superIter.good()) { @@ -1272,7 +1260,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal } else { - errors.push_back(TypeError{location, GenericError{"Failed to unify type packs"}}); + reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); } } else @@ -1372,7 +1360,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal { tryUnify_(*subIter, *superIter); - if (FFlag::LuauExtendedFunctionMismatchError && !errors.empty() && !firstPackErrorPos) + if (!errors.empty() && !firstPackErrorPos) firstPackErrorPos = loopCount; superIter.advance(); @@ -1459,7 +1447,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal size_t actualSize = size(subTp); if (ctx == CountMismatch::Result) std::swap(expectedSize, actualSize); - errors.push_back(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); + reportError(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); while (superIter.good()) { @@ -1480,7 +1468,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal } else { - errors.push_back(TypeError{location, GenericError{"Failed to unify type packs"}}); + reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); } } } @@ -1493,7 +1481,7 @@ void Unifier::tryUnifyPrimitives(TypeId subTy, TypeId superTy) ice("passed non primitive types to unifyPrimitives"); if (superPrim->type != subPrim->type) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) @@ -1508,13 +1496,13 @@ void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) if (superSingleton && *superSingleton == *subSingleton) return; - if (superPrim && superPrim->type == PrimitiveTypeVar::Boolean && get(subSingleton) && variance == Covariant) + if (superPrim && superPrim->type == PrimitiveTypeVar::Boolean && get(subSingleton) && variance == Covariant) return; if (superPrim && superPrim->type == PrimitiveTypeVar::String && get(subSingleton) && variance == Covariant) return; - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall) @@ -1536,10 +1524,7 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal { numGenerics = std::min(superFunction->generics.size(), subFunction->generics.size()); - if (FFlag::LuauExtendedFunctionMismatchError) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}}); - else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type parameters"}}); } size_t numGenericPacks = superFunction->genericPacks.size(); @@ -1547,10 +1532,7 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal { numGenericPacks = std::min(superFunction->genericPacks.size(), subFunction->genericPacks.size()); - if (FFlag::LuauExtendedFunctionMismatchError) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}}); - else - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "different number of generic type pack parameters"}}); } for (size_t i = 0; i < numGenerics; i++) @@ -1567,48 +1549,35 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal { Unifier innerState = makeChildUnifier(); - if (FFlag::LuauExtendedFunctionMismatchError) - { - innerState.ctx = CountMismatch::Arg; - innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); + innerState.ctx = CountMismatch::Arg; + innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); + + bool reported = !innerState.errors.empty(); - bool reported = !innerState.errors.empty(); + if (auto e = hasUnificationTooComplex(innerState.errors)) + reportError(*e); + else if (!innerState.errors.empty() && innerState.firstPackErrorPos) + reportError( + TypeError{location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), + innerState.errors.front()}}); + else if (!innerState.errors.empty()) + reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); + innerState.ctx = CountMismatch::Result; + innerState.tryUnify_(subFunction->retType, superFunction->retType); + + if (!reported) + { if (auto e = hasUnificationTooComplex(innerState.errors)) - errors.push_back(*e); + reportError(*e); + else if (!innerState.errors.empty() && size(superFunction->retType) == 1 && finite(superFunction->retType)) + reportError(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); else if (!innerState.errors.empty() && innerState.firstPackErrorPos) - errors.push_back( - TypeError{location, TypeMismatch{superTy, subTy, format("Argument #%d type is not compatible.", *innerState.firstPackErrorPos), + reportError( + TypeError{location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), innerState.errors.front()}}); else if (!innerState.errors.empty()) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); - - innerState.ctx = CountMismatch::Result; - innerState.tryUnify_(subFunction->retType, superFunction->retType); - - if (!reported) - { - if (auto e = hasUnificationTooComplex(innerState.errors)) - errors.push_back(*e); - else if (!innerState.errors.empty() && size(superFunction->retType) == 1 && finite(superFunction->retType)) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "Return type is not compatible.", innerState.errors.front()}}); - else if (!innerState.errors.empty() && innerState.firstPackErrorPos) - errors.push_back( - TypeError{location, TypeMismatch{superTy, subTy, format("Return #%d type is not compatible.", *innerState.firstPackErrorPos), - innerState.errors.front()}}); - else if (!innerState.errors.empty()) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); - } - } - else - { - ctx = CountMismatch::Arg; - innerState.tryUnify_(superFunction->argTypes, subFunction->argTypes, isFunctionCall); - - ctx = CountMismatch::Result; - innerState.tryUnify_(subFunction->retType, superFunction->retType); - - checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); } if (FFlag::LuauUseCommittingTxnLog) @@ -1716,7 +1685,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!missingProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); + reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); return; } } @@ -1734,7 +1703,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!extraProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); + reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); return; } } @@ -1957,13 +1926,13 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!missingProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); + reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingProperties)}}); return; } if (!extraProperties.empty()) { - errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); + reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraProperties), MissingProperties::Extra}}); return; } @@ -2051,7 +2020,7 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt return tryUnifySealedTables(subTy, superTy, isIntersection); else if ((superTable->state == TableState::Sealed && subTable->state == TableState::Generic) || (superTable->state == TableState::Generic && subTable->state == TableState::Sealed)) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); else if ((superTable->state == TableState::Free) != (subTable->state == TableState::Free)) // one table is free and the other is not { TypeId freeTypeId = subTable->state == TableState::Free ? subTy : superTy; @@ -2090,7 +2059,7 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt { const auto& r = subTable->props.find(name); if (r == subTable->props.end()) - errors.push_back(TypeError{location, UnknownProperty{subTy, name}}); + reportError(TypeError{location, UnknownProperty{subTy, name}}); else tryUnify_(r->second.type, prop.type); } @@ -2113,7 +2082,7 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt } } else - errors.push_back(TypeError{location, CannotExtendTable{subTy, CannotExtendTable::Indexer}}); + reportError(TypeError{location, CannotExtendTable{subTy, CannotExtendTable::Indexer}}); } } else if (superTable->state == TableState::Sealed) @@ -2194,7 +2163,7 @@ void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) } } else - errors.push_back(TypeError{location, UnknownProperty{subTy, freeName}}); + reportError(TypeError{location, UnknownProperty{subTy, freeName}}); } } @@ -2268,7 +2237,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec if (!missingPropertiesInSuper.empty()) { - errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); + reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); return; } } @@ -2284,7 +2253,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec missingPropertiesInSuper.push_back(it.first); - innerState.errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } else { @@ -2299,7 +2268,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec if (oldErrorSize != innerState.errors.size() && !errorReported) { errorReported = true; - errors.push_back(innerState.errors.back()); + reportError(innerState.errors.back()); } } else @@ -2340,7 +2309,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec } } else - innerState.errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } else { @@ -2369,7 +2338,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec } } else - innerState.errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } } @@ -2386,7 +2355,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec if (!missingPropertiesInSuper.empty()) { - errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); + reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(missingPropertiesInSuper)}}); return; } @@ -2413,7 +2382,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec if (!extraPropertiesInSub.empty()) { - errors.push_back(TypeError{location, MissingProperties{superTy, subTy, std::move(extraPropertiesInSub), MissingProperties::Extra}}); + reportError(TypeError{location, MissingProperties{superTy, subTy, std::move(extraPropertiesInSub), MissingProperties::Extra}}); return; } } @@ -2437,9 +2406,9 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) innerState.tryUnify_(subMetatable->metatable, superMetatable->metatable); if (auto e = hasUnificationTooComplex(innerState.errors)) - errors.push_back(*e); + reportError(*e); else if (!innerState.errors.empty()) - errors.push_back( + reportError( TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); if (FFlag::LuauUseCommittingTxnLog) @@ -2470,7 +2439,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) case TableState::Sealed: case TableState::Unsealed: case TableState::Generic: - errors.push_back(mismatchError); + reportError(mismatchError); } } else if (FFlag::LuauUseCommittingTxnLog ? (log.getMutable(subTy) || log.getMutable(subTy)) @@ -2479,7 +2448,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) } else { - errors.push_back(mismatchError); + reportError(mismatchError); } } @@ -2491,9 +2460,9 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) auto fail = [&]() { if (!reversed) - errors.push_back(TypeError{location, TypeMismatch{superTy, subTy}}); + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); else - errors.push_back(TypeError{location, TypeMismatch{subTy, superTy}}); + reportError(TypeError{location, TypeMismatch{subTy, superTy}}); }; const ClassTypeVar* superClass = get(superTy); @@ -2538,7 +2507,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) if (!classProp) { ok = false; - errors.push_back(TypeError{location, UnknownProperty{superTy, propName}}); + reportError(TypeError{location, UnknownProperty{superTy, propName}}); } else { @@ -2577,7 +2546,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) { ok = false; std::string msg = "Class " + superClass->name + " does not have an indexer"; - errors.push_back(TypeError{location, GenericError{msg}}); + reportError(TypeError{location, GenericError{msg}}); } if (!ok) @@ -2695,7 +2664,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } else if (get(tail)) { - errors.push_back(TypeError{location, GenericError{"Cannot unify variadic and generic packs"}}); + reportError(TypeError{location, GenericError{"Cannot unify variadic and generic packs"}}); } else if (get(tail)) { @@ -2709,7 +2678,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } else { - errors.push_back(TypeError{location, GenericError{"Failed to unify variadic packs"}}); + reportError(TypeError{location, GenericError{"Failed to unify variadic packs"}}); } } @@ -2886,7 +2855,7 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays if (needle == haystack) { - errors.push_back(TypeError{location, OccursCheckFailed{}}); + reportError(TypeError{location, OccursCheckFailed{}}); log.replace(needle, *getSingletonTypes().errorRecoveryType()); return; @@ -2894,17 +2863,6 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays if (log.getMutable(haystack)) return; - else if (auto a = log.getMutable(haystack)) - { - if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) - { - for (TypePackIterator it(a->argTypes, &log); it != end(a->argTypes); ++it) - check(*it); - - for (TypePackIterator it(a->retType, &log); it != end(a->retType); ++it) - check(*it); - } - } else if (auto a = log.getMutable(haystack)) { for (TypeId ty : a->options) @@ -2934,7 +2892,7 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays if (needle == haystack) { - errors.push_back(TypeError{location, OccursCheckFailed{}}); + reportError(TypeError{location, OccursCheckFailed{}}); DEPRECATED_log(needle); *asMutable(needle) = *getSingletonTypes().errorRecoveryType(); return; @@ -2942,17 +2900,6 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays if (get(haystack)) return; - else if (auto a = get(haystack)) - { - if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) - { - for (TypeId ty : a->argTypes) - check(ty); - - for (TypeId ty : a->retType) - check(ty); - } - } else if (auto a = get(haystack)) { for (TypeId ty : a->options) @@ -2988,7 +2935,7 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ if (log.getMutable(needle)) return; - if (!get(needle)) + if (!log.getMutable(needle)) ice("Expected needle pack to be free"); RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); @@ -2997,32 +2944,18 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ { if (needle == haystack) { - errors.push_back(TypeError{location, OccursCheckFailed{}}); + reportError(TypeError{location, OccursCheckFailed{}}); log.replace(needle, *getSingletonTypes().errorRecoveryTypePack()); return; } - if (auto a = get(haystack)) + if (auto a = get(haystack); a && a->tail) { - for (const auto& ty : a->head) - { - if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) - { - if (auto f = log.getMutable(log.follow(ty))) - { - occursCheck(seen, needle, f->argTypes); - occursCheck(seen, needle, f->retType); - } - } - } - - if (a->tail) - { - haystack = follow(*a->tail); - continue; - } + haystack = log.follow(*a->tail); + continue; } + break; } } @@ -3048,31 +2981,17 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ { if (needle == haystack) { - errors.push_back(TypeError{location, OccursCheckFailed{}}); + reportError(TypeError{location, OccursCheckFailed{}}); DEPRECATED_log(needle); *asMutable(needle) = *getSingletonTypes().errorRecoveryTypePack(); } - if (auto a = get(haystack)) + if (auto a = get(haystack); a && a->tail) { - if (!FFlag::LuauOccursCheckOkWithRecursiveFunctions) - { - for (const auto& ty : a->head) - { - if (auto f = get(follow(ty))) - { - occursCheck(seen, needle, f->argTypes); - occursCheck(seen, needle, f->retType); - } - } - } - - if (a->tail) - { - haystack = follow(*a->tail); - continue; - } + haystack = follow(*a->tail); + continue; } + break; } } @@ -3094,17 +3013,17 @@ bool Unifier::isNonstrictMode() const void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, TypeId wantedType, TypeId givenType) { if (auto e = hasUnificationTooComplex(innerErrors)) - errors.push_back(*e); + reportError(*e); else if (!innerErrors.empty()) - errors.push_back(TypeError{location, TypeMismatch{wantedType, givenType}}); + reportError(TypeError{location, TypeMismatch{wantedType, givenType}}); } void Unifier::checkChildUnifierTypeMismatch(const ErrorVec& innerErrors, const std::string& prop, TypeId wantedType, TypeId givenType) { if (auto e = hasUnificationTooComplex(innerErrors)) - errors.push_back(*e); + reportError(*e); else if (!innerErrors.empty()) - errors.push_back( + reportError( TypeError{location, TypeMismatch{wantedType, givenType, format("Property '%s' is not compatible.", prop.c_str()), innerErrors.front()}}); } diff --git a/CLI/Analyze.cpp b/CLI/Analyze.cpp index e0dc3e0fe..10cf17d21 100644 --- a/CLI/Analyze.cpp +++ b/CLI/Analyze.cpp @@ -43,14 +43,14 @@ static void report(ReportFormat format, const char* name, const Luau::Location& } } -static void reportError(ReportFormat format, const Luau::TypeError& error) +static void reportError(const Luau::Frontend& frontend, ReportFormat format, const Luau::TypeError& error) { - const char* name = error.moduleName.c_str(); + std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(error.moduleName); if (const Luau::SyntaxError* syntaxError = Luau::get_if(&error.data)) - report(format, name, error.location, "SyntaxError", syntaxError->message.c_str()); + report(format, humanReadableName.c_str(), error.location, "SyntaxError", syntaxError->message.c_str()); else - report(format, name, error.location, "TypeError", Luau::toString(error).c_str()); + report(format, humanReadableName.c_str(), error.location, "TypeError", Luau::toString(error).c_str()); } static void reportWarning(ReportFormat format, const char* name, const Luau::LintWarning& warning) @@ -72,14 +72,15 @@ static bool analyzeFile(Luau::Frontend& frontend, const char* name, ReportFormat } for (auto& error : cr.errors) - reportError(format, error); + reportError(frontend, format, error); Luau::LintResult lr = frontend.lint(name); + std::string humanReadableName = frontend.fileResolver->getHumanReadableModuleName(name); for (auto& error : lr.errors) - reportWarning(format, name, error); + reportWarning(format, humanReadableName.c_str(), error); for (auto& warning : lr.warnings) - reportWarning(format, name, warning); + reportWarning(format, humanReadableName.c_str(), warning); if (annotate) { @@ -120,11 +121,25 @@ struct CliFileResolver : Luau::FileResolver { std::optional readSource(const Luau::ModuleName& name) override { - std::optional source = readFile(name); + Luau::SourceCode::Type sourceType; + std::optional source = std::nullopt; + + // If the module name is "-", then read source from stdin + if (name == "-") + { + source = readStdin(); + sourceType = Luau::SourceCode::Script; + } + else + { + source = readFile(name); + sourceType = Luau::SourceCode::Module; + } + if (!source) return std::nullopt; - return Luau::SourceCode{*source, Luau::SourceCode::Module}; + return Luau::SourceCode{*source, sourceType}; } std::optional resolveModule(const Luau::ModuleInfo* context, Luau::AstExpr* node) override @@ -143,6 +158,13 @@ struct CliFileResolver : Luau::FileResolver return std::nullopt; } + + std::string getHumanReadableModuleName(const Luau::ModuleName& name) const override + { + if (name == "-") + return "stdin"; + return name; + } }; struct CliConfigResolver : Luau::ConfigResolver diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index cb993dfee..c68070227 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -74,6 +74,21 @@ std::optional readFile(const std::string& name) return result; } +std::optional readStdin() +{ + std::string result; + char buffer[4096] = { }; + + while (fgets(buffer, sizeof(buffer), stdin) != nullptr) + result.append(buffer); + + // If eof was not reached for stdin, then a read error occurred + if (!feof(stdin)) + return std::nullopt; + + return result; +} + template static void joinPaths(std::basic_string& str, const Ch* lhs, const Ch* rhs) { @@ -190,7 +205,10 @@ bool traverseDirectory(const std::string& path, const std::function getSourceFiles(int argc, char** argv) for (int i = 1; i < argc; ++i) { - if (argv[i][0] == '-') + // Treat '-' as a special file whose source is read from stdin + // All other arguments that start with '-' are skipped + if (argv[i][0] == '-' && argv[i][1] != '\0') continue; if (isDirectory(argv[i])) diff --git a/CLI/FileUtils.h b/CLI/FileUtils.h index da11f512d..97471cdc0 100644 --- a/CLI/FileUtils.h +++ b/CLI/FileUtils.h @@ -7,6 +7,7 @@ #include std::optional readFile(const std::string& name); +std::optional readStdin(); bool isDirectory(const std::string& path); bool traverseDirectory(const std::string& path, const std::function& callback); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index e50421528..ab0f0ed08 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -158,7 +158,7 @@ static int lua_collectgarbage(lua_State* L) luaL_error(L, "collectgarbage must be called with 'count' or 'collect'"); } -static void setupState(lua_State* L) +void setupState(lua_State* L) { luaL_openlibs(L); @@ -176,7 +176,7 @@ static void setupState(lua_State* L) luaL_sandbox(L); } -static std::string runCode(lua_State* L, const std::string& source) +std::string runCode(lua_State* L, const std::string& source) { std::string bytecode = Luau::compile(source, copts()); @@ -206,7 +206,13 @@ static std::string runCode(lua_State* L, const std::string& source) if (n) { luaL_checkstack(T, LUA_MINSTACK, "too many results to print"); - lua_getglobal(T, "print"); + lua_getglobal(T, "_PRETTYPRINT"); + // If _PRETTYPRINT is nil, then use the standard print function instead + if (lua_isnil(T, -1)) + { + lua_pop(T, 1); + lua_getglobal(T, "print"); + } lua_insert(T, 1); lua_pcall(T, n, 0, 0); } @@ -545,7 +551,7 @@ static int assertionHandler(const char* expr, const char* file, int line, const return 1; } -int main(int argc, char** argv) +int replMain(int argc, char** argv) { Luau::assertHandler() = assertionHandler; @@ -696,7 +702,6 @@ int main(int argc, char** argv) case CliMode::Unknown: default: LUAU_ASSERT(!"Unhandled cli mode."); + return 1; } } - - diff --git a/CLI/Repl.h b/CLI/Repl.h new file mode 100644 index 000000000..11a077ae8 --- /dev/null +++ b/CLI/Repl.h @@ -0,0 +1,12 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "lua.h" + +#include + +// Note: These are internal functions which are being exposed in a header +// so they can be included by unit tests. +int replMain(int argc, char** argv); +void setupState(lua_State* L); +std::string runCode(lua_State* L, const std::string& source); diff --git a/CLI/ReplEntry.cpp b/CLI/ReplEntry.cpp new file mode 100644 index 000000000..b31317128 --- /dev/null +++ b/CLI/ReplEntry.cpp @@ -0,0 +1,10 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Repl.h" + + + +int main(int argc, char** argv) +{ + return replMain(argc, argv); +} \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 77cf47e85..b9f7a9e11 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -29,6 +29,7 @@ endif() if(LUAU_BUILD_TESTS) add_executable(Luau.UnitTest) add_executable(Luau.Conformance) + add_executable(Luau.CLI.Test) endif() if(LUAU_BUILD_WEB) @@ -109,6 +110,17 @@ if(LUAU_BUILD_TESTS) target_compile_options(Luau.Conformance PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.Conformance PRIVATE extern) target_link_libraries(Luau.Conformance PRIVATE Luau.Analysis Luau.Compiler Luau.VM) + + target_compile_options(Luau.CLI.Test PRIVATE ${LUAU_OPTIONS}) + target_include_directories(Luau.CLI.Test PRIVATE extern CLI) + target_link_libraries(Luau.CLI.Test PRIVATE Luau.Compiler Luau.VM) + if(UNIX) + find_library(LIBPTHREAD pthread) + if (LIBPTHREAD) + target_link_libraries(Luau.CLI.Test PRIVATE pthread) + endif() + endif() + endif() if(LUAU_BUILD_WEB) diff --git a/Compiler/include/Luau/Bytecode.h b/Compiler/include/Luau/Bytecode.h index d9694d7d0..679712f60 100644 --- a/Compiler/include/Luau/Bytecode.h +++ b/Compiler/include/Luau/Bytecode.h @@ -472,6 +472,9 @@ enum LuauBuiltinFunction // bit32.count LBF_BIT32_COUNTLZ, LBF_BIT32_COUNTRZ, + + // select(_, ...) + LBF_SELECT_VARARG, }; // Capture type, used in LOP_CAPTURE diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index e344eb917..a907271c9 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -4,6 +4,8 @@ #include "Luau/Bytecode.h" #include "Luau/Compiler.h" +LUAU_FASTFLAGVARIABLE(LuauCompileSelectBuiltin, false) + namespace Luau { namespace Compile @@ -62,6 +64,9 @@ int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options) if (builtin.isGlobal("unpack")) return LBF_TABLE_UNPACK; + if (FFlag::LuauCompileSelectBuiltin && builtin.isGlobal("select")) + return LBF_SELECT_VARARG; + if (builtin.object == "math") { if (builtin.method == "abs") diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 9758c4a9a..7da852447 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -15,6 +15,9 @@ #include #include +LUAU_FASTFLAGVARIABLE(LuauCompileTableIndexOpt, false) +LUAU_FASTFLAG(LuauCompileSelectBuiltin) + namespace Luau { @@ -261,6 +264,122 @@ struct Compiler bytecode.emitABC(LOP_GETVARARGS, target, multRet ? 0 : uint8_t(targetCount + 1), 0); } + void compileExprSelectVararg(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop, bool multRet, uint8_t regs) + { + LUAU_ASSERT(FFlag::LuauCompileSelectBuiltin); + LUAU_ASSERT(targetCount == 1); + LUAU_ASSERT(!expr->self); + LUAU_ASSERT(expr->args.size == 2 && expr->args.data[1]->is()); + + AstExpr* arg = expr->args.data[0]; + + uint8_t argreg; + + if (isExprLocalReg(arg)) + argreg = getLocal(arg->as()->local); + else + { + argreg = uint8_t(regs + 1); + compileExprTempTop(arg, argreg); + } + + size_t fastcallLabel = bytecode.emitLabel(); + + bytecode.emitABC(LOP_FASTCALL1, LBF_SELECT_VARARG, argreg, 0); + + // note, these instructions are normally not executed and are used as a fallback for FASTCALL + // we can't use TempTop variant here because we need to make sure the arguments we already computed aren't overwritten + compileExprTemp(expr->func, regs); + + bytecode.emitABC(LOP_GETVARARGS, uint8_t(regs + 2), 0, 0); + + size_t callLabel = bytecode.emitLabel(); + if (!bytecode.patchSkipC(fastcallLabel, callLabel)) + CompileError::raise(expr->func->location, "Exceeded jump distance limit; simplify the code to compile"); + + // note, this is always multCall (last argument is variadic) + bytecode.emitABC(LOP_CALL, regs, 0, multRet ? 0 : uint8_t(targetCount + 1)); + + // if we didn't output results directly to target, we need to move them + if (!targetTop) + { + for (size_t i = 0; i < targetCount; ++i) + bytecode.emitABC(LOP_MOVE, uint8_t(target + i), uint8_t(regs + i), 0); + } + } + + void compileExprFastcallN(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop, bool multRet, uint8_t regs, int bfid) + { + LUAU_ASSERT(!expr->self); + LUAU_ASSERT(expr->args.size <= 2); + + LuauOpcode opc = expr->args.size == 1 ? LOP_FASTCALL1 : LOP_FASTCALL2; + + uint32_t args[2] = {}; + + for (size_t i = 0; i < expr->args.size; ++i) + { + if (i > 0) + { + if (int32_t cid = getConstantIndex(expr->args.data[i]); cid >= 0) + { + opc = LOP_FASTCALL2K; + args[i] = cid; + break; + } + } + + if (isExprLocalReg(expr->args.data[i])) + args[i] = getLocal(expr->args.data[i]->as()->local); + else + { + args[i] = uint8_t(regs + 1 + i); + compileExprTempTop(expr->args.data[i], uint8_t(args[i])); + } + } + + size_t fastcallLabel = bytecode.emitLabel(); + + bytecode.emitABC(opc, uint8_t(bfid), uint8_t(args[0]), 0); + if (opc != LOP_FASTCALL1) + bytecode.emitAux(args[1]); + + // Set up a traditional Lua stack for the subsequent LOP_CALL. + // Note, as with other instructions that immediately follow FASTCALL, these are normally not executed and are used as a fallback for + // these FASTCALL variants. + for (size_t i = 0; i < expr->args.size; ++i) + { + if (i > 0 && opc == LOP_FASTCALL2K) + { + emitLoadK(uint8_t(regs + 1 + i), args[i]); + break; + } + + if (args[i] != regs + 1 + i) + bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), uint8_t(args[i]), 0); + } + + // note, these instructions are normally not executed and are used as a fallback for FASTCALL + // we can't use TempTop variant here because we need to make sure the arguments we already computed aren't overwritten + compileExprTemp(expr->func, regs); + + size_t callLabel = bytecode.emitLabel(); + + // FASTCALL will skip over the instructions needed to compute function and jump over CALL which must immediately follow the instruction + // sequence after FASTCALL + if (!bytecode.patchSkipC(fastcallLabel, callLabel)) + CompileError::raise(expr->func->location, "Exceeded jump distance limit; simplify the code to compile"); + + bytecode.emitABC(LOP_CALL, regs, uint8_t(expr->args.size + 1), multRet ? 0 : uint8_t(targetCount + 1)); + + // if we didn't output results directly to target, we need to move them + if (!targetTop) + { + for (size_t i = 0; i < targetCount; ++i) + bytecode.emitABC(LOP_MOVE, uint8_t(target + i), uint8_t(regs + i), 0); + } + } + void compileExprCall(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop = false, bool multRet = false) { LUAU_ASSERT(!targetTop || unsigned(target + targetCount) == regTop); @@ -284,6 +403,25 @@ struct Compiler bfid = getBuiltinFunctionId(builtin, options); } + if (bfid == LBF_SELECT_VARARG) + { + LUAU_ASSERT(FFlag::LuauCompileSelectBuiltin); + // Optimization: compile select(_, ...) as FASTCALL1; the builtin will read variadic arguments directly + // note: for now we restrict this to single-return expressions since our runtime code doesn't deal with general cases + if (multRet == false && targetCount == 1 && expr->args.size == 2 && expr->args.data[1]->is()) + return compileExprSelectVararg(expr, target, targetCount, targetTop, multRet, regs); + else + bfid = -1; + } + + // Optimization: for 1/2 argument fast calls use specialized opcodes + if (!expr->self && bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2) + { + AstExpr* last = expr->args.data[expr->args.size - 1]; + if (!last->is() && !last->is()) + return compileExprFastcallN(expr, target, targetCount, targetTop, multRet, regs, bfid); + } + if (expr->self) { AstExprIndexName* fi = expr->func->as(); @@ -309,24 +447,13 @@ struct Compiler compileExprTempTop(expr->func, regs); } - // Note: if the last argument is ExprVararg or ExprCall, we need to route that directly to the called function preserving the # of args bool multCall = false; - bool skipArgs = false; - - if (!expr->self && bfid >= 0 && expr->args.size >= 1 && expr->args.size <= 2) - { - AstExpr* last = expr->args.data[expr->args.size - 1]; - skipArgs = !(last->is() || last->is()); - } - if (!skipArgs) - { - for (size_t i = 0; i < expr->args.size; ++i) - if (i + 1 == expr->args.size) - multCall = compileExprTempMultRet(expr->args.data[i], uint8_t(regs + 1 + expr->self + i)); - else - compileExprTempTop(expr->args.data[i], uint8_t(regs + 1 + expr->self + i)); - } + for (size_t i = 0; i < expr->args.size; ++i) + if (i + 1 == expr->args.size) + multCall = compileExprTempMultRet(expr->args.data[i], uint8_t(regs + 1 + expr->self + i)); + else + compileExprTempTop(expr->args.data[i], uint8_t(regs + 1 + expr->self + i)); setDebugLineEnd(expr->func); @@ -347,59 +474,8 @@ struct Compiler } else if (bfid >= 0) { - size_t fastcallLabel; - - if (skipArgs) - { - LuauOpcode opc = expr->args.size == 1 ? LOP_FASTCALL1 : LOP_FASTCALL2; - - uint32_t args[2] = {}; - for (size_t i = 0; i < expr->args.size; ++i) - { - if (i > 0) - { - if (int32_t cid = getConstantIndex(expr->args.data[i]); cid >= 0) - { - opc = LOP_FASTCALL2K; - args[i] = cid; - break; - } - } - - if (isExprLocalReg(expr->args.data[i])) - args[i] = getLocal(expr->args.data[i]->as()->local); - else - { - args[i] = uint8_t(regs + 1 + i); - compileExprTempTop(expr->args.data[i], uint8_t(args[i])); - } - } - - fastcallLabel = bytecode.emitLabel(); - bytecode.emitABC(opc, uint8_t(bfid), uint8_t(args[0]), 0); - if (opc != LOP_FASTCALL1) - bytecode.emitAux(args[1]); - - // Set up a traditional Lua stack for the subsequent LOP_CALL. - // Note, as with other instructions that immediately follow FASTCALL, these are normally not executed and are used as a fallback for - // these FASTCALL variants. - for (size_t i = 0; i < expr->args.size; ++i) - { - if (i > 0 && opc == LOP_FASTCALL2K) - { - emitLoadK(uint8_t(regs + 1 + i), args[i]); - break; - } - - if (args[i] != regs + 1 + i) - bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1 + i), uint8_t(args[i]), 0); - } - } - else - { - fastcallLabel = bytecode.emitLabel(); - bytecode.emitABC(LOP_FASTCALL, uint8_t(bfid), 0, 0); - } + size_t fastcallLabel = bytecode.emitLabel(); + bytecode.emitABC(LOP_FASTCALL, uint8_t(bfid), 0, 0); // note, these instructions are normally not executed and are used as a fallback for FASTCALL // we can't use TempTop variant here because we need to make sure the arguments we already computed aren't overwritten @@ -1101,9 +1177,20 @@ struct Compiler for (size_t i = 0; i < expr->items.size; ++i) { const AstExprTable::Item& item = expr->items.data[i]; - AstExprConstantNumber* ckey = item.key->as(); + LUAU_ASSERT(item.key); // no list portion => all items have keys + + if (FFlag::LuauCompileTableIndexOpt) + { + const Constant* ckey = constants.find(item.key); + + indexSize += (ckey && ckey->type == Constant::Type_Number && ckey->valueNumber == double(indexSize + 1)); + } + else + { + AstExprConstantNumber* ckey = item.key->as(); - indexSize += (ckey && ckey->value == double(indexSize + 1)); + indexSize += (ckey && ckey->value == double(indexSize + 1)); + } } // we only perform the optimization if we don't have any other []-keys @@ -1200,37 +1287,47 @@ struct Compiler arrayChunkCurrent = 0; } - // items with a key are set one by one via SETTABLE/SETTABLEKS + // items with a key are set one by one via SETTABLE/SETTABLEKS/SETTABLEN if (key) { RegScope rsi(this); - // Optimization: use SETTABLEKS/SETTABLEN for literal keys, this happens often as part of usual table construction syntax - if (AstExprConstantString* ckey = key->as()) - { - BytecodeBuilder::StringRef cname = sref(ckey->value); - int32_t cid = bytecode.addConstantString(cname); - if (cid < 0) - CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); - - uint8_t rv = compileExprAuto(value, rsi); - - bytecode.emitABC(LOP_SETTABLEKS, rv, reg, uint8_t(BytecodeBuilder::getStringHash(cname))); - bytecode.emitAux(cid); - } - else if (AstExprConstantNumber* ckey = key->as(); - ckey && ckey->value >= 1 && ckey->value <= 256 && double(int(ckey->value)) == ckey->value) + if (FFlag::LuauCompileTableIndexOpt) { + LValue lv = compileLValueIndex(reg, key, rsi); uint8_t rv = compileExprAuto(value, rsi); - bytecode.emitABC(LOP_SETTABLEN, rv, reg, uint8_t(int(ckey->value) - 1)); + compileAssign(lv, rv); } else { - uint8_t rk = compileExprAuto(key, rsi); - uint8_t rv = compileExprAuto(value, rsi); + // Optimization: use SETTABLEKS/SETTABLEN for literal keys, this happens often as part of usual table construction syntax + if (AstExprConstantString* ckey = key->as()) + { + BytecodeBuilder::StringRef cname = sref(ckey->value); + int32_t cid = bytecode.addConstantString(cname); + if (cid < 0) + CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + + uint8_t rv = compileExprAuto(value, rsi); + + bytecode.emitABC(LOP_SETTABLEKS, rv, reg, uint8_t(BytecodeBuilder::getStringHash(cname))); + bytecode.emitAux(cid); + } + else if (AstExprConstantNumber* ckey = key->as(); + ckey && ckey->value >= 1 && ckey->value <= 256 && double(int(ckey->value)) == ckey->value) + { + uint8_t rv = compileExprAuto(value, rsi); + + bytecode.emitABC(LOP_SETTABLEN, rv, reg, uint8_t(int(ckey->value) - 1)); + } + else + { + uint8_t rk = compileExprAuto(key, rsi); + uint8_t rv = compileExprAuto(value, rsi); - bytecode.emitABC(LOP_SETTABLE, rv, reg, rk); + bytecode.emitABC(LOP_SETTABLE, rv, reg, rk); + } } } // items without a key are set using SETLIST so that we can initialize large arrays quickly @@ -1339,6 +1436,9 @@ struct Compiler uint8_t rt = compileExprAuto(expr->expr, rs); uint8_t i = uint8_t(int(cv->valueNumber) - 1); + if (FFlag::LuauCompileTableIndexOpt) + setDebugLine(expr->index); + bytecode.emitABC(LOP_GETTABLEN, target, rt, i); } else if (cv && cv->type == Constant::Type_String) @@ -1350,6 +1450,9 @@ struct Compiler if (cid < 0) CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); + if (FFlag::LuauCompileTableIndexOpt) + setDebugLine(expr->index); + bytecode.emitABC(LOP_GETTABLEKS, target, rt, uint8_t(BytecodeBuilder::getStringHash(iname))); bytecode.emitAux(cid); } @@ -1657,6 +1760,40 @@ struct Compiler Location location; }; + LValue compileLValueIndex(uint8_t reg, AstExpr* index, RegScope& rs) + { + const Constant* cv = constants.find(index); + + if (cv && cv->type == Constant::Type_Number && cv->valueNumber >= 1 && cv->valueNumber <= 256 && + double(int(cv->valueNumber)) == cv->valueNumber) + { + LValue result = {LValue::Kind_IndexNumber}; + result.reg = reg; + result.number = uint8_t(int(cv->valueNumber) - 1); + result.location = index->location; + + return result; + } + else if (cv && cv->type == Constant::Type_String) + { + LValue result = {LValue::Kind_IndexName}; + result.reg = reg; + result.name = sref(cv->getString()); + result.location = index->location; + + return result; + } + else + { + LValue result = {LValue::Kind_IndexExpr}; + result.reg = reg; + result.index = compileExprAuto(index, rs); + result.location = index->location; + + return result; + } + } + LValue compileLValue(AstExpr* node, RegScope& rs) { setDebugLine(node); @@ -1699,36 +1836,9 @@ struct Compiler } else if (AstExprIndexExpr* expr = node->as()) { - const Constant* cv = constants.find(expr->index); - - if (cv && cv->type == Constant::Type_Number && cv->valueNumber >= 1 && cv->valueNumber <= 256 && - double(int(cv->valueNumber)) == cv->valueNumber) - { - LValue result = {LValue::Kind_IndexNumber}; - result.reg = compileExprAuto(expr->expr, rs); - result.number = uint8_t(int(cv->valueNumber) - 1); - result.location = node->location; - - return result; - } - else if (cv && cv->type == Constant::Type_String) - { - LValue result = {LValue::Kind_IndexName}; - result.reg = compileExprAuto(expr->expr, rs); - result.name = sref(cv->getString()); - result.location = node->location; - - return result; - } - else - { - LValue result = {LValue::Kind_IndexExpr}; - result.reg = compileExprAuto(expr->expr, rs); - result.index = compileExprAuto(expr->index, rs); - result.location = node->location; + uint8_t reg = compileExprAuto(expr->expr, rs); - return result; - } + return compileLValueIndex(reg, expr->index, rs); } else { @@ -1740,6 +1850,9 @@ struct Compiler void compileLValueUse(const LValue& lv, uint8_t reg, bool set) { + if (FFlag::LuauCompileTableIndexOpt) + setDebugLine(lv.location); + switch (lv.kind) { case LValue::Kind_Local: diff --git a/Makefile b/Makefile index b144cac60..638c4c635 100644 --- a/Makefile +++ b/Makefile @@ -23,11 +23,11 @@ VM_SOURCES=$(wildcard VM/src/*.cpp) VM_OBJECTS=$(VM_SOURCES:%=$(BUILD)/%.o) VM_TARGET=$(BUILD)/libluauvm.a -TESTS_SOURCES=$(wildcard tests/*.cpp) +TESTS_SOURCES=$(wildcard tests/*.cpp) CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp TESTS_OBJECTS=$(TESTS_SOURCES:%=$(BUILD)/%.o) TESTS_TARGET=$(BUILD)/luau-tests -REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp +REPL_CLI_SOURCES=CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp CLI/ReplEntry.cpp REPL_CLI_OBJECTS=$(REPL_CLI_SOURCES:%=$(BUILD)/%.o) REPL_CLI_TARGET=$(BUILD)/luau @@ -90,11 +90,12 @@ $(AST_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include $(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -IAst/include $(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -IAnalysis/include $(VM_OBJECTS): CXXFLAGS+=-std=c++11 -IVM/include -$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IAnalysis/include -IVM/include -Iextern +$(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IAnalysis/include -IVM/include -ICLI -Iextern $(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IVM/include -Iextern $(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -IAnalysis/include -Iextern $(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IAnalysis/include -IVM/include +$(TESTS_TARGET): LDFLAGS+=-lpthread $(REPL_CLI_TARGET): LDFLAGS+=-lpthread fuzz-proto fuzz-prototest: LDFLAGS+=build/libprotobuf-mutator/src/libfuzzer/libprotobuf-mutator-libfuzzer.a build/libprotobuf-mutator/src/libprotobuf-mutator.a build/libprotobuf-mutator/external.protobuf/lib/libprotobuf.a diff --git a/Sources.cmake b/Sources.cmake index bafe75948..22e7af223 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -176,7 +176,8 @@ if(TARGET Luau.Repl.CLI) CLI/FileUtils.cpp CLI/Profiler.h CLI/Profiler.cpp - CLI/Repl.cpp) + CLI/Repl.cpp + CLI/ReplEntry.cpp) endif() if(TARGET Luau.Analyze.CLI) @@ -243,6 +244,21 @@ if(TARGET Luau.Conformance) tests/main.cpp) endif() +if(TARGET Luau.CLI.Test) + # Luau.CLI.Test Sources + target_sources(Luau.CLI.Test PRIVATE + CLI/Coverage.h + CLI/Coverage.cpp + CLI/FileUtils.h + CLI/FileUtils.cpp + CLI/Profiler.h + CLI/Profiler.cpp + CLI/Repl.cpp + + tests/Repl.test.cpp + tests/main.cpp) +endif() + if(TARGET Luau.Web) # Luau.Web Sources target_sources(Luau.Web PRIVATE diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index d5416285e..5cffba63c 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -14,6 +14,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauGcForwardMetatableBarrier, false) + const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -869,7 +871,16 @@ int lua_setmetatable(lua_State* L, int objindex) luaG_runerror(L, "Attempt to modify a readonly table"); hvalue(obj)->metatable = mt; if (mt) - luaC_objbarriert(L, hvalue(obj), mt); + { + if (FFlag::LuauGcForwardMetatableBarrier) + { + luaC_objbarrier(L, hvalue(obj), mt); + } + else + { + luaC_objbarriert(L, hvalue(obj), mt); + } + } break; } case LUA_TUSERDATA: diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index 34e9ebc1f..ecc14e87b 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -1087,6 +1087,34 @@ static int luauF_countrz(lua_State* L, StkId res, TValue* arg0, int nresults, St return -1; } +static int luauF_select(lua_State* L, StkId res, TValue* arg0, int nresults, StkId args, int nparams) +{ + if (nparams == 1 && nresults == 1) + { + int n = cast_int(L->base - L->ci->func) - clvalue(L->ci->func)->l.p->numparams - 1; + + if (ttisnumber(arg0)) + { + int i = int(nvalue(arg0)); + + // i >= 1 && i <= n + if (unsigned(i - 1) <= unsigned(n)) + { + setobj2s(L, res, L->base - n + (i - 1)); + return 1; + } + // note: for now we don't handle negative case (wrap around) and defer to fallback + } + else if (ttisstring(arg0) && *svalue(arg0) == '#') + { + setnvalue(res, double(n)); + return 1; + } + } + + return -1; +} + luau_FastFunction luauF_table[256] = { NULL, luauF_assert, @@ -1156,4 +1184,6 @@ luau_FastFunction luauF_table[256] = { luauF_countlz, luauF_countrz, + + luauF_select, }; diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index abcde7796..192228613 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -5,8 +5,6 @@ #include "lstate.h" #include "lvm.h" -LUAU_FASTFLAGVARIABLE(LuauCoroutineClose, false) - #define CO_RUN 0 /* running */ #define CO_SUS 1 /* suspended */ #define CO_NOR 2 /* 'normal' (it resumed another coroutine) */ @@ -235,9 +233,6 @@ static int coyieldable(lua_State* L) static int coclose(lua_State* L) { - if (!FFlag::LuauCoroutineClose) - luaL_error(L, "coroutine.close is not enabled"); - lua_State* co = lua_tothread(L, 1); luaL_argexpected(L, co, 1, "thread"); diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index 581506a89..a3982bc68 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,8 +17,6 @@ #include -LUAU_FASTFLAG(LuauCoroutineClose) - /* ** {====================================================== ** Error-recovery functions @@ -300,7 +298,7 @@ static void resume(lua_State* L, void* ud) { // start coroutine LUAU_ASSERT(L->ci == L->base_ci && firstArg >= L->base); - if (FFlag::LuauCoroutineClose && firstArg == L->base) + if (firstArg == L->base) luaG_runerror(L, "cannot resume dead coroutine"); if (luau_precall(L, firstArg - 1, LUA_MULTRET) != PCRLUA) diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 50859b1e8..82ac00092 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -93,10 +93,8 @@ static void finishGcCycleStats(global_State* g) g->gcstats.lastcycle = g->gcstats.currcycle; g->gcstats.currcycle = GCCycleStats(); - g->gcstats.cyclestatsacc.markitems += g->gcstats.lastcycle.markitems; g->gcstats.cyclestatsacc.marktime += g->gcstats.lastcycle.marktime; g->gcstats.cyclestatsacc.atomictime += g->gcstats.lastcycle.atomictime; - g->gcstats.cyclestatsacc.sweepitems += g->gcstats.lastcycle.sweepitems; g->gcstats.cyclestatsacc.sweeptime += g->gcstats.lastcycle.sweeptime; } @@ -492,23 +490,22 @@ static void freeobj(lua_State* L, GCObject* o, lua_Page* page) } } -#define sweepwholelist(L, p, tc) sweeplist(L, p, SIZE_MAX, tc) +#define sweepwholelist(L, p) sweeplist(L, p, SIZE_MAX) -static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count, size_t* traversedcount) +static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count) { LUAU_ASSERT(!FFlag::LuauGcPagedSweep); GCObject* curr; global_State* g = L->global; int deadmask = otherwhite(g); - size_t startcount = count; LUAU_ASSERT(testbit(deadmask, FIXEDBIT)); /* make sure we never sweep fixed objects */ while ((curr = *p) != NULL && count-- > 0) { int alive = (curr->gch.marked ^ WHITEBITS) & deadmask; if (curr->gch.tt == LUA_TTHREAD) { - sweepwholelist(L, (GCObject**)&gco2th(curr)->openupval, traversedcount); /* sweep open upvalues */ + sweepwholelist(L, (GCObject**)&gco2th(curr)->openupval); /* sweep open upvalues */ lua_State* th = gco2th(curr); @@ -534,10 +531,6 @@ static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count, size_t* tr } } - // if we didn't reach the end of the list it means that we've stopped because the count dropped below zero - if (traversedcount) - *traversedcount += startcount - (curr ? count + 1 : count); - return p; } @@ -721,8 +714,6 @@ static bool sweepgco(lua_State* L, lua_Page* page, GCObject* gco) int alive = (gco->gch.marked ^ WHITEBITS) & deadmask; - g->gcstats.currcycle.sweepitems++; - if (gco->gch.tt == LUA_TTHREAD) { lua_State* th = gco2th(gco); @@ -770,11 +761,11 @@ static int sweepgcopage(lua_State* L, lua_Page* page) { // if the last block was removed, page would be removed as well if (--busyBlocks == 0) - return (pos - start) / blockSize + 1; + return int(pos - start) / blockSize + 1; } } - return (end - start) / blockSize; + return int(end - start) / blockSize; } static size_t gcstep(lua_State* L, size_t limit) @@ -793,8 +784,6 @@ static size_t gcstep(lua_State* L, size_t limit) { while (g->gray && cost < limit) { - g->gcstats.currcycle.markitems++; - cost += propagatemark(g); } @@ -812,8 +801,6 @@ static size_t gcstep(lua_State* L, size_t limit) { while (g->gray && cost < limit) { - g->gcstats.currcycle.markitems++; - cost += propagatemark(g); } @@ -842,10 +829,8 @@ static size_t gcstep(lua_State* L, size_t limit) while (g->sweepstrgc < g->strt.size && cost < limit) { - size_t traversedcount = 0; - sweepwholelist(L, (GCObject**)&g->strt.hash[g->sweepstrgc++], &traversedcount); + sweepwholelist(L, (GCObject**)&g->strt.hash[g->sweepstrgc++]); - g->gcstats.currcycle.sweepitems += traversedcount; cost += GC_SWEEPCOST; } @@ -855,12 +840,10 @@ static size_t gcstep(lua_State* L, size_t limit) // sweep string buffer list and preserve used string count uint32_t nuse = L->global->strt.nuse; - size_t traversedcount = 0; - sweepwholelist(L, (GCObject**)&g->strbufgc, &traversedcount); + sweepwholelist(L, (GCObject**)&g->strbufgc); L->global->strt.nuse = nuse; - g->gcstats.currcycle.sweepitems += traversedcount; g->gcstate = GCSsweep; // end sweep-string phase } break; @@ -893,10 +876,8 @@ static size_t gcstep(lua_State* L, size_t limit) { while (*g->sweepgc && cost < limit) { - size_t traversedcount = 0; - g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX, &traversedcount); + g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX); - g->gcstats.currcycle.sweepitems += traversedcount; cost += GC_SWEEPMAX * GC_SWEEPCOST; } diff --git a/VM/src/lgc.h b/VM/src/lgc.h index 4455fec5b..528d09446 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -113,6 +113,7 @@ luaC_barrierf(L, obj2gco(p), obj2gco(o)); \ } +// TODO: remove with FFlagLuauGcForwardMetatableBarrier #define luaC_objbarriert(L, t, o) \ { \ if (isblack(obj2gco(t)) && iswhite(obj2gco(o))) \ diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 6d3b77724..e1dbce504 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -200,7 +200,7 @@ static lua_Page* newpage(lua_State* L, lua_Page** gcopageset, int pageSize, int global_State* g = L->global; - LUAU_ASSERT(pageSize - offsetof(lua_Page, data) >= blockSize * blockCount); + LUAU_ASSERT(pageSize - int(offsetof(lua_Page, data)) >= blockSize * blockCount); lua_Page* page = (lua_Page*)(*g->frealloc)(L, g->ud, NULL, 0, pageSize); if (!page) @@ -376,7 +376,7 @@ static void* luaM_newgcoblock(lua_State* L, int sizeClass) LUAU_ASSERT(!page->prev); LUAU_ASSERT(page->freeList || page->freeNext >= 0); - LUAU_ASSERT(size_t(page->blockSize) == kSizeClassConfig.sizeOfClass[sizeClass]); + LUAU_ASSERT(page->blockSize == kSizeClassConfig.sizeOfClass[sizeClass]); void* block; @@ -520,7 +520,7 @@ GCObject* luaM_newgco_(lua_State* L, size_t nsize, uint8_t memcat) } else { - lua_Page* page = newpage(L, &g->allgcopages, offsetof(lua_Page, data) + nsize, nsize, 1); + lua_Page* page = newpage(L, &g->allgcopages, offsetof(lua_Page, data) + int(nsize), int(nsize), 1); block = &page->data; ASAN_UNPOISON_MEMORY_REGION(block, page->blockSize); diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 080f00248..0708b71f3 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -96,9 +96,6 @@ struct GCCycleStats double sweeptime = 0.0; - size_t markitems = 0; - size_t sweepitems = 0; - size_t assistwork = 0; size_t explicitwork = 0; diff --git a/tests/AstQuery.test.cpp b/tests/AstQuery.test.cpp index 41b553b55..292625b09 100644 --- a/tests/AstQuery.test.cpp +++ b/tests/AstQuery.test.cpp @@ -44,10 +44,6 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "prop") TEST_CASE_FIXTURE(DocumentationSymbolFixture, "event_callback_arg") { - ScopedFastFlag sffs[] = { - {"LuauPersistDefinitionFileTypes", true}, - }; - loadDefinition(R"( declare function Connect(fn: (string) -> ()) )"); @@ -63,8 +59,6 @@ TEST_CASE_FIXTURE(DocumentationSymbolFixture, "event_callback_arg") TEST_CASE_FIXTURE(DocumentationSymbolFixture, "overloaded_fn") { - ScopedFastFlag sffs{"LuauStoreMatchingOverloadFnType", true}; - loadDefinition(R"( declare foo: ((string) -> number) & ((number) -> string) )"); diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 211e1be1f..e8e3b3156 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2626,7 +2626,6 @@ local a: A<(number, s@1> TEST_CASE_FIXTURE(ACFixture, "autocomplete_first_function_arg_expected_type") { ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); - ScopedFastFlag luauAutocompleteFirstArg("LuauAutocompleteFirstArg", true); check(R"( local function foo1() return 1 end @@ -2728,4 +2727,39 @@ end CHECK(ac.entryMap.count("getx")); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons") +{ + ScopedFastFlag sffs[] = { + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + {"LuauRefactorTypeVarQuestions", true}, + }; + + check(R"( + --!strict + local foo: "hello" | "bye" = "hello" + foo:@1 + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("format")); +} + +TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses_2") +{ + ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); + ScopedFastFlag preferToCallFunctionsForIntersects("PreferToCallFunctionsForIntersects", true); + + check(R"( +local bar: ((number) -> number) & (number, number) -> number) +local abc = b@1 + )"); + + auto ac = autocomplete('1'); + + CHECK(ac.entryMap.count("bar")); + CHECK(ac.entryMap["bar"].parens == ParenthesesRecommendation::CursorInside); +} + TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 3b0d677de..4a28bdde9 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -603,6 +603,37 @@ RETURN R0 1 )"); } +TEST_CASE("TableLiteralsIndexConstant") +{ + ScopedFastFlag sff("LuauCompileTableIndexOpt", true); + + // validate that we use SETTTABLEKS for constant variable keys + CHECK_EQ("\n" + compileFunction0(R"( + local a, b = "key", "value" + return {[a] = 42, [b] = 0} +)"), R"( +NEWTABLE R0 2 0 +LOADN R1 42 +SETTABLEKS R1 R0 K0 +LOADN R1 0 +SETTABLEKS R1 R0 K1 +RETURN R0 1 +)"); + + // validate that we use SETTABLEN for constant variable keys *and* that we predict array size + CHECK_EQ("\n" + compileFunction0(R"( + local a, b = 1, 2 + return {[a] = 42, [b] = 0} +)"), R"( +NEWTABLE R0 0 2 +LOADN R1 42 +SETTABLEN R1 R0 1 +LOADN R1 0 +SETTABLEN R1 R0 2 +RETURN R0 1 +)"); +} + TEST_CASE("TableSizePredictionBasic") { CHECK_EQ("\n" + compileFunction0(R"( @@ -2450,6 +2481,37 @@ return )"); } +TEST_CASE("DebugLineInfoAssignment") +{ + ScopedFastFlag sff("LuauCompileTableIndexOpt", true); + + Luau::BytecodeBuilder bcb; + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); + Luau::compileOrThrow(bcb, R"( + local a = { b = { c = { d = 3 } } } + +a +["b"] +["c"] +["d"] = 4 +)"); + + CHECK_EQ("\n" + bcb.dumpFunction(0), R"( +2: DUPTABLE R0 1 +2: DUPTABLE R1 3 +2: DUPTABLE R2 5 +2: LOADN R3 3 +2: SETTABLEKS R3 R2 K4 +2: SETTABLEKS R2 R1 K2 +2: SETTABLEKS R1 R0 K0 +5: GETTABLEKS R2 R0 K0 +6: GETTABLEKS R1 R2 K2 +7: LOADN R2 4 +7: SETTABLEKS R2 R1 K4 +8: RETURN R0 0 +)"); +} + TEST_CASE("DebugSource") { const char* source = R"( @@ -2763,6 +2825,75 @@ RETURN R1 -1 )"); } +TEST_CASE("FastcallSelect") +{ + ScopedFastFlag sff("LuauCompileSelectBuiltin", true); + + // select(_, ...) compiles to a builtin call + CHECK_EQ("\n" + compileFunction0("return (select('#', ...))"), R"( +LOADK R1 K0 +FASTCALL1 57 R1 +3 +GETIMPORT R0 2 +GETVARARGS R2 -1 +CALL R0 -1 1 +RETURN R0 1 +)"); + + // more complex example: select inside a for loop bound + select from a iterator + CHECK_EQ("\n" + compileFunction0(R"( +local sum = 0 +for i=1, select('#', ...) do + sum += select(i, ...) +end +return sum +)"), R"( +LOADN R0 0 +LOADN R3 1 +LOADK R5 K0 +FASTCALL1 57 R5 +3 +GETIMPORT R4 2 +GETVARARGS R6 -1 +CALL R4 -1 1 +MOVE R1 R4 +LOADN R2 1 +FORNPREP R1 +7 +FASTCALL1 57 R3 +3 +GETIMPORT R4 2 +GETVARARGS R6 -1 +CALL R4 -1 1 +ADD R0 R0 R4 +FORNLOOP R1 -7 +RETURN R0 1 +)"); + + // currently we assume a single value return to avoid dealing with stack resizing + CHECK_EQ("\n" + compileFunction0("return select('#', ...)"), R"( +GETIMPORT R0 1 +LOADK R1 K2 +GETVARARGS R2 -1 +CALL R0 -1 -1 +RETURN R0 -1 +)"); + + // note that select with a non-variadic second argument doesn't get optimized + CHECK_EQ("\n" + compileFunction0("return select('#')"), R"( +GETIMPORT R0 1 +LOADK R1 K2 +CALL R0 1 -1 +RETURN R0 -1 +)"); + + // note that select with a non-variadic second argument doesn't get optimized + CHECK_EQ("\n" + compileFunction0("return select('#', foo())"), R"( +GETIMPORT R0 1 +LOADK R1 K2 +GETIMPORT R2 4 +CALL R2 0 -1 +CALL R0 -1 -1 +RETURN R0 -1 +)"); +} + TEST_CASE("LotsOfParameters") { const char* source = R"( diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 5222af33e..914b881f7 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -331,8 +331,6 @@ TEST_CASE("UTF8") TEST_CASE("Coroutine") { - ScopedFastFlag sff("LuauCoroutineClose", true); - runConformance("coroutine.lua"); } diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index 405f26e07..ea1a08fe7 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -956,7 +956,6 @@ TEST_CASE("no_use_after_free_with_type_fun_instantiation") { // This flag forces this test to crash if there's a UAF in this code. ScopedFastFlag sff_DebugLuauFreezeArena("DebugLuauFreezeArena", true); - ScopedFastFlag sff_LuauCloneCorrectlyBeforeMutatingTableType("LuauCloneCorrectlyBeforeMutatingTableType", true); FrontendFixture fix; diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index ac81005c5..90831ee9d 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2000,6 +2000,73 @@ TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type_errors") matchParseError("type Y number> = {}", "Expected type pack after '=', got type", Location{{0, 14}, {0, 32}}); } +TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") +{ + { + AstStat* stat = parse("return if true then 1 else 2"); + + REQUIRE(stat != nullptr); + AstStatReturn* str = stat->as()->body.data[0]->as(); + REQUIRE(str != nullptr); + CHECK(str->list.size == 1); + auto* ifElseExpr = str->list.data[0]->as(); + REQUIRE(ifElseExpr != nullptr); + } + + { + AstStat* stat = parse("return if true then 1 elseif true then 2 else 3"); + + REQUIRE(stat != nullptr); + AstStatReturn* str = stat->as()->body.data[0]->as(); + REQUIRE(str != nullptr); + CHECK(str->list.size == 1); + auto* ifElseExpr1 = str->list.data[0]->as(); + REQUIRE(ifElseExpr1 != nullptr); + auto* ifElseExpr2 = ifElseExpr1->falseExpr->as(); + REQUIRE(ifElseExpr2 != nullptr); + } + + // Use "else if" as opposed to elseif + { + AstStat* stat = parse("return if true then 1 else if true then 2 else 3"); + + REQUIRE(stat != nullptr); + AstStatReturn* str = stat->as()->body.data[0]->as(); + REQUIRE(str != nullptr); + CHECK(str->list.size == 1); + auto* ifElseExpr1 = str->list.data[0]->as(); + REQUIRE(ifElseExpr1 != nullptr); + auto* ifElseExpr2 = ifElseExpr1->falseExpr->as(); + REQUIRE(ifElseExpr2 != nullptr); + } + + // Use an if-else expression as the conditional expression of an if-else expression + { + AstStat* stat = parse("return if if true then false else true then 1 else 2"); + + REQUIRE(stat != nullptr); + AstStatReturn* str = stat->as()->body.data[0]->as(); + REQUIRE(str != nullptr); + CHECK(str->list.size == 1); + auto* ifElseExpr = str->list.data[0]->as(); + REQUIRE(ifElseExpr != nullptr); + auto* nestedIfElseExpr = ifElseExpr->condition->as(); + REQUIRE(nestedIfElseExpr != nullptr); + } +} + +TEST_CASE_FIXTURE(Fixture, "parse_type_pack_type_parameters") +{ + AstStat* stat = parse(R"( +type Packed = () -> T... + +type A = Packed +type B = Packed<...number> +type C = Packed<(number, X...)> + )"); + REQUIRE(stat != nullptr); +} + TEST_SUITE_END(); TEST_SUITE_BEGIN("ParseErrorRecovery"); @@ -2504,71 +2571,4 @@ type Y = (T...) -> U... CHECK_EQ(1, result.errors.size()); } -TEST_CASE_FIXTURE(Fixture, "parse_if_else_expression") -{ - { - AstStat* stat = parse("return if true then 1 else 2"); - - REQUIRE(stat != nullptr); - AstStatReturn* str = stat->as()->body.data[0]->as(); - REQUIRE(str != nullptr); - CHECK(str->list.size == 1); - auto* ifElseExpr = str->list.data[0]->as(); - REQUIRE(ifElseExpr != nullptr); - } - - { - AstStat* stat = parse("return if true then 1 elseif true then 2 else 3"); - - REQUIRE(stat != nullptr); - AstStatReturn* str = stat->as()->body.data[0]->as(); - REQUIRE(str != nullptr); - CHECK(str->list.size == 1); - auto* ifElseExpr1 = str->list.data[0]->as(); - REQUIRE(ifElseExpr1 != nullptr); - auto* ifElseExpr2 = ifElseExpr1->falseExpr->as(); - REQUIRE(ifElseExpr2 != nullptr); - } - - // Use "else if" as opposed to elseif - { - AstStat* stat = parse("return if true then 1 else if true then 2 else 3"); - - REQUIRE(stat != nullptr); - AstStatReturn* str = stat->as()->body.data[0]->as(); - REQUIRE(str != nullptr); - CHECK(str->list.size == 1); - auto* ifElseExpr1 = str->list.data[0]->as(); - REQUIRE(ifElseExpr1 != nullptr); - auto* ifElseExpr2 = ifElseExpr1->falseExpr->as(); - REQUIRE(ifElseExpr2 != nullptr); - } - - // Use an if-else expression as the conditional expression of an if-else expression - { - AstStat* stat = parse("return if if true then false else true then 1 else 2"); - - REQUIRE(stat != nullptr); - AstStatReturn* str = stat->as()->body.data[0]->as(); - REQUIRE(str != nullptr); - CHECK(str->list.size == 1); - auto* ifElseExpr = str->list.data[0]->as(); - REQUIRE(ifElseExpr != nullptr); - auto* nestedIfElseExpr = ifElseExpr->condition->as(); - REQUIRE(nestedIfElseExpr != nullptr); - } -} - -TEST_CASE_FIXTURE(Fixture, "parse_type_pack_type_parameters") -{ - AstStat* stat = parse(R"( -type Packed = () -> T... - -type A = Packed -type B = Packed<...number> -type C = Packed<(number, X...)> - )"); - REQUIRE(stat != nullptr); -} - TEST_SUITE_END(); diff --git a/tests/Repl.test.cpp b/tests/Repl.test.cpp new file mode 100644 index 000000000..f660bcd3f --- /dev/null +++ b/tests/Repl.test.cpp @@ -0,0 +1,117 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "lua.h" +#include "lualib.h" + +#include "Repl.h" + +#include "doctest.h" + +#include +#include +#include +#include + + +class ReplFixture +{ +public: + ReplFixture() + : luaState(luaL_newstate(), lua_close) + { + L = luaState.get(); + setupState(L); + luaL_sandboxthread(L); + + std::string result = runCode(L, prettyPrintSource); + } + + // Returns all of the output captured from the pretty printer + std::string getCapturedOutput() + { + lua_getglobal(L, "capturedoutput"); + const char* str = lua_tolstring(L, -1, nullptr); + std::string result(str); + lua_pop(L, 1); + return result; + } + lua_State* L; + +private: + std::unique_ptr luaState; + + // This is a simplicitic and incomplete pretty printer. + // It is included here to test that the pretty printer hook is being called. + // More elaborate tests to ensure correct output can be added if we introduce + // a more feature rich pretty printer. + std::string prettyPrintSource = R"( +-- Accumulate pretty printer output in `capturedoutput` +capturedoutput = "" + +function arraytostring(arr) + local strings = {} + table.foreachi(arr, function(k,v) table.insert(strings, pptostring(v)) end ) + return "{" .. table.concat(strings, ", ") .. "}" +end + +function pptostring(x) + if type(x) == "table" then + -- Just assume array-like tables for now. + return arraytostring(x) + elseif type(x) == "string" then + return '"' .. x .. '"' + else + return tostring(x) + end +end + +-- Note: Instead of calling print, the pretty printer just stores the output +-- in `capturedoutput` so we can check for the correct results. +function _PRETTYPRINT(...) + local args = table.pack(...) + local strings = {} + for i=1, args.n do + local item = args[i] + local str = pptostring(item, customoptions) + if i == 1 then + capturedoutput = capturedoutput .. str + else + capturedoutput = capturedoutput .. "\t" .. str + end + end +end +)"; +}; + +TEST_SUITE_BEGIN("ReplPrettyPrint"); + +TEST_CASE_FIXTURE(ReplFixture, "AdditionStatement") +{ + runCode(L, "return 30 + 12"); + CHECK(getCapturedOutput() == "42"); +} + +TEST_CASE_FIXTURE(ReplFixture, "TableLiteral") +{ + runCode(L, "return {1, 2, 3, 4}"); + CHECK(getCapturedOutput() == "{1, 2, 3, 4}"); +} + +TEST_CASE_FIXTURE(ReplFixture, "StringLiteral") +{ + runCode(L, "return 'str'"); + CHECK(getCapturedOutput() == "\"str\""); +} + +TEST_CASE_FIXTURE(ReplFixture, "TableWithStringLiterals") +{ + runCode(L, "return {1, 'two', 3, 'four'}"); + CHECK(getCapturedOutput() == "{1, \"two\", 3, \"four\"}"); +} + +TEST_CASE_FIXTURE(ReplFixture, "MultipleArguments") +{ + runCode(L, "return 3, 'three'"); + CHECK(getCapturedOutput() == "3\t\"three\""); +} + +TEST_SUITE_END(); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 445ee5329..bbb262910 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -435,8 +435,6 @@ TEST_CASE_FIXTURE(Fixture, "toString_the_boundTo_table_type_contained_within_a_T TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_union") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( type F = ((() -> number)?) -> F? local function f(p) return f end @@ -450,8 +448,6 @@ TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_union" TEST_CASE_FIXTURE(Fixture, "no_parentheses_around_cyclic_function_type_in_intersection") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( function f() return f end local a: ((number) -> ()) & typeof(f) diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 822bd727e..76ab23b3a 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -11,8 +11,6 @@ TEST_SUITE_BEGIN("TypeAliases"); TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( type F = () -> F? local function f() @@ -194,8 +192,6 @@ TEST_CASE_FIXTURE(Fixture, "corecursive_types_generic") TEST_CASE_FIXTURE(Fixture, "corecursive_function_types") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( type A = () -> (number, B) type B = () -> (string, A) diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 114679e34..a7f275515 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -9,8 +9,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauExtendedFunctionMismatchError) - TEST_SUITE_BEGIN("GenericsTests"); TEST_CASE_FIXTURE(Fixture, "check_generic_function") @@ -656,11 +654,7 @@ local d: D = c LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauExtendedFunctionMismatchError) - CHECK_EQ( - toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type parameters)"); - else - CHECK_EQ(toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()')"); + CHECK_EQ(toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type parameters)"); } TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_generic_pack") @@ -675,11 +669,8 @@ local d: D = c LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauExtendedFunctionMismatchError) - CHECK_EQ(toString(result.errors[0]), - R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type pack parameters)"); - else - CHECK_EQ(toString(result.errors[0]), R"(Type '() -> ()' could not be converted into '() -> ()')"); + CHECK_EQ(toString(result.errors[0]), + R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type pack parameters)"); } TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index e6d3d4d47..47c13be9a 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -271,6 +271,32 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") CHECK_EQ(toString(requireTypeAtPosition({5, 36})), "boolean?"); // a ~= b } +// Also belongs in TypeInfer.refinements.test.cpp. +// Just needs to fully support equality refinement. Which is annoying without type states. +TEST_CASE_FIXTURE(Fixture, "discriminate_from_x_not_equal_to_nil") +{ + ScopedFastFlag sff{"LuauDiscriminableUnions", true}; + + CheckResult result = check(R"( + type T = {x: string, y: number} | {x: nil, y: nil} + + local function f(t: T) + if t.x ~= nil then + local foo = t + else + local bar = t + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("{| x: string, y: number |}", toString(requireTypeAtPosition({5, 28}))); + + // Should be {| x: nil, y: nil |} + CHECK_EQ("{| x: nil, y: nil |} | {| x: string, y: number |}", toString(requireTypeAtPosition({7, 28}))); +} + TEST_CASE_FIXTURE(Fixture, "bail_early_if_unification_is_too_complicated" * doctest::timeout(0.5)) { ScopedFastInt sffi{"LuauTarjanChildLimit", 1}; @@ -590,8 +616,6 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") { - ScopedFastFlag luauCloneCorrectlyBeforeMutatingTableType{"LuauCloneCorrectlyBeforeMutatingTableType", true}; - // Mutability in type function application right now can create strange recursive types // TODO: instantiation right now is problematic, in this example should either leave the Table type alone // or it should rename the type to 'Self' so that the result will be 'Self
' diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index d76b920bb..f346ddfdf 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -6,11 +6,77 @@ #include "doctest.h" +LUAU_FASTFLAG(LuauDiscriminableUnions) LUAU_FASTFLAG(LuauWeakEqConstraint) LUAU_FASTFLAG(LuauQuantifyInPlace2) using namespace Luau; +namespace +{ +std::optional> magicFunctionInstanceIsA( + TypeChecker& typeChecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) +{ + if (expr.args.size != 1) + return std::nullopt; + + auto index = expr.func->as(); + auto str = expr.args.data[0]->as(); + if (!index || !str) + return std::nullopt; + + std::optional lvalue = tryGetLValue(*index->expr); + std::optional tfun = scope->lookupType(std::string(str->value.data, str->value.size)); + if (!lvalue || !tfun) + return std::nullopt; + + unfreeze(typeChecker.globalTypes); + TypePackId booleanPack = typeChecker.globalTypes.addTypePack({typeChecker.booleanType}); + freeze(typeChecker.globalTypes); + return ExprResult{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; +} + +struct RefinementClassFixture : Fixture +{ + RefinementClassFixture() + { + TypeArena& arena = typeChecker.globalTypes; + + unfreeze(arena); + TypeId vec3 = arena.addType(ClassTypeVar{"Vector3", {}, std::nullopt, std::nullopt, {}, nullptr}); + getMutable(vec3)->props = { + {"X", Property{typeChecker.numberType}}, + {"Y", Property{typeChecker.numberType}}, + {"Z", Property{typeChecker.numberType}}, + }; + + TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr}); + + TypePackId isAParams = arena.addTypePack({inst, typeChecker.stringType}); + TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); + TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets}); + getMutable(isA)->magicFunction = magicFunctionInstanceIsA; + + getMutable(inst)->props = { + {"Name", Property{typeChecker.stringType}}, + {"IsA", Property{isA}}, + }; + + TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr}); + TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr}); + getMutable(part)->props = { + {"Position", Property{vec3}}, + }; + + typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; + typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; + typeChecker.globalScope->exportedTypeBindings["Folder"] = TypeFun{{}, folder}; + typeChecker.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, part}; + freeze(typeChecker.globalTypes); + } +}; +} // namespace + TEST_SUITE_BEGIN("RefinementTest"); TEST_CASE_FIXTURE(Fixture, "is_truthy_constraint") @@ -196,8 +262,18 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_only_look_up_types_from_global_scope") end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0])); + if (FFlag::LuauDiscriminableUnions) + { + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({8, 44}))); + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({9, 38}))); + } + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' has no overlap with 'string'", toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "call_a_more_specific_function_using_typeguard") @@ -237,7 +313,6 @@ TEST_CASE_FIXTURE(Fixture, "impossible_type_narrow_is_not_an_error") TEST_CASE_FIXTURE(Fixture, "truthy_constraint_on_properties") { - CheckResult result = check(R"( local t: {x: number?} = {x = 1} @@ -254,7 +329,6 @@ TEST_CASE_FIXTURE(Fixture, "truthy_constraint_on_properties") TEST_CASE_FIXTURE(Fixture, "index_on_a_refined_property") { - CheckResult result = check(R"( local t: {x: {y: string}?} = {x = {y = "hello!"}} @@ -360,7 +434,10 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term") TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; + ScopedFastFlag sff[] = { + {"LuauDiscriminableUnions", true}, + {"LuauSingletonTypes", true}, + }; CheckResult result = check(R"( local function f(a: (string | number)?) @@ -374,16 +451,8 @@ TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "(number | string)?"); // a == "hello" - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= "hello" - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 28})), "string"); // a == "hello" - CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= "hello" - } + CHECK_EQ(toString(requireTypeAtPosition({3, 28})), R"("hello")"); // a == "hello" + CHECK_EQ(toString(requireTypeAtPosition({5, 28})), "(number | string)?"); // a ~= "hello" } TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") @@ -416,7 +485,8 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; + ScopedFastFlag sff{"LuauDiscriminableUnions", true}; + ScopedFastFlag sff2{"LuauWeakEqConstraint", true}; CheckResult result = check(R"( local function f(a, b: string?) @@ -428,16 +498,8 @@ TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") LUAU_REQUIRE_NO_ERRORS(result); - if (FFlag::LuauWeakEqConstraint) - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "a"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b - } - else - { - CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "string?"); // a == b - CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b - } + CHECK_EQ(toString(requireTypeAtPosition({3, 33})), "a"); // a == b + CHECK_EQ(toString(requireTypeAtPosition({3, 36})), "string?"); // a == b } TEST_CASE_FIXTURE(Fixture, "unknown_lvalue_is_not_synonymous_with_other_on_not_equal") @@ -527,9 +589,17 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") end )"); - // This is kinda weird to see, but this actually only happens in Luau without Roblox type bindings because we don't have a Vector3 type. - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Unknown type 'Vector3'", toString(result.errors[0])); + if (FFlag::LuauDiscriminableUnions) + { + LUAU_REQUIRE_NO_ERRORS(result); + } + else + { + // This is kinda weird to see, but this actually only happens in Luau without Roblox type bindings because we don't have a Vector3 type. + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Unknown type 'Vector3'", toString(result.errors[0])); + } + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({3, 28}))); } @@ -614,214 +684,6 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_narrows_for_functions") CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); // type(x) ~= "function" } -namespace -{ -std::optional> magicFunctionInstanceIsA( - TypeChecker& typeChecker, const ScopePtr& scope, const AstExprCall& expr, ExprResult exprResult) -{ - if (expr.args.size != 1) - return std::nullopt; - - auto index = expr.func->as(); - auto str = expr.args.data[0]->as(); - if (!index || !str) - return std::nullopt; - - std::optional lvalue = tryGetLValue(*index->expr); - std::optional tfun = scope->lookupType(std::string(str->value.data, str->value.size)); - if (!lvalue || !tfun) - return std::nullopt; - - unfreeze(typeChecker.globalTypes); - TypePackId booleanPack = typeChecker.globalTypes.addTypePack({typeChecker.booleanType}); - freeze(typeChecker.globalTypes); - return ExprResult{booleanPack, {IsAPredicate{std::move(*lvalue), expr.location, tfun->type}}}; -} - -struct RefinementClassFixture : Fixture -{ - RefinementClassFixture() - { - TypeArena& arena = typeChecker.globalTypes; - - unfreeze(arena); - TypeId vec3 = arena.addType(ClassTypeVar{"Vector3", {}, std::nullopt, std::nullopt, {}, nullptr}); - getMutable(vec3)->props = { - {"X", Property{typeChecker.numberType}}, - {"Y", Property{typeChecker.numberType}}, - {"Z", Property{typeChecker.numberType}}, - }; - - TypeId inst = arena.addType(ClassTypeVar{"Instance", {}, std::nullopt, std::nullopt, {}, nullptr}); - - TypePackId isAParams = arena.addTypePack({inst, typeChecker.stringType}); - TypePackId isARets = arena.addTypePack({typeChecker.booleanType}); - TypeId isA = arena.addType(FunctionTypeVar{isAParams, isARets}); - getMutable(isA)->magicFunction = magicFunctionInstanceIsA; - - getMutable(inst)->props = { - {"Name", Property{typeChecker.stringType}}, - {"IsA", Property{isA}}, - }; - - TypeId folder = typeChecker.globalTypes.addType(ClassTypeVar{"Folder", {}, inst, std::nullopt, {}, nullptr}); - TypeId part = typeChecker.globalTypes.addType(ClassTypeVar{"Part", {}, inst, std::nullopt, {}, nullptr}); - getMutable(part)->props = { - {"Position", Property{vec3}}, - }; - - typeChecker.globalScope->exportedTypeBindings["Vector3"] = TypeFun{{}, vec3}; - typeChecker.globalScope->exportedTypeBindings["Instance"] = TypeFun{{}, inst}; - typeChecker.globalScope->exportedTypeBindings["Folder"] = TypeFun{{}, folder}; - typeChecker.globalScope->exportedTypeBindings["Part"] = TypeFun{{}, part}; - freeze(typeChecker.globalTypes); - } -}; -} // namespace - -TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") -{ - CheckResult result = check(R"( - local function f(vec) - local X, Y, Z = vec.X, vec.Y, vec.Z - - if type(vec) == "vector" then - local foo = vec - elseif typeof(vec) == "Instance" then - local foo = vec - else - local foo = vec - end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector" - - if (FFlag::LuauQuantifyInPlace2) - CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); - else - CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); - - CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" - - if (FFlag::LuauQuantifyInPlace2) - CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" - else - CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" -} - -TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector") -{ - CheckResult result = check(R"( - local function f(x: Instance | Vector3) - if typeof(x) == "Vector3" then - local foo = x - else - local foo = x - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("Vector3", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); -} - -TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") -{ - CheckResult result = check(R"( - local function f(x: string | number | Instance | Vector3) - if type(x) == "userdata" then - local foo = x - else - local foo = x - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("Instance | Vector3", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28}))); -} - -TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") -{ - CheckResult result = check(R"( - local function f(x: Part | Folder | string) - if typeof(x) == "Instance" then - local foo = x - else - local foo = x - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); -} - -TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") -{ - CheckResult result = check(R"( - local function f(x: Part | Folder | Instance | string | Vector3 | any) - if typeof(x) == "Instance" then - local foo = x - else - local foo = x - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("Folder | Instance | Part", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("Vector3 | any | string", toString(requireTypeAtPosition({5, 28}))); -} - -TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table") -{ - CheckResult result = check(R"( - --!nonstrict - - local function f(x) - if typeof(x) == "Instance" and x:IsA("Folder") then - local foo = x - elseif typeof(x) == "table" then - local foo = x - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("Folder", toString(requireTypeAtPosition({5, 28}))); - CHECK_EQ("any", toString(requireTypeAtPosition({7, 28}))); -} - -TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") -{ - CheckResult result = check(R"( - local function f(x: Part | Folder | string) - if typeof(x) ~= "Instance" or not x:IsA("Part") then - local foo = x - else - local foo = x - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("Folder | string", toString(requireTypeAtPosition({3, 28}))); - CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28}))); -} - TEST_CASE_FIXTURE(Fixture, "type_guard_can_filter_for_intersection_of_tables") { CheckResult result = check(R"( @@ -1145,4 +1007,259 @@ TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscrip LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") +{ + ScopedFastFlag sff[] = { + {"LuauDiscriminableUnions", true}, + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type T = {tag: "missing", x: nil} | {tag: "exists", x: string} + + local function f(t: T) + if t.x then + local foo = t + else + local bar = t + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"({| tag: "exists", x: string |})", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ(R"({| tag: "missing", x: nil |})", toString(requireTypeAtPosition({7, 28}))); +} + +TEST_CASE_FIXTURE(Fixture, "discriminate_tag") +{ + ScopedFastFlag sff[] = { + {"LuauDiscriminableUnions", true}, + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type Cat = {tag: "Cat", name: string, catfood: string} + type Dog = {tag: "Dog", name: string, dogfood: string} + type Animal = Cat | Dog + + local function f(animal: Animal) + if animal.tag == "Cat" then + local cat: Cat = animal + elseif animal.tag == "Dog" then + local dog: Dog = animal + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Cat", toString(requireTypeAtPosition({7, 33}))); + CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); +} + +TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string") +{ + ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true}; + + CheckResult result = check(R"( + type T = { [string]: { prop: number }? } + local t: T = {} + + if t["hello"] then + local foo = t["hello"].prop + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "and_or_peephole_refinement") +{ + CheckResult result = check(R"( + local function len(a: {any}) + return a and #a or nil + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "discriminate_from_isa_of_x") +{ + ScopedFastFlag sff[] = { + {"LuauDiscriminableUnions", true}, + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + }; + + CheckResult result = check(R"( + type T = {tag: "Part", x: Part} | {tag: "Folder", x: Folder} + + local function f(t: T) + if t.x:IsA("Part") then + local foo = t + else + local bar = t + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"({| tag: "Part", x: Part |})", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ(R"({| tag: "Folder", x: Folder |})", toString(requireTypeAtPosition({7, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") +{ + CheckResult result = check(R"( + local function f(vec) + local X, Y, Z = vec.X, vec.Y, vec.Z + + if type(vec) == "vector" then + local foo = vec + elseif typeof(vec) == "Instance" then + local foo = vec + else + local foo = vec + end + end + )"); + + if (FFlag::LuauDiscriminableUnions) + LUAU_REQUIRE_NO_ERRORS(result); + else + { + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); + else + CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); + } + + CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector" + + CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" + + if (FFlag::LuauQuantifyInPlace2) + CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" + else + CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector") +{ + CheckResult result = check(R"( + local function f(x: Instance | Vector3) + if typeof(x) == "Vector3" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Vector3", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Instance", toString(requireTypeAtPosition({5, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "type_narrow_for_all_the_userdata") +{ + CheckResult result = check(R"( + local function f(x: string | number | Instance | Vector3) + if type(x) == "userdata" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Instance | Vector3", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("number | string", toString(requireTypeAtPosition({5, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "eliminate_subclasses_of_instance") +{ + CheckResult result = check(R"( + local function f(x: Part | Folder | string) + if typeof(x) == "Instance" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Folder | Part", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("string", toString(requireTypeAtPosition({5, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "narrow_this_large_union") +{ + CheckResult result = check(R"( + local function f(x: Part | Folder | Instance | string | Vector3 | any) + if typeof(x) == "Instance" then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Folder | Instance | Part", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Vector3 | any | string", toString(requireTypeAtPosition({5, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "x_as_any_if_x_is_instance_elseif_x_is_table") +{ + CheckResult result = check(R"( + --!nonstrict + + local function f(x) + if typeof(x) == "Instance" and x:IsA("Folder") then + local foo = x + elseif typeof(x) == "table" then + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Folder", toString(requireTypeAtPosition({5, 28}))); + CHECK_EQ("any", toString(requireTypeAtPosition({7, 28}))); +} + +TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") +{ + CheckResult result = check(R"( + local function f(x: Part | Folder | string) + if typeof(x) ~= "Instance" or not x:IsA("Part") then + local foo = x + else + local foo = x + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("Folder | string", toString(requireTypeAtPosition({3, 28}))); + CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28}))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 94cfb6437..df365fda4 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -379,9 +379,7 @@ TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_string") ScopedFastFlag sffs[] = { {"LuauSingletonTypes", true}, {"LuauParseSingletonTypes", true}, - {"LuauUnionHeuristic", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauExtendedUnionMismatchError", true}, }; CheckResult result = check(R"( @@ -404,9 +402,7 @@ TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_bool") ScopedFastFlag sffs[] = { {"LuauSingletonTypes", true}, {"LuauParseSingletonTypes", true}, - {"LuauUnionHeuristic", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauExtendedUnionMismatchError", true}, }; CheckResult result = check(R"( @@ -429,9 +425,7 @@ TEST_CASE_FIXTURE(Fixture, "if_then_else_expression_singleton_options") ScopedFastFlag sffs[] = { {"LuauSingletonTypes", true}, {"LuauParseSingletonTypes", true}, - {"LuauUnionHeuristic", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauExtendedUnionMismatchError", true}, {"LuauIfElseExpectedType2", true}, {"LuauIfElseBranchTypeUnion", true}, }; diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 644efed75..483109213 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -12,8 +12,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauExtendedFunctionMismatchError) - TEST_SUITE_BEGIN("TableTests"); TEST_CASE_FIXTURE(Fixture, "basic") @@ -2075,22 +2073,11 @@ caused by: caused by: Property 'y' is not compatible. Type 'string' could not be converted into 'number')"); - if (FFlag::LuauExtendedFunctionMismatchError) - { - CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' + CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' caused by: Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); - } - else - { - CHECK_EQ(toString(result.errors[1]), R"(Type 'b2' could not be converted into 'a2' -caused by: - Type '{ __call: (a, b) -> () }' could not be converted into '{ __call: (a) -> () }' -caused by: - Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()')"); - } } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") @@ -2166,7 +2153,6 @@ a.p = { x = 9 } TEST_CASE_FIXTURE(Fixture, "recursive_metatable_type_call") { ScopedFastFlag sff[]{ - {"LuauFixRecursiveMetatableCall", true}, {"LuauUnsealedTableLiteral", true}, }; diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 7ee5253c7..c9b30e1a2 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -16,7 +16,6 @@ LUAU_FASTFLAG(LuauFixLocationSpanTableIndexExpr) LUAU_FASTFLAG(LuauEqConstraint) -LUAU_FASTFLAG(LuauExtendedFunctionMismatchError) using namespace Luau; @@ -959,8 +958,6 @@ TEST_CASE_FIXTURE(Fixture, "another_recursive_local_function") TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( function f() return f @@ -973,8 +970,6 @@ TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets") TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( function f(g) return f(f) @@ -1699,8 +1694,6 @@ TEST_CASE_FIXTURE(Fixture, "first_argument_can_be_optional") TEST_CASE_FIXTURE(Fixture, "dont_ice_when_failing_the_occurs_check") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( --!strict local s @@ -1711,8 +1704,6 @@ TEST_CASE_FIXTURE(Fixture, "dont_ice_when_failing_the_occurs_check") TEST_CASE_FIXTURE(Fixture, "occurs_check_does_not_recurse_forever_if_asked_to_traverse_a_cyclic_type") { - ScopedFastFlag sff{"LuauOccursCheckOkWithRecursiveFunctions", true}; - CheckResult result = check(R"( --!strict function u(t, w) @@ -3326,11 +3317,12 @@ TEST_CASE_FIXTURE(Fixture, "unknown_type_in_comparison") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable") +TEST_CASE_FIXTURE(Fixture, "concat_op_on_free_lhs_and_string_rhs") { CheckResult result = check(R"( - local x - print((x == true and (x .. "y")) .. 1) + local function f(x) + return x .. "y" + end )"); LUAU_REQUIRE_ERROR_COUNT(1, result); @@ -3340,13 +3332,14 @@ TEST_CASE_FIXTURE(Fixture, "relation_op_on_any_lhs_where_rhs_maybe_has_metatable TEST_CASE_FIXTURE(Fixture, "concat_op_on_string_lhs_and_free_rhs") { CheckResult result = check(R"( - local x - print("foo" .. x) + local function f(x) + return "foo" .. x + end )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("string", toString(requireType("x"))); + CHECK_EQ("(string) -> string", toString(requireType("f"))); } TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") @@ -4374,8 +4367,6 @@ TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_not_ok") TEST_CASE_FIXTURE(Fixture, "record_matching_overload") { - ScopedFastFlag sffs("LuauStoreMatchingOverloadFnType", true); - CheckResult result = check(R"( type Overload = ((string) -> string) & ((number) -> number) local abc: Overload @@ -4475,17 +4466,10 @@ f(function(a, b, c, ...) return a + b end) LUAU_REQUIRE_ERRORS(result); - if (FFlag::LuauExtendedFunctionMismatchError) - { - CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' caused by: Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", - toString(result.errors[0])); - } - else - { - CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number')", toString(result.errors[0])); - } + toString(result.errors[0])); // Infer from variadic packs into elements result = check(R"( @@ -4618,17 +4602,9 @@ local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not i )"); LUAU_REQUIRE_ERRORS(result); - if (FFlag::LuauExtendedFunctionMismatchError) - { - CHECK_EQ( - "Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'; different number of generic type " - "parameters", - toString(result.errors[0])); - } - else - { - CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'", toString(result.errors[0])); - } + CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'; different number of generic type " + "parameters", + toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "infer_return_value_type") @@ -4741,8 +4717,6 @@ TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") { - ScopedFastFlag luauCloneCorrectlyBeforeMutatingTableType{"LuauCloneCorrectlyBeforeMutatingTableType", true}; - CheckResult result = check(R"( type A = { x: number } local a: A = { x = 1 } @@ -4965,8 +4939,6 @@ TEST_CASE_FIXTURE(Fixture, "inferred_methods_of_free_tables_have_the_same_level_ TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg_count") { - ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; - CheckResult result = check(R"( type A = (number, number) -> string type B = (number) -> string @@ -4983,8 +4955,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg") { - ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; - CheckResult result = check(R"( type A = (number, number) -> string type B = (number, string) -> string @@ -5001,8 +4971,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_count") { - ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; - CheckResult result = check(R"( type A = (number, number) -> (number) type B = (number, number) -> (number, boolean) @@ -5019,8 +4987,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret") { - ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; - CheckResult result = check(R"( type A = (number, number) -> string type B = (number, number) -> number @@ -5037,8 +5003,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_mult") { - ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; - CheckResult result = check(R"( type A = (number, number) -> (number, string) type B = (number, number) -> (number, boolean) @@ -5069,8 +5033,6 @@ TEST_CASE_FIXTURE(Fixture, "prop_access_on_any_with_other_options") TEST_CASE_FIXTURE(Fixture, "table_function_check_use_after_free") { - ScopedFastFlag luauUnifyFunctionCheckResult{"LuauUpdateFunctionNameBinding", true}; - CheckResult result = check(R"( local t = {} diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index d4878d149..079870f57 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -931,7 +931,6 @@ type R = { m: F } TEST_CASE_FIXTURE(Fixture, "pack_tail_unification_check") { ScopedFastFlag luauUnifyPackTails{"LuauUnifyPackTails", true}; - ScopedFastFlag luauExtendedFunctionMismatchError{"LuauExtendedFunctionMismatchError", true}; CheckResult result = check(R"( local a: () -> (number, ...string) diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index b54ba9962..759794e63 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -464,8 +464,6 @@ local a: XYZ = { w = 4 } TEST_CASE_FIXTURE(Fixture, "error_detailed_optional") { - ScopedFastFlag luauExtendedUnionMismatchError{"LuauExtendedUnionMismatchError", true}; - CheckResult result = check(R"( type X = { x: number } diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 2e0d149ec..329e7b1f6 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -268,8 +268,6 @@ TEST_CASE_FIXTURE(Fixture, "substitution_skip_failure") TEST_CASE("tagging_tables") { - ScopedFastFlag sff{"LuauRefactorTagging", true}; - TypeVar ttv{TableTypeVar{}}; CHECK(!Luau::hasTag(&ttv, "foo")); Luau::attachTag(&ttv, "foo"); @@ -278,8 +276,6 @@ TEST_CASE("tagging_tables") TEST_CASE("tagging_classes") { - ScopedFastFlag sff{"LuauRefactorTagging", true}; - TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}}; CHECK(!Luau::hasTag(&base, "foo")); Luau::attachTag(&base, "foo"); @@ -288,8 +284,6 @@ TEST_CASE("tagging_classes") TEST_CASE("tagging_subclasses") { - ScopedFastFlag sff{"LuauRefactorTagging", true}; - TypeVar base{ClassTypeVar{"Base", {}, std::nullopt, std::nullopt, {}, nullptr}}; TypeVar derived{ClassTypeVar{"Derived", {}, &base, std::nullopt, {}, nullptr}}; @@ -307,8 +301,6 @@ TEST_CASE("tagging_subclasses") TEST_CASE("tagging_functions") { - ScopedFastFlag sff{"LuauRefactorTagging", true}; - TypePackVar empty{TypePack{}}; TypeVar ftv{FunctionTypeVar{&empty, &empty}}; CHECK(!Luau::hasTag(&ftv, "foo")); @@ -318,8 +310,6 @@ TEST_CASE("tagging_functions") TEST_CASE("tagging_props") { - ScopedFastFlag sff{"LuauRefactorTagging", true}; - Property prop{}; CHECK(!Luau::hasTag(prop, "foo")); Luau::attachTag(prop, "foo"); @@ -370,4 +360,66 @@ local b: (T, T, T) -> T CHECK_EQ(count, 1); } +TEST_CASE("isString_on_string_singletons") +{ + ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; + + TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}}; + CHECK(isString(&helloString)); +} + +TEST_CASE("isString_on_unions_of_various_string_singletons") +{ + ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; + + TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}}; + TypeVar byeString{SingletonTypeVar{StringSingleton{"bye"}}}; + TypeVar union_{UnionTypeVar{{&helloString, &byeString}}}; + + CHECK(isString(&union_)); +} + +TEST_CASE("proof_that_isString_uses_all_of") +{ + ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; + + TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}}; + TypeVar byeString{SingletonTypeVar{StringSingleton{"bye"}}}; + TypeVar booleanType{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}}; + TypeVar union_{UnionTypeVar{{&helloString, &byeString, &booleanType}}}; + + CHECK(!isString(&union_)); +} + +TEST_CASE("isBoolean_on_boolean_singletons") +{ + ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; + + TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}}; + CHECK(isBoolean(&trueBool)); +} + +TEST_CASE("isBoolean_on_unions_of_true_or_false_singletons") +{ + ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; + + TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}}; + TypeVar falseBool{SingletonTypeVar{BooleanSingleton{false}}}; + TypeVar union_{UnionTypeVar{{&trueBool, &falseBool}}}; + + CHECK(isBoolean(&union_)); +} + +TEST_CASE("proof_that_isBoolean_uses_all_of") +{ + ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; + + TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}}; + TypeVar falseBool{SingletonTypeVar{BooleanSingleton{false}}}; + TypeVar stringType{PrimitiveTypeVar{PrimitiveTypeVar::String}}; + TypeVar union_{UnionTypeVar{{&trueBool, &falseBool, &stringType}}}; + + CHECK(!isBoolean(&union_)); +} + TEST_SUITE_END(); From 78039f45355c10064bfb635ca6b316b2e9303294 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 27 Jan 2022 13:52:56 -0800 Subject: [PATCH 20/32] Thanks gcc, we know you can't compile code. --- Analysis/src/TypeInfer.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 23fcc2d5b..d6b3b5b37 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -5929,7 +5929,9 @@ void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMa if (!sense || canUnify(eqP.type, option, eqP.location).empty()) return sense ? eqP.type : option; - return std::nullopt; + // local variable works around an odd gcc 9.3 warning: may be used uninitialized + std::optional res = std::nullopt; + return res; } return option; From f6b4cc9442f57db2ef9c186d7fea008dde8b2f68 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 3 Feb 2022 15:09:37 -0800 Subject: [PATCH 21/32] Sync to upstream/release/513 --- Analysis/include/Luau/Error.h | 12 +- Analysis/include/Luau/LValue.h | 14 +- Analysis/include/Luau/Substitution.h | 6 + Analysis/include/Luau/TxnLog.h | 8 +- Analysis/include/Luau/TypeInfer.h | 9 +- Analysis/include/Luau/TypedAllocator.h | 22 +- Analysis/include/Luau/Unifier.h | 4 + Analysis/src/BuiltinDefinitions.cpp | 39 +- Analysis/src/LValue.cpp | 52 +- Analysis/src/Linter.cpp | 4 +- Analysis/src/Substitution.cpp | 65 +- Analysis/src/TxnLog.cpp | 71 +- Analysis/src/TypeInfer.cpp | 242 +-- Analysis/src/TypeVar.cpp | 8 +- Analysis/src/TypedAllocator.cpp | 1 - Analysis/src/Unifier.cpp | 523 +++-- Ast/src/Parser.cpp | 3 +- CLI/Coverage.cpp | 2 +- CLI/FileUtils.cpp | 2 +- CLI/Repl.cpp | 55 +- CLI/ReplEntry.cpp | 3 +- CMakeLists.txt | 19 +- Compiler/src/Builtins.cpp | 4 +- Compiler/src/Compiler.cpp | 9 +- Compiler/src/TableShape.cpp | 8 - Makefile | 24 +- Sources.cmake | 5 + VM/include/luaconf.h | 4 +- VM/src/lcorolib.cpp | 2 +- VM/src/ldebug.cpp | 3 +- VM/src/lfunc.cpp | 6 +- VM/src/lgc.cpp | 2 + VM/src/lmem.cpp | 53 +- VM/src/lstring.cpp | 2 +- VM/src/lvmexecute.cpp | 11 +- VM/src/lvmload.cpp | 8 +- extern/isocline/.gitignore | 16 + extern/isocline/LICENSE | 21 + extern/isocline/include/isocline.h | 627 ++++++ extern/isocline/readme.md | 460 ++++ extern/isocline/src/attr.c | 294 +++ extern/isocline/src/attr.h | 70 + extern/isocline/src/bbcode.c | 842 +++++++ extern/isocline/src/bbcode.h | 37 + extern/isocline/src/bbcode_colors.c | 194 ++ extern/isocline/src/common.c | 347 +++ extern/isocline/src/common.h | 187 ++ extern/isocline/src/completers.c | 675 ++++++ extern/isocline/src/completions.c | 326 +++ extern/isocline/src/completions.h | 52 + extern/isocline/src/editline.c | 1142 ++++++++++ extern/isocline/src/editline_completion.c | 277 +++ extern/isocline/src/editline_help.c | 140 ++ extern/isocline/src/editline_history.c | 260 +++ extern/isocline/src/env.h | 60 + extern/isocline/src/highlight.c | 259 +++ extern/isocline/src/highlight.h | 24 + extern/isocline/src/history.c | 269 +++ extern/isocline/src/history.h | 38 + extern/isocline/src/isocline.c | 589 +++++ extern/isocline/src/stringbuf.c | 1038 +++++++++ extern/isocline/src/stringbuf.h | 121 ++ extern/isocline/src/term.c | 1124 ++++++++++ extern/isocline/src/term.h | 85 + extern/isocline/src/term_color.c | 371 ++++ extern/isocline/src/tty.c | 889 ++++++++ extern/isocline/src/tty.h | 160 ++ extern/isocline/src/tty_esc.c | 401 ++++ extern/isocline/src/undo.c | 67 + extern/isocline/src/undo.h | 24 + extern/isocline/src/wcwidth.c | 292 +++ extern/linenoise.hpp | 2415 --------------------- tests/Autocomplete.test.cpp | 116 +- tests/Compiler.test.cpp | 20 +- tests/Conformance.test.cpp | 1 - tests/LValue.test.cpp | 62 +- tests/Linter.test.cpp | 25 +- tests/Parser.test.cpp | 7 +- tests/TypeInfer.annotations.test.cpp | 3 - tests/TypeInfer.builtins.test.cpp | 51 + tests/TypeInfer.provisional.test.cpp | 13 - tests/TypeInfer.refinements.test.cpp | 61 +- tests/TypeInfer.tables.test.cpp | 8 +- tests/TypeInfer.test.cpp | 86 +- tests/TypeInfer.tryUnify.test.cpp | 13 + tests/conformance/basic.lua | 4 + tests/conformance/vararg.lua | 50 +- 87 files changed, 12810 insertions(+), 3208 deletions(-) create mode 100644 extern/isocline/.gitignore create mode 100644 extern/isocline/LICENSE create mode 100644 extern/isocline/include/isocline.h create mode 100644 extern/isocline/readme.md create mode 100644 extern/isocline/src/attr.c create mode 100644 extern/isocline/src/attr.h create mode 100644 extern/isocline/src/bbcode.c create mode 100644 extern/isocline/src/bbcode.h create mode 100644 extern/isocline/src/bbcode_colors.c create mode 100644 extern/isocline/src/common.c create mode 100644 extern/isocline/src/common.h create mode 100644 extern/isocline/src/completers.c create mode 100644 extern/isocline/src/completions.c create mode 100644 extern/isocline/src/completions.h create mode 100644 extern/isocline/src/editline.c create mode 100644 extern/isocline/src/editline_completion.c create mode 100644 extern/isocline/src/editline_help.c create mode 100644 extern/isocline/src/editline_history.c create mode 100644 extern/isocline/src/env.h create mode 100644 extern/isocline/src/highlight.c create mode 100644 extern/isocline/src/highlight.h create mode 100644 extern/isocline/src/history.c create mode 100644 extern/isocline/src/history.h create mode 100644 extern/isocline/src/isocline.c create mode 100644 extern/isocline/src/stringbuf.c create mode 100644 extern/isocline/src/stringbuf.h create mode 100644 extern/isocline/src/term.c create mode 100644 extern/isocline/src/term.h create mode 100644 extern/isocline/src/term_color.c create mode 100644 extern/isocline/src/tty.c create mode 100644 extern/isocline/src/tty.h create mode 100644 extern/isocline/src/tty_esc.c create mode 100644 extern/isocline/src/undo.c create mode 100644 extern/isocline/src/undo.h create mode 100644 extern/isocline/src/wcwidth.c delete mode 100644 extern/linenoise.hpp diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index aff3c4d9e..a71e02246 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -285,12 +285,12 @@ struct TypesAreUnrelated bool operator==(const TypesAreUnrelated& rhs) const; }; -using TypeErrorData = Variant; +using TypeErrorData = + Variant; struct TypeError { diff --git a/Analysis/include/Luau/LValue.h b/Analysis/include/Luau/LValue.h index 8fd96f05a..3d510d5f9 100644 --- a/Analysis/include/Luau/LValue.h +++ b/Analysis/include/Luau/LValue.h @@ -4,7 +4,6 @@ #include "Luau/Variant.h" #include "Luau/Symbol.h" -#include // TODO: Kill with LuauLValueAsKey. #include #include @@ -38,24 +37,13 @@ std::optional tryGetLValue(const class AstExpr& expr); // Utility function: breaks down an LValue to get at the Symbol, and reverses the vector of keys. std::pair> getFullName(const LValue& lvalue); -// Kill with LuauLValueAsKey. -std::string toString(const LValue& lvalue); - template const T* get(const LValue& lvalue) { return get_if(&lvalue); } -using NEW_RefinementMap = std::unordered_map; -using DEPRECATED_RefinementMap = std::map; - -// Transient. Kill with LuauLValueAsKey. -struct RefinementMap -{ - NEW_RefinementMap NEW_refinements; - DEPRECATED_RefinementMap DEPRECATED_refinements; -}; +using RefinementMap = std::unordered_map; void merge(RefinementMap& l, const RefinementMap& r, std::function f); void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty); diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 80a14e8fb..4f3307cdf 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -55,6 +55,8 @@ namespace Luau { +struct TxnLog; + enum class TarjanResult { TooManyChildren, @@ -89,6 +91,10 @@ struct Tarjan int childCount = 0; + // This should never be null; ensure you initialize it before calling + // substitution methods. + const TxnLog* log; + std::vector edgesTy; std::vector edgesTp; std::vector worklist; diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index dc45bebf4..02b873748 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -72,6 +72,9 @@ struct PendingType } }; +std::string toString(PendingType* pending); +std::string dump(PendingType* pending); + // Pending state for a TypePackVar. Generated by a TxnLog and committed via // TxnLog::commit. struct PendingTypePack @@ -85,6 +88,9 @@ struct PendingTypePack } }; +std::string toString(PendingTypePack* pending); +std::string dump(PendingTypePack* pending); + template T* getMutable(PendingType* pending) { @@ -237,7 +243,7 @@ struct TxnLog // Follows a type, accounting for pending type states. The returned type may have // pending state; you should use `pending` or `get` to find out. - TypeId follow(TypeId ty); + TypeId follow(TypeId ty) const; // Follows a type pack, accounting for pending type states. The returned type pack // may have pending state; you should use `pending` or `get` to find out. diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index b843509dc..90dc9f426 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -262,7 +262,7 @@ struct TypeChecker * {method: ({method: () -> a}) -> a} * */ - TypeId instantiate(const ScopePtr& scope, TypeId ty, Location location); + TypeId instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log = TxnLog::empty()); // Replace any free types or type packs by `any`. // This is used when exporting types from modules, to make sure free types don't leak. @@ -308,9 +308,15 @@ struct TypeChecker TypeId singletonType(bool value); TypeId singletonType(std::string value); + TypeIdPredicate mkTruthyPredicate(bool sense); + // Returns nullopt if the predicate filters down the TypeId to 0 options. std::optional filterMap(TypeId type, TypeIdPredicate predicate); +public: + std::optional pickTypesFromSense(TypeId type, bool sense); + +private: TypeId unionOfTypes(TypeId a, TypeId b, const Location& location, bool unifyFreeTypes = true); // ex @@ -349,7 +355,6 @@ struct TypeChecker void refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate); std::optional resolveLValue(const ScopePtr& scope, const LValue& lvalue); - std::optional DEPRECATED_resolveLValue(const ScopePtr& scope, const LValue& lvalue); std::optional resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue); void resolve(const PredicateVec& predicates, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr = false); diff --git a/Analysis/include/Luau/TypedAllocator.h b/Analysis/include/Luau/TypedAllocator.h index 64227e7c1..c1c04d10a 100644 --- a/Analysis/include/Luau/TypedAllocator.h +++ b/Analysis/include/Luau/TypedAllocator.h @@ -6,8 +6,6 @@ #include #include -LUAU_FASTFLAG(LuauTypedAllocatorZeroStart) - namespace Luau { @@ -22,10 +20,7 @@ class TypedAllocator public: TypedAllocator() { - if (FFlag::LuauTypedAllocatorZeroStart) - currentBlockSize = kBlockSize; - else - appendBlock(); + currentBlockSize = kBlockSize; } ~TypedAllocator() @@ -64,18 +59,12 @@ class TypedAllocator bool empty() const { - if (FFlag::LuauTypedAllocatorZeroStart) - return stuff.empty(); - else - return stuff.size() == 1 && currentBlockSize == 0; + return stuff.empty(); } size_t size() const { - if (FFlag::LuauTypedAllocatorZeroStart) - return stuff.empty() ? 0 : kBlockSize * (stuff.size() - 1) + currentBlockSize; - else - return kBlockSize * (stuff.size() - 1) + currentBlockSize; + return stuff.empty() ? 0 : kBlockSize * (stuff.size() - 1) + currentBlockSize; } void clear() @@ -84,10 +73,7 @@ class TypedAllocator unfreeze(); free(); - if (FFlag::LuauTypedAllocatorZeroStart) - currentBlockSize = kBlockSize; - else - appendBlock(); + currentBlockSize = kBlockSize; } void freeze() diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 1b1671c0a..9db4e22b0 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -51,6 +51,10 @@ struct Unifier private: void tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall = false, bool isIntersection = false); + void tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId superTy); + void tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTypeVar* uv, bool cacheEnabled, bool isFunctionCall); + void tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionTypeVar* uv); + void tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall); void tryUnifyPrimitives(TypeId subTy, TypeId superTy); void tryUnifySingletons(TypeId subTy, TypeId superTy); void tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall = false); diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index d527414a1..d72422a53 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -8,6 +8,8 @@ #include +LUAU_FASTFLAG(LuauAssertStripsFalsyTypes) + /** FIXME: Many of these type definitions are not quite completely accurate. * * Some of them require richer generics than we have. For instance, we do not yet have a way to talk @@ -391,12 +393,41 @@ static std::optional> magicFunctionAssert( { auto [paramPack, predicates] = exprResult; - if (expr.args.size < 1) - return ExprResult{paramPack}; + if (FFlag::LuauAssertStripsFalsyTypes) + { + TypeArena& arena = typechecker.currentModule->internalTypes; + + auto [head, tail] = flatten(paramPack); + if (head.empty() && tail) + { + std::optional fst = first(*tail); + if (!fst) + return ExprResult{paramPack}; + head.push_back(*fst); + } + + typechecker.reportErrors(typechecker.resolve(predicates, scope, true)); + + if (head.size() > 0) + { + std::optional newhead = typechecker.pickTypesFromSense(head[0], true); + if (!newhead) + head = {typechecker.nilType}; + else + head[0] = *newhead; + } + + return ExprResult{arena.addTypePack(TypePack{std::move(head), tail})}; + } + else + { + if (expr.args.size < 1) + return ExprResult{paramPack}; - typechecker.reportErrors(typechecker.resolve(predicates, scope, true)); + typechecker.reportErrors(typechecker.resolve(predicates, scope, true)); - return ExprResult{paramPack}; + return ExprResult{paramPack}; + } } static std::optional> magicFunctionPack( diff --git a/Analysis/src/LValue.cpp b/Analysis/src/LValue.cpp index da6804c6b..c9466a40e 100644 --- a/Analysis/src/LValue.cpp +++ b/Analysis/src/LValue.cpp @@ -5,8 +5,6 @@ #include -LUAU_FASTFLAG(LuauLValueAsKey) - namespace Luau { @@ -94,17 +92,7 @@ std::pair> getFullName(const LValue& lvalue) return {*symbol, std::vector(keys.rbegin(), keys.rend())}; } -// Kill with LuauLValueAsKey. -std::string toString(const LValue& lvalue) -{ - auto [symbol, keys] = getFullName(lvalue); - std::string s = toString(symbol); - for (std::string key : keys) - s += "." + key; - return s; -} - -static void merge(NEW_RefinementMap& l, const NEW_RefinementMap& r, std::function f) +void merge(RefinementMap& l, const RefinementMap& r, std::function f) { for (const auto& [k, a] : r) { @@ -115,45 +103,9 @@ static void merge(NEW_RefinementMap& l, const NEW_RefinementMap& r, std::functio } } -static void merge(DEPRECATED_RefinementMap& l, const DEPRECATED_RefinementMap& r, std::function f) -{ - auto itL = l.begin(); - auto itR = r.begin(); - while (itL != l.end() && itR != r.end()) - { - const auto& [k, a] = *itR; - if (itL->first == k) - { - l[k] = f(itL->second, a); - ++itL; - ++itR; - } - else if (itL->first < k) - ++itL; - else - { - l[k] = a; - ++itR; - } - } - - l.insert(itR, r.end()); -} - -void merge(RefinementMap& l, const RefinementMap& r, std::function f) -{ - if (FFlag::LuauLValueAsKey) - return merge(l.NEW_refinements, r.NEW_refinements, f); - else - return merge(l.DEPRECATED_refinements, r.DEPRECATED_refinements, f); -} - void addRefinement(RefinementMap& refis, const LValue& lvalue, TypeId ty) { - if (FFlag::LuauLValueAsKey) - refis.NEW_refinements[lvalue] = ty; - else - refis.DEPRECATED_refinements[toString(lvalue)] = ty; + refis[lvalue] = ty; } } // namespace Luau diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 905b70bff..57a33e931 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -12,8 +12,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauLintTableCreateTable, false) - namespace Luau { @@ -2155,7 +2153,7 @@ class LintTableOperations : AstVisitor "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); } - if (FFlag::LuauLintTableCreateTable && func->index == "create" && node->args.size == 2) + if (func->index == "create" && node->args.size == 2) { // table.create(n, {...}) if (args[1]->is()) diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index 3d004bee3..bacbca762 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -2,6 +2,7 @@ #include "Luau/Substitution.h" #include "Luau/Common.h" +#include "Luau/TxnLog.h" #include #include @@ -13,17 +14,17 @@ namespace Luau void Tarjan::visitChildren(TypeId ty, int index) { - ty = follow(ty); + ty = log->follow(ty); if (ignoreChildren(ty)) return; - if (const FunctionTypeVar* ftv = get(ty)) + if (const FunctionTypeVar* ftv = log->getMutable(ty)) { visitChild(ftv->argTypes); visitChild(ftv->retType); } - else if (const TableTypeVar* ttv = get(ty)) + else if (const TableTypeVar* ttv = log->getMutable(ty)) { LUAU_ASSERT(!ttv->boundTo); for (const auto& [name, prop] : ttv->props) @@ -40,17 +41,17 @@ void Tarjan::visitChildren(TypeId ty, int index) for (TypePackId itp : ttv->instantiatedTypePackParams) visitChild(itp); } - else if (const MetatableTypeVar* mtv = get(ty)) + else if (const MetatableTypeVar* mtv = log->getMutable(ty)) { visitChild(mtv->table); visitChild(mtv->metatable); } - else if (const UnionTypeVar* utv = get(ty)) + else if (const UnionTypeVar* utv = log->getMutable(ty)) { for (TypeId opt : utv->options) visitChild(opt); } - else if (const IntersectionTypeVar* itv = get(ty)) + else if (const IntersectionTypeVar* itv = log->getMutable(ty)) { for (TypeId part : itv->parts) visitChild(part); @@ -59,19 +60,19 @@ void Tarjan::visitChildren(TypeId ty, int index) void Tarjan::visitChildren(TypePackId tp, int index) { - tp = follow(tp); + tp = log->follow(tp); if (ignoreChildren(tp)) return; - if (const TypePack* tpp = get(tp)) + if (const TypePack* tpp = log->getMutable(tp)) { for (TypeId tv : tpp->head) visitChild(tv); if (tpp->tail) visitChild(*tpp->tail); } - else if (const VariadicTypePack* vtp = get(tp)) + else if (const VariadicTypePack* vtp = log->getMutable(tp)) { visitChild(vtp->ty); } @@ -79,7 +80,7 @@ void Tarjan::visitChildren(TypePackId tp, int index) std::pair Tarjan::indexify(TypeId ty) { - ty = follow(ty); + ty = log->follow(ty); bool fresh = !typeToIndex.contains(ty); int& index = typeToIndex[ty]; @@ -97,7 +98,7 @@ std::pair Tarjan::indexify(TypeId ty) std::pair Tarjan::indexify(TypePackId tp) { - tp = follow(tp); + tp = log->follow(tp); bool fresh = !packToIndex.contains(tp); int& index = packToIndex[tp]; @@ -115,7 +116,7 @@ std::pair Tarjan::indexify(TypePackId tp) void Tarjan::visitChild(TypeId ty) { - ty = follow(ty); + ty = log->follow(ty); edgesTy.push_back(ty); edgesTp.push_back(nullptr); @@ -123,7 +124,7 @@ void Tarjan::visitChild(TypeId ty) void Tarjan::visitChild(TypePackId tp) { - tp = follow(tp); + tp = log->follow(tp); edgesTy.push_back(nullptr); edgesTp.push_back(tp); @@ -243,7 +244,7 @@ void Tarjan::clear() TarjanResult Tarjan::visitRoot(TypeId ty) { childCount = 0; - ty = follow(ty); + ty = log->follow(ty); clear(); auto [index, fresh] = indexify(ty); @@ -254,7 +255,7 @@ TarjanResult Tarjan::visitRoot(TypeId ty) TarjanResult Tarjan::visitRoot(TypePackId tp) { childCount = 0; - tp = follow(tp); + tp = log->follow(tp); clear(); auto [index, fresh] = indexify(tp); @@ -325,7 +326,7 @@ TarjanResult FindDirty::findDirty(TypePackId tp) std::optional Substitution::substitute(TypeId ty) { - ty = follow(ty); + ty = log->follow(ty); newTypes.clear(); newPacks.clear(); @@ -345,7 +346,7 @@ std::optional Substitution::substitute(TypeId ty) std::optional Substitution::substitute(TypePackId tp) { - tp = follow(tp); + tp = log->follow(tp); newTypes.clear(); newPacks.clear(); @@ -365,11 +366,11 @@ std::optional Substitution::substitute(TypePackId tp) TypeId Substitution::clone(TypeId ty) { - ty = follow(ty); + ty = log->follow(ty); TypeId result = ty; - if (const FunctionTypeVar* ftv = get(ty)) + if (const FunctionTypeVar* ftv = log->getMutable(ty)) { FunctionTypeVar clone = FunctionTypeVar{ftv->level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; clone.generics = ftv->generics; @@ -379,7 +380,7 @@ TypeId Substitution::clone(TypeId ty) clone.argNames = ftv->argNames; result = addType(std::move(clone)); } - else if (const TableTypeVar* ttv = get(ty)) + else if (const TableTypeVar* ttv = log->getMutable(ty)) { LUAU_ASSERT(!ttv->boundTo); TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; @@ -392,19 +393,19 @@ TypeId Substitution::clone(TypeId ty) clone.tags = ttv->tags; result = addType(std::move(clone)); } - else if (const MetatableTypeVar* mtv = get(ty)) + else if (const MetatableTypeVar* mtv = log->getMutable(ty)) { MetatableTypeVar clone = MetatableTypeVar{mtv->table, mtv->metatable}; clone.syntheticName = mtv->syntheticName; result = addType(std::move(clone)); } - else if (const UnionTypeVar* utv = get(ty)) + else if (const UnionTypeVar* utv = log->getMutable(ty)) { UnionTypeVar clone; clone.options = utv->options; result = addType(std::move(clone)); } - else if (const IntersectionTypeVar* itv = get(ty)) + else if (const IntersectionTypeVar* itv = log->getMutable(ty)) { IntersectionTypeVar clone; clone.parts = itv->parts; @@ -417,15 +418,15 @@ TypeId Substitution::clone(TypeId ty) TypePackId Substitution::clone(TypePackId tp) { - tp = follow(tp); - if (const TypePack* tpp = get(tp)) + tp = log->follow(tp); + if (const TypePack* tpp = log->getMutable(tp)) { TypePack clone; clone.head = tpp->head; clone.tail = tpp->tail; return addTypePack(std::move(clone)); } - else if (const VariadicTypePack* vtp = get(tp)) + else if (const VariadicTypePack* vtp = log->getMutable(tp)) { VariadicTypePack clone; clone.ty = vtp->ty; @@ -437,7 +438,7 @@ TypePackId Substitution::clone(TypePackId tp) void Substitution::foundDirty(TypeId ty) { - ty = follow(ty); + ty = log->follow(ty); if (isDirty(ty)) newTypes[ty] = clean(ty); else @@ -446,7 +447,7 @@ void Substitution::foundDirty(TypeId ty) void Substitution::foundDirty(TypePackId tp) { - tp = follow(tp); + tp = log->follow(tp); if (isDirty(tp)) newPacks[tp] = clean(tp); else @@ -455,7 +456,7 @@ void Substitution::foundDirty(TypePackId tp) TypeId Substitution::replace(TypeId ty) { - ty = follow(ty); + ty = log->follow(ty); if (TypeId* prevTy = newTypes.find(ty)) return *prevTy; else @@ -464,7 +465,7 @@ TypeId Substitution::replace(TypeId ty) TypePackId Substitution::replace(TypePackId tp) { - tp = follow(tp); + tp = log->follow(tp); if (TypePackId* prevTp = newPacks.find(tp)) return *prevTp; else @@ -473,7 +474,7 @@ TypePackId Substitution::replace(TypePackId tp) void Substitution::replaceChildren(TypeId ty) { - ty = follow(ty); + ty = log->follow(ty); if (ignoreChildren(ty)) return; @@ -519,7 +520,7 @@ void Substitution::replaceChildren(TypeId ty) void Substitution::replaceChildren(TypePackId tp) { - tp = follow(tp); + tp = log->follow(tp); if (ignoreChildren(tp)) return; diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index a46ac0c35..0968a4c10 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -1,6 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TxnLog.h" +#include "Luau/ToString.h" #include "Luau/TypePack.h" #include @@ -80,6 +81,56 @@ void DEPRECATED_TxnLog::popSeen(TypeId lhs, TypeId rhs) sharedSeen->pop_back(); } +const std::string nullPendingResult = ""; + +std::string toString(PendingType* pending) +{ + if (pending == nullptr) + return nullPendingResult; + + return toString(pending->pending); +} + +std::string dump(PendingType* pending) +{ + if (pending == nullptr) + { + printf("%s\n", nullPendingResult.c_str()); + return nullPendingResult; + } + + ToStringOptions opts; + opts.exhaustive = true; + opts.functionTypeArguments = true; + std::string result = toString(pending->pending, opts); + printf("%s\n", result.c_str()); + return result; +} + +std::string toString(PendingTypePack* pending) +{ + if (pending == nullptr) + return nullPendingResult; + + return toString(pending->pending); +} + +std::string dump(PendingTypePack* pending) +{ + if (pending == nullptr) + { + printf("%s\n", nullPendingResult.c_str()); + return nullPendingResult; + } + + ToStringOptions opts; + opts.exhaustive = true; + opts.functionTypeArguments = true; + std::string result = toString(pending->pending, opts); + printf("%s\n", result.c_str()); + return result; +} + static const TxnLog emptyLog; const TxnLog* TxnLog::empty() @@ -199,8 +250,6 @@ PendingTypePack* TxnLog::queue(TypePackId tp) PendingType* TxnLog::pending(TypeId ty) const { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - for (const TxnLog* current = this; current; current = current->parent) { if (auto it = current->typeVarChanges.find(ty); it != current->typeVarChanges.end()) @@ -212,8 +261,6 @@ PendingType* TxnLog::pending(TypeId ty) const PendingTypePack* TxnLog::pending(TypePackId tp) const { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - for (const TxnLog* current = this; current; current = current->parent) { if (auto it = current->typePackChanges.find(tp); it != current->typePackChanges.end()) @@ -225,8 +272,6 @@ PendingTypePack* TxnLog::pending(TypePackId tp) const PendingType* TxnLog::replace(TypeId ty, TypeVar replacement) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - PendingType* newTy = queue(ty); newTy->pending = replacement; return newTy; @@ -234,8 +279,6 @@ PendingType* TxnLog::replace(TypeId ty, TypeVar replacement) PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - PendingTypePack* newTp = queue(tp); newTp->pending = replacement; return newTp; @@ -243,7 +286,6 @@ PendingTypePack* TxnLog::replace(TypePackId tp, TypePackVar replacement) PendingType* TxnLog::bindTable(TypeId ty, std::optional newBoundTo) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); LUAU_ASSERT(get(ty)); PendingType* newTy = queue(ty); @@ -255,7 +297,6 @@ PendingType* TxnLog::bindTable(TypeId ty, std::optional newBoundTo) PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); LUAU_ASSERT(get(ty) || get(ty) || get(ty)); PendingType* newTy = queue(ty); @@ -278,7 +319,6 @@ PendingType* TxnLog::changeLevel(TypeId ty, TypeLevel newLevel) PendingTypePack* TxnLog::changeLevel(TypePackId tp, TypeLevel newLevel) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); LUAU_ASSERT(get(tp)); PendingTypePack* newTp = queue(tp); @@ -292,7 +332,6 @@ PendingTypePack* TxnLog::changeLevel(TypePackId tp, TypeLevel newLevel) PendingType* TxnLog::changeIndexer(TypeId ty, std::optional indexer) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); LUAU_ASSERT(get(ty)); PendingType* newTy = queue(ty); @@ -306,8 +345,6 @@ PendingType* TxnLog::changeIndexer(TypeId ty, std::optional indexe std::optional TxnLog::getLevel(TypeId ty) const { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - if (FreeTypeVar* ftv = getMutable(ty)) return ftv->level; else if (TableTypeVar* ttv = getMutable(ty); ttv && (ttv->state == TableState::Free || ttv->state == TableState::Generic)) @@ -318,10 +355,8 @@ std::optional TxnLog::getLevel(TypeId ty) const return std::nullopt; } -TypeId TxnLog::follow(TypeId ty) +TypeId TxnLog::follow(TypeId ty) const { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - return Luau::follow(ty, [this](TypeId ty) { PendingType* state = this->pending(ty); @@ -337,8 +372,6 @@ TypeId TxnLog::follow(TypeId ty) TypePackId TxnLog::follow(TypePackId tp) const { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - return Luau::follow(tp, [this](TypePackId tp) { PendingTypePack* state = this->pending(tp); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 23fcc2d5b..4d25fe2e3 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -32,6 +32,7 @@ LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauIfElseBranchTypeUnion, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpectedType2, false) LUAU_FASTFLAGVARIABLE(LuauLengthOnCompositeType, false) +LUAU_FASTFLAGVARIABLE(LuauNoSealedTypeMod, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) @@ -40,13 +41,12 @@ LUAU_FASTFLAGVARIABLE(LuauTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) -LUAU_FASTFLAGVARIABLE(LuauLValueAsKey, false) -LUAU_FASTFLAGVARIABLE(LuauRefiLookupFromIndexExpr, false) LUAU_FASTFLAGVARIABLE(LuauPerModuleUnificationCache, false) LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) -LUAU_FASTFLAGVARIABLE(LuauBidirectionalAsExpr, false) +LUAU_FASTFLAG(LuauUnionTagMatchFix) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) +LUAU_FASTFLAGVARIABLE(LuauAssertStripsFalsyTypes, false) namespace Luau { @@ -1117,7 +1117,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco ty = follow(ty); - if (tableSelf && !selfTy->persistent) + if (tableSelf && (FFlag::LuauNoSealedTypeMod ? tableSelf->state != TableState::Sealed : !selfTy->persistent)) tableSelf->props[indexName->index.value] = {ty, /* deprecated */ false, {}, indexName->indexLocation}; const FunctionTypeVar* funTy = get(ty); @@ -1130,7 +1130,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); - if (tableSelf && !selfTy->persistent) + if (tableSelf && (FFlag::LuauNoSealedTypeMod ? tableSelf->state != TableState::Sealed : !selfTy->persistent)) tableSelf->props[indexName->index.value] = { follow(quantify(funScope, ty, indexName->indexLocation)), /* deprecated */ false, {}, indexName->indexLocation}; } @@ -1657,7 +1657,7 @@ std::optional TypeChecker::getIndexTypeFromType( RecursionLimiter _rl(&recursionCount, FInt::LuauTypeInferRecursionLimit); // Not needed when we normalize types. - if (FFlag::LuauLValueAsKey && get(follow(t))) + if (get(follow(t))) return t; if (std::optional ty = getIndexTypeFromType(scope, t, name, location, false)) @@ -1802,12 +1802,9 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn { TypeId ty = checkLValue(scope, expr); - if (FFlag::LuauRefiLookupFromIndexExpr) - { - if (std::optional lvalue = tryGetLValue(expr)) - if (std::optional refiTy = resolveLValue(scope, *lvalue)) - return {*refiTy, {TruthyPredicate{std::move(*lvalue), expr.location}}}; - } + if (std::optional lvalue = tryGetLValue(expr)) + if (std::optional refiTy = resolveLValue(scope, *lvalue)) + return {*refiTy, {TruthyPredicate{std::move(*lvalue), expr.location}}}; return {ty}; } @@ -2471,33 +2468,28 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi { if (expr.op == AstExprBinary::And) { - ExprResult lhs = checkExpr(scope, *expr.left); + auto [lhsTy, lhsPredicates] = checkExpr(scope, *expr.left); - // We can't just report errors here. - // This function can be called from AstStatLocal or from AstStatIf, or even from AstExprBinary (and others). - // For now, ignore the errors returned by the predicate resolver. - // We may need an extra property for each predicate set that indicates it has been resolved. - // Requires a slight modification to the data structure. ScopePtr innerScope = childScope(scope, expr.location); - resolve(lhs.predicates, innerScope, true); + resolve(lhsPredicates, innerScope, true); - ExprResult rhs = checkExpr(innerScope, *expr.right); + auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right); - return {checkBinaryOperation(FFlag::LuauDiscriminableUnions ? scope : innerScope, expr, lhs.type, rhs.type), - {AndPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; + return {checkBinaryOperation(FFlag::LuauDiscriminableUnions ? scope : innerScope, expr, lhsTy, rhsTy), + {AndPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; } else if (expr.op == AstExprBinary::Or) { - ExprResult lhs = checkExpr(scope, *expr.left); + auto [lhsTy, lhsPredicates] = checkExpr(scope, *expr.left); ScopePtr innerScope = childScope(scope, expr.location); - resolve(lhs.predicates, innerScope, false); + resolve(lhsPredicates, innerScope, false); - ExprResult rhs = checkExpr(innerScope, *expr.right); + auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right); - // Because of C++, I'm not sure if lhs.predicates was not moved out by the time we call checkBinaryOperation. - TypeId result = checkBinaryOperation(FFlag::LuauDiscriminableUnions ? scope : innerScope, expr, lhs.type, rhs.type, lhs.predicates); - return {result, {OrPredicate{std::move(lhs.predicates), std::move(rhs.predicates)}}}; + // Because of C++, I'm not sure if lhsPredicates was not moved out by the time we call checkBinaryOperation. + TypeId result = checkBinaryOperation(FFlag::LuauDiscriminableUnions ? scope : innerScope, expr, lhsTy, rhsTy, lhsPredicates); + return {result, {OrPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; } else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) { @@ -2535,27 +2527,15 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprTy TypeId annotationType = resolveType(scope, *expr.annotation); ExprResult result = checkExpr(scope, *expr.expr, annotationType); - if (FFlag::LuauBidirectionalAsExpr) - { - // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. - if (canUnify(annotationType, result.type, expr.location).empty()) - return {annotationType, std::move(result.predicates)}; - - if (canUnify(result.type, annotationType, expr.location).empty()) - return {annotationType, std::move(result.predicates)}; - - reportError(expr.location, TypesAreUnrelated{result.type, annotationType}); - return {errorRecoveryType(annotationType), std::move(result.predicates)}; - } - else - { - ErrorVec errorVec = canUnify(annotationType, result.type, expr.location); - reportErrors(errorVec); - if (!errorVec.empty()) - annotationType = errorRecoveryType(annotationType); + // Note: As an optimization, we try 'number <: number | string' first, as that is the more likely case. + if (canUnify(annotationType, result.type, expr.location).empty()) + return {annotationType, std::move(result.predicates)}; + if (canUnify(result.type, annotationType, expr.location).empty()) return {annotationType, std::move(result.predicates)}; - } + + reportError(expr.location, TypesAreUnrelated{result.type, annotationType}); + return {errorRecoveryType(annotationType), std::move(result.predicates)}; } ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprError& expr) @@ -4295,7 +4275,7 @@ void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId s child.tryUnify(subTy, superTy, /*isFunctionCall*/ false); if (!child.errors.empty()) { - TypeId instantiated = instantiate(scope, subTy, state.location); + TypeId instantiated = instantiate(scope, subTy, state.location, &child.log); if (subTy == instantiated) { // Instantiating the argument made no difference, so just report any child errors @@ -4330,7 +4310,7 @@ void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId s bool Instantiation::isDirty(TypeId ty) { - if (get(ty)) + if (log->getMutable(ty)) return true; else return false; @@ -4343,7 +4323,7 @@ bool Instantiation::isDirty(TypePackId tp) bool Instantiation::ignoreChildren(TypeId ty) { - if (get(ty)) + if (log->getMutable(ty)) return true; else return false; @@ -4351,7 +4331,7 @@ bool Instantiation::ignoreChildren(TypeId ty) TypeId Instantiation::clean(TypeId ty) { - const FunctionTypeVar* ftv = get(ty); + const FunctionTypeVar* ftv = log->getMutable(ty); LUAU_ASSERT(ftv); FunctionTypeVar clone = FunctionTypeVar{level, ftv->argTypes, ftv->retType, ftv->definition, ftv->hasSelf}; @@ -4362,6 +4342,7 @@ TypeId Instantiation::clean(TypeId ty) // Annoyingly, we have to do this even if there are no generics, // to replace any generic tables. + replaceGenerics.log = log; replaceGenerics.level = level; replaceGenerics.currentModule = currentModule; replaceGenerics.generics.assign(ftv->generics.begin(), ftv->generics.end()); @@ -4383,7 +4364,7 @@ TypePackId Instantiation::clean(TypePackId tp) bool ReplaceGenerics::ignoreChildren(TypeId ty) { - if (const FunctionTypeVar* ftv = get(ty)) + if (const FunctionTypeVar* ftv = log->getMutable(ty)) // We aren't recursing in the case of a generic function which // binds the same generics. This can happen if, for example, there's recursive types. // If T = (a,T)->T then instantiating T should produce T' = (X,T)->T not T' = (X,T')->T'. @@ -4396,9 +4377,9 @@ bool ReplaceGenerics::ignoreChildren(TypeId ty) bool ReplaceGenerics::isDirty(TypeId ty) { - if (const TableTypeVar* ttv = get(ty)) + if (const TableTypeVar* ttv = log->getMutable(ty)) return ttv->state == TableState::Generic; - else if (get(ty)) + else if (log->getMutable(ty)) return std::find(generics.begin(), generics.end(), ty) != generics.end(); else return false; @@ -4406,7 +4387,7 @@ bool ReplaceGenerics::isDirty(TypeId ty) bool ReplaceGenerics::isDirty(TypePackId tp) { - if (get(tp)) + if (log->getMutable(tp)) return std::find(genericPacks.begin(), genericPacks.end(), tp) != genericPacks.end(); else return false; @@ -4415,7 +4396,7 @@ bool ReplaceGenerics::isDirty(TypePackId tp) TypeId ReplaceGenerics::clean(TypeId ty) { LUAU_ASSERT(isDirty(ty)); - if (const TableTypeVar* ttv = get(ty)) + if (const TableTypeVar* ttv = log->getMutable(ty)) { TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, TableState::Free}; clone.methodDefinitionLocations = ttv->methodDefinitionLocations; @@ -4434,9 +4415,9 @@ TypePackId ReplaceGenerics::clean(TypePackId tp) bool Quantification::isDirty(TypeId ty) { - if (const TableTypeVar* ttv = get(ty)) + if (const TableTypeVar* ttv = log->getMutable(ty)) return level.subsumes(ttv->level) && ((ttv->state == TableState::Free) || (ttv->state == TableState::Unsealed)); - else if (const FreeTypeVar* ftv = get(ty)) + else if (const FreeTypeVar* ftv = log->getMutable(ty)) return level.subsumes(ftv->level); else return false; @@ -4444,7 +4425,7 @@ bool Quantification::isDirty(TypeId ty) bool Quantification::isDirty(TypePackId tp) { - if (const FreeTypePack* ftv = get(tp)) + if (const FreeTypePack* ftv = log->getMutable(tp)) return level.subsumes(ftv->level); else return false; @@ -4453,7 +4434,7 @@ bool Quantification::isDirty(TypePackId tp) TypeId Quantification::clean(TypeId ty) { LUAU_ASSERT(isDirty(ty)); - if (const TableTypeVar* ttv = get(ty)) + if (const TableTypeVar* ttv = log->getMutable(ty)) { TableState state = (ttv->state == TableState::Unsealed ? TableState::Sealed : TableState::Generic); TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, state}; @@ -4479,9 +4460,9 @@ TypePackId Quantification::clean(TypePackId tp) bool Anyification::isDirty(TypeId ty) { - if (const TableTypeVar* ttv = get(ty)) + if (const TableTypeVar* ttv = log->getMutable(ty)) return (ttv->state == TableState::Free || (FFlag::LuauSealExports && ttv->state == TableState::Unsealed)); - else if (get(ty)) + else if (log->getMutable(ty)) return true; else return false; @@ -4489,7 +4470,7 @@ bool Anyification::isDirty(TypeId ty) bool Anyification::isDirty(TypePackId tp) { - if (get(tp)) + if (log->getMutable(tp)) return true; else return false; @@ -4498,7 +4479,7 @@ bool Anyification::isDirty(TypePackId tp) TypeId Anyification::clean(TypeId ty) { LUAU_ASSERT(isDirty(ty)); - if (const TableTypeVar* ttv = get(ty)) + if (const TableTypeVar* ttv = log->getMutable(ty)) { TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, TableState::Sealed}; clone.methodDefinitionLocations = ttv->methodDefinitionLocations; @@ -4535,6 +4516,7 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location return ty; } + quantification.log = TxnLog::empty(); quantification.level = scope->level; quantification.generics.clear(); quantification.genericPacks.clear(); @@ -4558,8 +4540,11 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location return *qty; } -TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location) +TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log) { + LUAU_ASSERT(log != nullptr); + + instantiation.log = FFlag::LuauUseCommittingTxnLog ? log : TxnLog::empty(); instantiation.level = scope->level; instantiation.currentModule = currentModule; std::optional instantiated = instantiation.substitute(ty); @@ -4574,6 +4559,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) { + anyification.log = TxnLog::empty(); anyification.anyType = anyType; anyification.anyTypePack = anyTypePack; anyification.currentModule = currentModule; @@ -4589,6 +4575,7 @@ TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location location) { + anyification.log = TxnLog::empty(); anyification.anyType = anyType; anyification.anyTypePack = anyTypePack; anyification.currentModule = currentModule; @@ -4660,7 +4647,7 @@ void TypeChecker::diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& d } }; - if (auto ttv = getTableType(follow(utk->table))) + if (auto ttv = getTableType(FFlag::LuauUnionTagMatchFix ? utk->table : follow(utk->table))) accumulate(ttv->props); else if (auto ctv = get(follow(utk->table))) { @@ -4775,6 +4762,29 @@ TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess) return getSingletonTypes().errorRecoveryTypePack(guess); } +TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense) { + return [this, sense](TypeId ty) -> std::optional { + // any/error/free gets a special pass unconditionally because they can't be decided. + if (get(ty) || get(ty) || get(ty)) + return ty; + + // maps boolean primitive to the corresponding singleton equal to sense + if (isPrim(ty, PrimitiveTypeVar::Boolean)) + return singletonType(sense); + + // if we have boolean singleton, eliminate it if the sense doesn't match with that singleton + if (auto boolean = get(get(ty))) + return boolean->value == sense ? std::optional(ty) : std::nullopt; + + // if we have nil, eliminate it if sense is true, otherwise take it + if (isNil(ty)) + return sense ? std::nullopt : std::optional(ty); + + // at this point, anything else is kept if sense is true, or eliminated otherwise + return sense ? std::optional(ty) : std::nullopt; + }; +} + std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predicate) { std::vector types = Luau::filterMap(type, predicate); @@ -4783,6 +4793,11 @@ std::optional TypeChecker::filterMap(TypeId type, TypeIdPredicate predic return std::nullopt; } +std::optional TypeChecker::pickTypesFromSense(TypeId type, bool sense) +{ + return filterMap(type, mkTruthyPredicate(sense)); +} + TypeId TypeChecker::addTV(TypeVar&& tv) { return currentModule->internalTypes.addType(std::move(tv)); @@ -5293,6 +5308,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, for (size_t i = 0; i < tf.typePackParams.size(); ++i) applyTypeFunction.typePackArguments[tf.typePackParams[i].tp] = typePackParams[i]; + applyTypeFunction.log = TxnLog::empty(); applyTypeFunction.currentModule = currentModule; applyTypeFunction.level = scope->level; applyTypeFunction.encounteredForwardedType = false; @@ -5507,9 +5523,6 @@ void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LValue& lvalue) { - if (!FFlag::LuauLValueAsKey) - return DEPRECATED_resolveLValue(scope, lvalue); - // We want to be walking the Scope parents. // We'll also want to walk up the LValue path. As we do this, we need to save each LValue because we must walk back. // For example: @@ -5529,7 +5542,7 @@ std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LV const LValue* currentLValue = &lvalue; while (currentLValue) { - if (auto it = currentScope->refinements.NEW_refinements.find(*currentLValue); it != currentScope->refinements.NEW_refinements.end()) + if (auto it = currentScope->refinements.find(*currentLValue); it != currentScope->refinements.end()) { found = it->second; break; @@ -5576,43 +5589,9 @@ std::optional TypeChecker::resolveLValue(const ScopePtr& scope, const LV return std::nullopt; } -std::optional TypeChecker::DEPRECATED_resolveLValue(const ScopePtr& scope, const LValue& lvalue) -{ - auto [symbol, keys] = getFullName(lvalue); - - ScopePtr currentScope = scope; - while (currentScope) - { - if (auto it = currentScope->refinements.DEPRECATED_refinements.find(toString(lvalue)); it != currentScope->refinements.DEPRECATED_refinements.end()) - return it->second; - - // Should not be using scope->lookup. This is already recursive. - if (auto it = currentScope->bindings.find(symbol); it != currentScope->bindings.end()) - { - std::optional currentTy = it->second.typeId; - - for (std::string key : keys) - { - // TODO: This function probably doesn't need Location at all, or at least should hide the argument. - currentTy = getIndexTypeFromType(scope, *currentTy, key, Location(), false); - if (!currentTy) - break; - } - - return currentTy; - } - - currentScope = currentScope->parent; - } - - return std::nullopt; -} - std::optional TypeChecker::resolveLValue(const RefinementMap& refis, const ScopePtr& scope, const LValue& lvalue) { - if (auto it = refis.DEPRECATED_refinements.find(toString(lvalue)); it != refis.DEPRECATED_refinements.end()) - return it->second; - else if (auto it = refis.NEW_refinements.find(lvalue); it != refis.NEW_refinements.end()) + if (auto it = refis.find(lvalue); it != refis.end()) return it->second; else return resolveLValue(scope, lvalue); @@ -5661,35 +5640,46 @@ void TypeChecker::resolve(const Predicate& predicate, ErrorVec& errVec, Refineme void TypeChecker::resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, RefinementMap& refis, const ScopePtr& scope, bool sense, bool fromOr) { - auto predicate = [sense](TypeId option) -> std::optional { - if (isUndecidable(option) || isBoolean(option) || isNil(option) != sense) - return option; - - return std::nullopt; - }; - - if (FFlag::LuauDiscriminableUnions) + if (FFlag::LuauAssertStripsFalsyTypes) { std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); if (ty && fromOr) return addRefinement(refis, truthyP.lvalue, *ty); - refineLValue(truthyP.lvalue, refis, scope, predicate); + refineLValue(truthyP.lvalue, refis, scope, mkTruthyPredicate(sense)); } else { - std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); - if (!ty) - return; + auto predicate = [sense](TypeId option) -> std::optional { + if (isUndecidable(option) || isBoolean(option) || isNil(option) != sense) + return option; - // This is a hack. :( - // Without this, the expression 'a or b' might refine 'b' to be falsy. - // I'm not yet sure how else to get this to do the right thing without this hack, so we'll do this for now in the meantime. - if (fromOr) - return addRefinement(refis, truthyP.lvalue, *ty); + return std::nullopt; + }; - if (std::optional result = filterMap(*ty, predicate)) - addRefinement(refis, truthyP.lvalue, *result); + if (FFlag::LuauDiscriminableUnions) + { + std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); + if (ty && fromOr) + return addRefinement(refis, truthyP.lvalue, *ty); + + refineLValue(truthyP.lvalue, refis, scope, predicate); + } + else + { + std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); + if (!ty) + return; + + // This is a hack. :( + // Without this, the expression 'a or b' might refine 'b' to be falsy. + // I'm not yet sure how else to get this to do the right thing without this hack, so we'll do this for now in the meantime. + if (fromOr) + return addRefinement(refis, truthyP.lvalue, *ty); + + if (std::optional result = filterMap(*ty, predicate)) + addRefinement(refis, truthyP.lvalue, *result); + } } } @@ -5929,7 +5919,9 @@ void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMa if (!sense || canUnify(eqP.type, option, eqP.location).empty()) return sense ? eqP.type : option; - return std::nullopt; + // local variable works around an odd gcc 9.3 warning: may be used uninitialized + std::optional res = std::nullopt; + return res; } return option; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 5b162b31b..2321eafda 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -27,6 +27,7 @@ LUAU_FASTFLAG(LuauLengthOnCompositeType) LUAU_FASTFLAGVARIABLE(LuauMetatableAreEqualRecursion, false) LUAU_FASTFLAGVARIABLE(LuauRefactorTypeVarQuestions, false) LUAU_FASTFLAG(LuauErrorRecoveryType) +LUAU_FASTFLAG(LuauUnionTagMatchFix) namespace Luau { @@ -288,10 +289,13 @@ std::optional getMetatable(TypeId type) const TableTypeVar* getTableType(TypeId type) { + if (FFlag::LuauUnionTagMatchFix) + type = follow(type); + if (const TableTypeVar* ttv = get(type)) return ttv; else if (const MetatableTypeVar* mtv = get(type)) - return get(mtv->table); + return get(FFlag::LuauUnionTagMatchFix ? follow(mtv->table) : mtv->table); else return nullptr; } @@ -308,7 +312,7 @@ const std::string* getName(TypeId type) { if (mtv->syntheticName) return &*mtv->syntheticName; - type = mtv->table; + type = FFlag::LuauUnionTagMatchFix ? follow(mtv->table) : mtv->table; } if (auto ttv = get(type)) diff --git a/Analysis/src/TypedAllocator.cpp b/Analysis/src/TypedAllocator.cpp index 1f7ef8c25..f037351e5 100644 --- a/Analysis/src/TypedAllocator.cpp +++ b/Analysis/src/TypedAllocator.cpp @@ -20,7 +20,6 @@ const size_t kPageSize = sysconf(_SC_PAGESIZE); #include LUAU_FASTFLAG(DebugLuauFreezeArena) -LUAU_FASTFLAGVARIABLE(LuauTypedAllocatorZeroStart, false) namespace Luau { diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 17d9bf58f..89e4ae237 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -8,6 +8,7 @@ #include "Luau/TypeUtils.h" #include "Luau/TimeTrace.h" #include "Luau/VisitTypeVar.h" +#include "Luau/ToString.h" #include @@ -22,6 +23,7 @@ LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauProperTypeLevels); LUAU_FASTFLAGVARIABLE(LuauUnifyPackTails, false) +LUAU_FASTFLAGVARIABLE(LuauUnionTagMatchFix, false) namespace Luau { @@ -225,19 +227,33 @@ static std::optional hasUnificationTooComplex(const ErrorVec& errors) // Used for tagged union matching heuristic, returns first singleton type field static std::optional> getTableMatchTag(TypeId type) { - type = follow(type); - - if (auto ttv = get(type)) + if (FFlag::LuauUnionTagMatchFix) { - for (auto&& [name, prop] : ttv->props) + if (auto ttv = getTableType(type)) { - if (auto sing = get(follow(prop.type))) - return {{name, sing}}; + for (auto&& [name, prop] : ttv->props) + { + if (auto sing = get(follow(prop.type))) + return {{name, sing}}; + } } } - else if (auto mttv = get(type)) + else { - return getTableMatchTag(mttv->table); + type = follow(type); + + if (auto ttv = get(type)) + { + for (auto&& [name, prop] : ttv->props) + { + if (auto sing = get(follow(prop.type))) + return {{name, sing}}; + } + } + else if (auto mttv = get(type)) + { + return getTableMatchTag(mttv->table); + } } return std::nullopt; @@ -508,293 +524,316 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (const UnionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) { - // A | B <: T if A <: T and B <: T - bool failed = false; - std::optional unificationTooComplex; - std::optional firstFailedOption; + tryUnifyUnionWithType(subTy, uv, superTy); + } + else if (const UnionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) + { + tryUnifyTypeWithUnion(subTy, superTy, uv, cacheEnabled, isFunctionCall); + } + else if (const IntersectionTypeVar* uv = + FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) + { + tryUnifyTypeWithIntersection(subTy, superTy, uv); + } + else if (const IntersectionTypeVar* uv = + FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) + { + tryUnifyIntersectionWithType(subTy, uv, superTy, cacheEnabled, isFunctionCall); + } + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) + tryUnifyPrimitives(subTy, superTy); + + else if (FFlag::LuauSingletonTypes && + ((FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) || + (FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy))) && + (FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy))) + tryUnifySingletons(subTy, superTy); + + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) + tryUnifyFunctions(subTy, superTy, isFunctionCall); + + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) + { + tryUnifyTables(subTy, superTy, isIntersection); + + if (cacheEnabled && errors.empty()) + cacheResult(subTy, superTy); + } + + // tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical. + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(superTy))) + tryUnifyWithMetatable(subTy, superTy, /*reversed*/ false); + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(subTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(subTy))) + tryUnifyWithMetatable(superTy, subTy, /*reversed*/ true); + + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy)) || + (!FFlag::LuauUseCommittingTxnLog && get(superTy))) + tryUnifyWithClass(subTy, superTy, /*reversed*/ false); + + // Unification of nonclasses with classes is almost, but not quite symmetrical. + // The order in which we perform this test is significant in the case that both types are classes. + else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(subTy)) || (!FFlag::LuauUseCommittingTxnLog && get(subTy))) + tryUnifyWithClass(subTy, superTy, /*reversed*/ true); + + else + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + + if (FFlag::LuauUseCommittingTxnLog) + log.popSeen(superTy, subTy); + else + DEPRECATED_log.popSeen(superTy, subTy); +} + +void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId superTy) +{ + // A | B <: T if A <: T and B <: T + bool failed = false; + std::optional unificationTooComplex; + std::optional firstFailedOption; - size_t count = uv->options.size(); - size_t i = 0; + size_t count = uv->options.size(); + size_t i = 0; - for (TypeId type : uv->options) + for (TypeId type : uv->options) + { + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(type, superTy); + + if (auto e = hasUnificationTooComplex(innerState.errors)) + unificationTooComplex = e; + else if (!innerState.errors.empty()) { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(type, superTy); + // 'nil' option is skipped from extended report because we present the type in a special way - 'T?' + if (!firstFailedOption && !isNil(type)) + firstFailedOption = {innerState.errors.front()}; - if (auto e = hasUnificationTooComplex(innerState.errors)) - unificationTooComplex = e; - else if (!innerState.errors.empty()) - { - // 'nil' option is skipped from extended report because we present the type in a special way - 'T?' - if (!firstFailedOption && !isNil(type)) - firstFailedOption = {innerState.errors.front()}; + failed = true; + } - failed = true; + if (FFlag::LuauUseCommittingTxnLog) + { + if (i == count - 1) + { + log.concat(std::move(innerState.log)); } - - if (FFlag::LuauUseCommittingTxnLog) + } + else + { + if (i != count - 1) { - if (i == count - 1) - { - log.concat(std::move(innerState.log)); - } + innerState.DEPRECATED_log.rollback(); } else { - if (i != count - 1) - { - innerState.DEPRECATED_log.rollback(); - } - else - { - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - } + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); } - - ++i; } - if (unificationTooComplex) - reportError(*unificationTooComplex); - else if (failed) - { - if (firstFailedOption) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); - else - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - } + ++i; } - else if (const UnionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) + + if (unificationTooComplex) + reportError(*unificationTooComplex); + else if (failed) { - // T <: A | B if T <: A or T <: B - bool found = false; - std::optional unificationTooComplex; + if (firstFailedOption) + reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all union options are compatible.", *firstFailedOption}}); + else + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + } +} - size_t failedOptionCount = 0; - std::optional failedOption; +void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTypeVar* uv, bool cacheEnabled, bool isFunctionCall) +{ + // T <: A | B if T <: A or T <: B + bool found = false; + std::optional unificationTooComplex; - bool foundHeuristic = false; - size_t startIndex = 0; + size_t failedOptionCount = 0; + std::optional failedOption; - if (const std::string* subName = getName(subTy)) - { - for (size_t i = 0; i < uv->options.size(); ++i) - { - const std::string* optionName = getName(uv->options[i]); - if (optionName && *optionName == *subName) - { - foundHeuristic = true; - startIndex = i; - break; - } - } - } + bool foundHeuristic = false; + size_t startIndex = 0; - if (auto subMatchTag = getTableMatchTag(subTy)) + if (const std::string* subName = getName(subTy)) + { + for (size_t i = 0; i < uv->options.size(); ++i) { - for (size_t i = 0; i < uv->options.size(); ++i) + const std::string* optionName = getName(uv->options[i]); + if (optionName && *optionName == *subName) { - auto optionMatchTag = getTableMatchTag(uv->options[i]); - if (optionMatchTag && optionMatchTag->first == subMatchTag->first && *optionMatchTag->second == *subMatchTag->second) - { - foundHeuristic = true; - startIndex = i; - break; - } + foundHeuristic = true; + startIndex = i; + break; } } + } - if (!foundHeuristic && cacheEnabled) + if (auto subMatchTag = getTableMatchTag(subTy)) + { + for (size_t i = 0; i < uv->options.size(); ++i) { - for (size_t i = 0; i < uv->options.size(); ++i) + auto optionMatchTag = getTableMatchTag(uv->options[i]); + if (optionMatchTag && optionMatchTag->first == subMatchTag->first && *optionMatchTag->second == *subMatchTag->second) { - TypeId type = uv->options[i]; - - if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type}))) - { - startIndex = i; - break; - } + foundHeuristic = true; + startIndex = i; + break; } } + } + + if (!foundHeuristic && cacheEnabled) + { + auto& cache = sharedState.cachedUnify; for (size_t i = 0; i < uv->options.size(); ++i) { - TypeId type = uv->options[(i + startIndex) % uv->options.size()]; - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(subTy, type, isFunctionCall); + TypeId type = uv->options[i]; - if (innerState.errors.empty()) + if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type}))) { - found = true; - if (FFlag::LuauUseCommittingTxnLog) - log.concat(std::move(innerState.log)); - else - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - + startIndex = i; break; } - else if (auto e = hasUnificationTooComplex(innerState.errors)) - { - unificationTooComplex = e; - } - else if (!isNil(type)) - { - failedOptionCount++; + } + } - if (!failedOption) - failedOption = {innerState.errors.front()}; - } + for (size_t i = 0; i < uv->options.size(); ++i) + { + TypeId type = uv->options[(i + startIndex) % uv->options.size()]; + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(subTy, type, isFunctionCall); - if (!FFlag::LuauUseCommittingTxnLog) - innerState.DEPRECATED_log.rollback(); - } + if (innerState.errors.empty()) + { + found = true; + if (FFlag::LuauUseCommittingTxnLog) + log.concat(std::move(innerState.log)); + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - if (unificationTooComplex) + break; + } + else if (auto e = hasUnificationTooComplex(innerState.errors)) { - reportError(*unificationTooComplex); + unificationTooComplex = e; } - else if (!found) + else if (!isNil(type)) { - if ((failedOptionCount == 1 || foundHeuristic) && failedOption) - reportError( - TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}}); - else - reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); + failedOptionCount++; + + if (!failedOption) + failedOption = {innerState.errors.front()}; } + + if (!FFlag::LuauUseCommittingTxnLog) + innerState.DEPRECATED_log.rollback(); } - else if (const IntersectionTypeVar* uv = - FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) + + if (unificationTooComplex) + { + reportError(*unificationTooComplex); + } + else if (!found) { - std::optional unificationTooComplex; - std::optional firstFailedOption; + if ((failedOptionCount == 1 || foundHeuristic) && failedOption) + reportError(TypeError{location, TypeMismatch{superTy, subTy, "None of the union options are compatible. For example:", *failedOption}}); + else + reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the union options are compatible"}}); + } +} - // T <: A & B if A <: T and B <: T - for (TypeId type : uv->parts) - { - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(subTy, type, /*isFunctionCall*/ false, /*isIntersection*/ true); +void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const IntersectionTypeVar* uv) +{ + std::optional unificationTooComplex; + std::optional firstFailedOption; - if (auto e = hasUnificationTooComplex(innerState.errors)) - unificationTooComplex = e; - else if (!innerState.errors.empty()) - { - if (!firstFailedOption) - firstFailedOption = {innerState.errors.front()}; - } + // T <: A & B if A <: T and B <: T + for (TypeId type : uv->parts) + { + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(subTy, type, /*isFunctionCall*/ false, /*isIntersection*/ true); - if (FFlag::LuauUseCommittingTxnLog) - log.concat(std::move(innerState.log)); - else - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + if (auto e = hasUnificationTooComplex(innerState.errors)) + unificationTooComplex = e; + else if (!innerState.errors.empty()) + { + if (!firstFailedOption) + firstFailedOption = {innerState.errors.front()}; } - if (unificationTooComplex) - reportError(*unificationTooComplex); - else if (firstFailedOption) - reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); + if (FFlag::LuauUseCommittingTxnLog) + log.concat(std::move(innerState.log)); + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); } - else if (const IntersectionTypeVar* uv = - FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) - { - // A & B <: T if T <: A or T <: B - bool found = false; - std::optional unificationTooComplex; - size_t startIndex = 0; + if (unificationTooComplex) + reportError(*unificationTooComplex); + else if (firstFailedOption) + reportError(TypeError{location, TypeMismatch{superTy, subTy, "Not all intersection parts are compatible.", *firstFailedOption}}); +} + +void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeVar* uv, TypeId superTy, bool cacheEnabled, bool isFunctionCall) +{ + // A & B <: T if T <: A or T <: B + bool found = false; + std::optional unificationTooComplex; - if (cacheEnabled) - { - for (size_t i = 0; i < uv->parts.size(); ++i) - { - TypeId type = uv->parts[i]; + size_t startIndex = 0; - if (cache.contains({superTy, type}) && (variance == Covariant || cache.contains({type, superTy}))) - { - startIndex = i; - break; - } - } - } + if (cacheEnabled) + { + auto& cache = sharedState.cachedUnify; for (size_t i = 0; i < uv->parts.size(); ++i) { - TypeId type = uv->parts[(i + startIndex) % uv->parts.size()]; - Unifier innerState = makeChildUnifier(); - innerState.tryUnify_(type, superTy, isFunctionCall); + TypeId type = uv->parts[i]; - if (innerState.errors.empty()) + if (cache.contains({superTy, type}) && (variance == Covariant || cache.contains({type, superTy}))) { - found = true; - if (FFlag::LuauUseCommittingTxnLog) - log.concat(std::move(innerState.log)); - else - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + startIndex = i; break; } - else if (auto e = hasUnificationTooComplex(innerState.errors)) - { - unificationTooComplex = e; - } - - if (!FFlag::LuauUseCommittingTxnLog) - innerState.DEPRECATED_log.rollback(); } + } - if (unificationTooComplex) - reportError(*unificationTooComplex); - else if (!found) + for (size_t i = 0; i < uv->parts.size(); ++i) + { + TypeId type = uv->parts[(i + startIndex) % uv->parts.size()]; + Unifier innerState = makeChildUnifier(); + innerState.tryUnify_(type, superTy, isFunctionCall); + + if (innerState.errors.empty()) { - reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); + found = true; + if (FFlag::LuauUseCommittingTxnLog) + log.concat(std::move(innerState.log)); + else + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + break; + } + else if (auto e = hasUnificationTooComplex(innerState.errors)) + { + unificationTooComplex = e; } - } - else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || - (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) - tryUnifyPrimitives(subTy, superTy); - - else if (FFlag::LuauSingletonTypes && - ((FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) || - (FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy))) && - (FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy))) - tryUnifySingletons(subTy, superTy); - else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || - (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) - tryUnifyFunctions(subTy, superTy, isFunctionCall); + if (!FFlag::LuauUseCommittingTxnLog) + innerState.DEPRECATED_log.rollback(); + } - else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || - (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) + if (unificationTooComplex) + reportError(*unificationTooComplex); + else if (!found) { - tryUnifyTables(subTy, superTy, isIntersection); - - if (cacheEnabled && errors.empty()) - cacheResult(subTy, superTy); + reportError(TypeError{location, TypeMismatch{superTy, subTy, "none of the intersection parts are compatible"}}); } - - // tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical. - else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy)) || - (!FFlag::LuauUseCommittingTxnLog && get(superTy))) - tryUnifyWithMetatable(subTy, superTy, /*reversed*/ false); - else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(subTy)) || - (!FFlag::LuauUseCommittingTxnLog && get(subTy))) - tryUnifyWithMetatable(superTy, subTy, /*reversed*/ true); - - else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy)) || - (!FFlag::LuauUseCommittingTxnLog && get(superTy))) - tryUnifyWithClass(subTy, superTy, /*reversed*/ false); - - // Unification of nonclasses with classes is almost, but not quite symmetrical. - // The order in which we perform this test is significant in the case that both types are classes. - else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(subTy)) || (!FFlag::LuauUseCommittingTxnLog && get(subTy))) - tryUnifyWithClass(subTy, superTy, /*reversed*/ true); - - else - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - - if (FFlag::LuauUseCommittingTxnLog) - log.popSeen(superTy, subTy); - else - DEPRECATED_log.popSeen(superTy, subTy); } void Unifier::cacheResult(TypeId subTy, TypeId superTy) @@ -1119,8 +1158,8 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal auto [superTypes, superTail] = logAwareFlatten(superTp, log); auto [subTypes, subTail] = logAwareFlatten(subTp, log); - bool noInfiniteGrowth = - (superTypes.size() != subTypes.size()) && (superTail && get(*superTail)) && (subTail && get(*subTail)); + bool noInfiniteGrowth = (superTypes.size() != subTypes.size()) && (superTail && log.getMutable(*superTail)) && + (subTail && log.getMutable(*subTail)); auto superIter = WeirdIter(superTp, log); auto subIter = WeirdIter(subTp, log); @@ -1667,6 +1706,13 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) TableTypeVar* superTable = getMutable(superTy); TableTypeVar* subTable = getMutable(subTy); + + if (FFlag::LuauUseCommittingTxnLog) + { + superTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); + } + if (!superTable || !subTable) ice("passed non-table types to unifyTables"); @@ -1679,7 +1725,11 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) for (const auto& [propName, superProp] : superTable->props) { auto subIter = subTable->props.find(propName); - if (subIter == subTable->props.end() && !isOptional(superProp.type) && !get(follow(superProp.type))) + + bool isAny = + FFlag::LuauUseCommittingTxnLog ? log.getMutable(log.follow(superProp.type)) : get(follow(superProp.type)); + + if (subIter == subTable->props.end() && !isOptional(superProp.type) && !isAny) missingProperties.push_back(propName); } @@ -1697,7 +1747,10 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) for (const auto& [propName, subProp] : subTable->props) { auto superIter = superTable->props.find(propName); - if (superIter == superTable->props.end() && !isOptional(subProp.type) && !get(follow(subProp.type))) + + bool isAny = + FFlag::LuauUseCommittingTxnLog ? log.getMutable(log.follow(subProp.type)) : get(follow(subProp.type)); + if (superIter == superTable->props.end() && !isOptional(subProp.type) && !isAny) extraProperties.push_back(propName); } @@ -1775,6 +1828,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) TableTypeVar* ttv = getMutable(pendingSub); LUAU_ASSERT(ttv); ttv->props[name] = prop; + subTable = ttv; } else { @@ -1831,6 +1885,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) PendingType* pendingSuper = log.queue(superTy); TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); pendingSuperTtv->props[name] = clone; + superTable = pendingSuperTtv; } else { @@ -1853,6 +1908,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) PendingType* pendingSuper = log.queue(superTy); TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); pendingSuperTtv->props[name] = prop; + superTable = pendingSuperTtv; } else { @@ -1967,7 +2023,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } else { - DEPRECATED_log(subTy); + DEPRECATED_log(subTable); subTable->boundTo = superTy; } } @@ -2408,8 +2464,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) if (auto e = hasUnificationTooComplex(innerState.errors)) reportError(*e); else if (!innerState.errors.empty()) - reportError( - TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); + reportError(TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); if (FFlag::LuauUseCommittingTxnLog) log.concat(std::move(innerState.log)); diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 3c607d24c..f559e2e07 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -10,7 +10,6 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauFixAmbiguousErrorRecoveryInAssign, false) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauParseTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauParseRecoverTypePackEllipsis, false) @@ -957,7 +956,7 @@ AstStat* Parser::parseAssignment(AstExpr* initial) { nextLexeme(); - AstExpr* expr = parsePrimaryExpr(/* asStatement= */ FFlag::LuauFixAmbiguousErrorRecoveryInAssign); + AstExpr* expr = parsePrimaryExpr(/* asStatement= */ true); if (!isExprLValue(expr)) expr = reportExprError(expr->location, copy({expr}), "Assigned expression must be a variable or a field"); diff --git a/CLI/Coverage.cpp b/CLI/Coverage.cpp index 254df3f03..a509ab89a 100644 --- a/CLI/Coverage.cpp +++ b/CLI/Coverage.cpp @@ -68,7 +68,7 @@ void coverageDump(const char* path) fprintf(f, "TN:\n"); - for (int fref: gCoverage.functions) + for (int fref : gCoverage.functions) { lua_getref(L, fref); diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index c68070227..fe005aece 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -77,7 +77,7 @@ std::optional readFile(const std::string& name) std::optional readStdin() { std::string result; - char buffer[4096] = { }; + char buffer[4096] = {}; while (fgets(buffer, sizeof(buffer), stdin) != nullptr) result.append(buffer); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index ab0f0ed08..5af6b508f 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -10,7 +10,7 @@ #include "Profiler.h" #include "Coverage.h" -#include "linenoise.hpp" +#include "isocline.h" #include @@ -240,9 +240,10 @@ std::string runCode(lua_State* L, const std::string& source) return std::string(); } -static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, std::vector& completions) +static void completeIndexer(ic_completion_env_t* cenv, const char* editBuffer) { - std::string_view lookup = editBuffer + start; + auto* L = reinterpret_cast(ic_completion_arg(cenv)); + std::string_view lookup = editBuffer; char lastSep = 0; for (;;) @@ -268,13 +269,14 @@ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, if (!key.empty() && requiredValueType && Luau::startsWith(key, prefix)) { - std::string completion(editBuffer + std::string(key.substr(prefix.size()))); + std::string completedComponent(key.substr(prefix.size())); + std::string completion(editBuffer + completedComponent); if (valueType == LUA_TFUNCTION) { // Add an opening paren for function calls by default. completion += "("; } - completions.push_back(completion); + ic_add_completion_ex(cenv, completion.data(), key.data(), nullptr); } } lua_pop(L, 1); @@ -310,19 +312,23 @@ static void completeIndexer(lua_State* L, const char* editBuffer, size_t start, lua_pop(L, 1); } -static void completeRepl(lua_State* L, const char* editBuffer, std::vector& completions) +static bool isMethodOrFunctionChar(const char* s, long len) { - size_t start = strlen(editBuffer); - while (start > 0 && (isalnum(editBuffer[start - 1]) || editBuffer[start - 1] == '.' || editBuffer[start - 1] == ':' || editBuffer[start - 1] == '_')) - start--; + char c = *s; + return len == 1 && (isalnum(c) || c == '.' || c == ':' || c == '_'); +} + +static void completeRepl(ic_completion_env_t* cenv, const char* editBuffer) +{ + auto* L = reinterpret_cast(ic_completion_arg(cenv)); // look the value up in current global table first lua_pushvalue(L, LUA_GLOBALSINDEX); - completeIndexer(L, editBuffer, start, completions); + ic_complete_word(cenv, editBuffer, completeIndexer, isMethodOrFunctionChar); // and in actual global table after that lua_getglobal(L, "_G"); - completeIndexer(L, editBuffer, start, completions); + ic_complete_word(cenv, editBuffer, completeIndexer, isMethodOrFunctionChar); } struct LinenoiseScopedHistory @@ -341,13 +347,11 @@ struct LinenoiseScopedHistory } if (!historyFilepath.empty()) - linenoise::LoadHistory(historyFilepath.c_str()); + ic_set_history(historyFilepath.c_str(), -1 /* default entries (= 200) */); } ~LinenoiseScopedHistory() { - if (!historyFilepath.empty()) - linenoise::SaveHistory(historyFilepath.c_str()); } std::string historyFilepath; @@ -355,28 +359,32 @@ struct LinenoiseScopedHistory static void runReplImpl(lua_State* L) { - linenoise::SetCompletionCallback([L](const char* editBuffer, std::vector& completions) { - completeRepl(L, editBuffer, completions); - }); + ic_set_default_completer(completeRepl, L); + + // Make brace matching easier to see + ic_style_def("ic-bracematch", "teal"); + + // Prevent auto insertion of braces + ic_enable_brace_insertion(false); std::string buffer; LinenoiseScopedHistory scopedHistory; for (;;) { - bool quit = false; - std::string line = linenoise::Readline(buffer.empty() ? "> " : ">> ", quit); - if (quit) + const char* line = ic_readline(buffer.empty() ? "" : ">"); + if (!line) break; if (buffer.empty() && runCode(L, std::string("return ") + line) == std::string()) { - linenoise::AddHistory(line.c_str()); + ic_history_add(line); continue; } + if (!buffer.empty()) + buffer += "\n"; buffer += line; - buffer += " "; // linenoise doesn't work very well with multiline history entries std::string error = runCode(L, buffer); @@ -390,8 +398,9 @@ static void runReplImpl(lua_State* L) fprintf(stdout, "%s\n", error.c_str()); } - linenoise::AddHistory(buffer.c_str()); + ic_history_add(buffer.c_str()); buffer.clear(); + free((void*)line); } } diff --git a/CLI/ReplEntry.cpp b/CLI/ReplEntry.cpp index b31317128..75995e6a8 100644 --- a/CLI/ReplEntry.cpp +++ b/CLI/ReplEntry.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details - #include "Repl.h" @@ -7,4 +6,4 @@ int main(int argc, char** argv) { return replMain(argc, argv); -} \ No newline at end of file +} diff --git a/CMakeLists.txt b/CMakeLists.txt index b9f7a9e11..881d3c3f9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,7 +5,7 @@ if(EXT_PLATFORM_STRING) endif() cmake_minimum_required(VERSION 3.0) -project(Luau LANGUAGES CXX) +project(Luau LANGUAGES CXX C) option(LUAU_BUILD_CLI "Build CLI" ON) option(LUAU_BUILD_TESTS "Build tests" ON) @@ -16,6 +16,7 @@ add_library(Luau.Ast STATIC) add_library(Luau.Compiler STATIC) add_library(Luau.Analysis STATIC) add_library(Luau.VM STATIC) +add_library(isocline STATIC) if(LUAU_BUILD_CLI) add_executable(Luau.Repl.CLI) @@ -52,6 +53,8 @@ target_link_libraries(Luau.Analysis PUBLIC Luau.Ast) target_compile_features(Luau.VM PRIVATE cxx_std_11) target_include_directories(Luau.VM PUBLIC VM/include) +target_include_directories(isocline PUBLIC extern/isocline/include) + set(LUAU_OPTIONS) if(MSVC) @@ -75,9 +78,16 @@ if(LUAU_BUILD_WEB) list(APPEND LUAU_OPTIONS -fexceptions) endif() +set(ISOCLINE_OPTIONS) + +if (NOT MSVC) + list(APPEND ISOCLINE_OPTIONS -Wno-unused-function) +endif() + target_compile_options(Luau.Ast PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analysis PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.VM PRIVATE ${LUAU_OPTIONS}) +target_compile_options(isocline PRIVATE ${LUAU_OPTIONS} ${ISOCLINE_OPTIONS}) if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924) # disable partial redundancy elimination which regresses interpreter codegen substantially in VS2022: @@ -89,8 +99,9 @@ if(LUAU_BUILD_CLI) target_compile_options(Luau.Repl.CLI PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) - target_include_directories(Luau.Repl.CLI PRIVATE extern) - target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.VM) + target_include_directories(Luau.Repl.CLI PRIVATE extern extern/isocline/include) + + target_link_libraries(Luau.Repl.CLI PRIVATE Luau.Compiler Luau.VM isocline) if(UNIX) find_library(LIBPTHREAD pthread) @@ -113,7 +124,7 @@ if(LUAU_BUILD_TESTS) target_compile_options(Luau.CLI.Test PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.CLI.Test PRIVATE extern CLI) - target_link_libraries(Luau.CLI.Test PRIVATE Luau.Compiler Luau.VM) + target_link_libraries(Luau.CLI.Test PRIVATE Luau.Compiler Luau.VM isocline) if(UNIX) find_library(LIBPTHREAD pthread) if (LIBPTHREAD) diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index a907271c9..26360c495 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -4,7 +4,7 @@ #include "Luau/Bytecode.h" #include "Luau/Compiler.h" -LUAU_FASTFLAGVARIABLE(LuauCompileSelectBuiltin, false) +LUAU_FASTFLAGVARIABLE(LuauCompileSelectBuiltin2, false) namespace Luau { @@ -64,7 +64,7 @@ int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options) if (builtin.isGlobal("unpack")) return LBF_TABLE_UNPACK; - if (FFlag::LuauCompileSelectBuiltin && builtin.isGlobal("select")) + if (FFlag::LuauCompileSelectBuiltin2 && builtin.isGlobal("select")) return LBF_SELECT_VARARG; if (builtin.object == "math") diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 7da852447..e4253adc3 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -16,7 +16,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauCompileTableIndexOpt, false) -LUAU_FASTFLAG(LuauCompileSelectBuiltin) +LUAU_FASTFLAG(LuauCompileSelectBuiltin2) namespace Luau { @@ -266,7 +266,7 @@ struct Compiler void compileExprSelectVararg(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop, bool multRet, uint8_t regs) { - LUAU_ASSERT(FFlag::LuauCompileSelectBuiltin); + LUAU_ASSERT(FFlag::LuauCompileSelectBuiltin2); LUAU_ASSERT(targetCount == 1); LUAU_ASSERT(!expr->self); LUAU_ASSERT(expr->args.size == 2 && expr->args.data[1]->is()); @@ -291,6 +291,9 @@ struct Compiler // we can't use TempTop variant here because we need to make sure the arguments we already computed aren't overwritten compileExprTemp(expr->func, regs); + if (argreg != regs + 1) + bytecode.emitABC(LOP_MOVE, uint8_t(regs + 1), argreg, 0); + bytecode.emitABC(LOP_GETVARARGS, uint8_t(regs + 2), 0, 0); size_t callLabel = bytecode.emitLabel(); @@ -405,7 +408,7 @@ struct Compiler if (bfid == LBF_SELECT_VARARG) { - LUAU_ASSERT(FFlag::LuauCompileSelectBuiltin); + LUAU_ASSERT(FFlag::LuauCompileSelectBuiltin2); // Optimization: compile select(_, ...) as FASTCALL1; the builtin will read variadic arguments directly // note: for now we restrict this to single-return expressions since our runtime code doesn't deal with general cases if (multRet == false && targetCount == 1 && expr->args.size == 2 && expr->args.data[1]->is()) diff --git a/Compiler/src/TableShape.cpp b/Compiler/src/TableShape.cpp index 9dc2f0a46..5a866e878 100644 --- a/Compiler/src/TableShape.cpp +++ b/Compiler/src/TableShape.cpp @@ -1,8 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "TableShape.h" -LUAU_FASTFLAGVARIABLE(LuauPredictTableSizeLoop, false) - namespace Luau { namespace Compile @@ -87,9 +85,6 @@ struct ShapeVisitor : AstVisitor } else if (AstExprLocal* iter = index->as()) { - if (!FFlag::LuauPredictTableSizeLoop) - return; - if (const unsigned int* bound = loops.find(iter->local)) { TableShape& shape = shapes[*table]; @@ -143,9 +138,6 @@ struct ShapeVisitor : AstVisitor bool visit(AstStatFor* node) override { - if (!FFlag::LuauPredictTableSizeLoop) - return true; - AstExprConstantNumber* from = node->from->as(); AstExprConstantNumber* to = node->to->as(); diff --git a/Makefile b/Makefile index 638c4c635..80eff0182 100644 --- a/Makefile +++ b/Makefile @@ -23,6 +23,10 @@ VM_SOURCES=$(wildcard VM/src/*.cpp) VM_OBJECTS=$(VM_SOURCES:%=$(BUILD)/%.o) VM_TARGET=$(BUILD)/libluauvm.a +ISOCLINE_SOURCES=extern/isocline/src/isocline.c +ISOCLINE_OBJECTS=$(ISOCLINE_SOURCES:%=$(BUILD)/%.o) +ISOCLINE_TARGET=$(BUILD)/libisocline.a + TESTS_SOURCES=$(wildcard tests/*.cpp) CLI/FileUtils.cpp CLI/Profiler.cpp CLI/Coverage.cpp CLI/Repl.cpp TESTS_OBJECTS=$(TESTS_SOURCES:%=$(BUILD)/%.o) TESTS_TARGET=$(BUILD)/luau-tests @@ -43,7 +47,7 @@ ifneq ($(flags),) TESTS_ARGS+=--fflags=$(flags) endif -OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(VM_OBJECTS) $(TESTS_OBJECTS) $(CLI_OBJECTS) $(FUZZ_OBJECTS) +OBJECTS=$(AST_OBJECTS) $(COMPILER_OBJECTS) $(ANALYSIS_OBJECTS) $(VM_OBJECTS) $(ISOCLINE_OBJECTS) $(TESTS_OBJECTS) $(CLI_OBJECTS) $(FUZZ_OBJECTS) # common flags CXXFLAGS=-g -Wall @@ -90,8 +94,9 @@ $(AST_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include $(COMPILER_OBJECTS): CXXFLAGS+=-std=c++17 -ICompiler/include -IAst/include $(ANALYSIS_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -IAnalysis/include $(VM_OBJECTS): CXXFLAGS+=-std=c++11 -IVM/include +$(ISOCLINE_OBJECTS): CXXFLAGS+=-Wno-unused-function -Iextern/isocline/include $(TESTS_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IAnalysis/include -IVM/include -ICLI -Iextern -$(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IVM/include -Iextern +$(REPL_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IVM/include -Iextern -Iextern/isocline/include $(ANALYZE_CLI_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -IAnalysis/include -Iextern $(FUZZ_OBJECTS): CXXFLAGS+=-std=c++17 -IAst/include -ICompiler/include -IAnalysis/include -IVM/include @@ -116,8 +121,8 @@ coverage: $(TESTS_TARGET) $(TESTS_TARGET) llvm-profdata merge default.profraw default-flags.profraw -o default.profdata rm default.profraw default-flags.profraw - llvm-cov show -format=html -show-instantiations=false -show-line-counts=true -show-region-summary=false -ignore-filename-regex=\(tests\|extern\)/.* -output-dir=coverage --instr-profile default.profdata build/coverage/luau-tests - llvm-cov report -ignore-filename-regex=\(tests\|extern\)/.* -show-region-summary=false --instr-profile default.profdata build/coverage/luau-tests + llvm-cov show -format=html -show-instantiations=false -show-line-counts=true -show-region-summary=false -ignore-filename-regex=\(tests\|extern\|CLI\)/.* -output-dir=coverage --instr-profile default.profdata build/coverage/luau-tests + llvm-cov report -ignore-filename-regex=\(tests\|extern\|CLI\)/.* -show-region-summary=false --instr-profile default.profdata build/coverage/luau-tests llvm-cov export -format lcov --instr-profile default.profdata build/coverage/luau-tests >coverage.info format: @@ -135,8 +140,8 @@ luau-analyze: $(ANALYZE_CLI_TARGET) ln -fs $^ $@ # executable targets -$(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) -$(REPL_CLI_TARGET): $(REPL_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) +$(TESTS_TARGET): $(TESTS_OBJECTS) $(ANALYSIS_TARGET) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) +$(REPL_CLI_TARGET): $(REPL_CLI_OBJECTS) $(COMPILER_TARGET) $(AST_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET) $(ANALYZE_CLI_TARGET): $(ANALYZE_CLI_OBJECTS) $(ANALYSIS_TARGET) $(AST_TARGET) $(TESTS_TARGET) $(REPL_CLI_TARGET) $(ANALYZE_CLI_TARGET): @@ -154,8 +159,9 @@ $(AST_TARGET): $(AST_OBJECTS) $(COMPILER_TARGET): $(COMPILER_OBJECTS) $(ANALYSIS_TARGET): $(ANALYSIS_OBJECTS) $(VM_TARGET): $(VM_OBJECTS) +$(ISOCLINE_TARGET): $(ISOCLINE_OBJECTS) -$(AST_TARGET) $(COMPILER_TARGET) $(ANALYSIS_TARGET) $(VM_TARGET): +$(AST_TARGET) $(COMPILER_TARGET) $(ANALYSIS_TARGET) $(VM_TARGET) $(ISOCLINE_TARGET): ar rcs $@ $^ # object file targets @@ -163,6 +169,10 @@ $(BUILD)/%.cpp.o: %.cpp @mkdir -p $(dir $@) $(CXX) $< $(CXXFLAGS) -c -MMD -MP -o $@ +$(BUILD)/%.c.o: %.c + @mkdir -p $(dir $@) + $(CXX) -x c $< $(CXXFLAGS) -c -MMD -MP -o $@ + # protobuf fuzzer setup fuzz/luau.pb.cpp: fuzz/luau.proto build/libprotobuf-mutator cd fuzz && ../build/libprotobuf-mutator/external.protobuf/bin/protoc luau.proto --cpp_out=. diff --git a/Sources.cmake b/Sources.cmake index 22e7af223..b36b6db56 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -167,6 +167,11 @@ target_sources(Luau.VM PRIVATE VM/src/lvm.h ) +target_sources(isocline PRIVATE + extern/isocline/include/isocline.h + extern/isocline/src/isocline.c +) + if(TARGET Luau.Repl.CLI) # Luau.Repl.CLI Sources target_sources(Luau.Repl.CLI PRIVATE diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h index 7e0832e7c..c5bf1c184 100644 --- a/VM/include/luaconf.h +++ b/VM/include/luaconf.h @@ -83,7 +83,7 @@ #endif #ifndef LUAI_GCSTEPSIZE -#define LUAI_GCSTEPSIZE 1 /* GC runs every KB of memory allocation */ +#define LUAI_GCSTEPSIZE 1 /* GC runs every KB of memory allocation */ #endif /* LUA_MINSTACK is the guaranteed number of Lua stack slots available to a C function */ @@ -153,6 +153,6 @@ long l; \ } -#define LUA_VECTOR_SIZE 3 /* must be 3 or 4 */ +#define LUA_VECTOR_SIZE 3 /* must be 3 or 4 */ #define LUA_EXTRA_SIZE LUA_VECTOR_SIZE - 2 diff --git a/VM/src/lcorolib.cpp b/VM/src/lcorolib.cpp index 192228613..7592a14cd 100644 --- a/VM/src/lcorolib.cpp +++ b/VM/src/lcorolib.cpp @@ -250,7 +250,7 @@ static int coclose(lua_State* L) { lua_pushboolean(L, false); if (lua_gettop(co)) - lua_xmove(co, L, 1); /* move error message */ + lua_xmove(co, L, 1); /* move error message */ lua_resetthread(co); return 2; } diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index 2b5382bba..e9930f7ab 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -12,7 +12,6 @@ #include #include -LUAU_FASTFLAG(LuauBytecodeV2Read) LUAU_FASTFLAG(LuauBytecodeV2Force) static const char* getfuncname(Closure* f); @@ -96,7 +95,7 @@ static int getlinedefined(Proto* p) { if (FFlag::LuauBytecodeV2Force) return p->linedefined; - else if (FFlag::LuauBytecodeV2Read && p->linedefined >= 0) + else if (p->linedefined >= 0) return p->linedefined; else return luaG_getline(p, 0); diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 6088f71c4..582d46277 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -90,7 +90,7 @@ UpVal* luaF_findupval(lua_State* L, StkId level) uv->tt = LUA_TUPVAL; uv->marked = luaC_white(g); uv->memcat = L->activememcat; - uv->v = level; /* current value lives in the stack */ + uv->v = level; /* current value lives in the stack */ // chain the upvalue in the threads open upvalue list at the proper position UpVal* next = *pp; @@ -138,8 +138,8 @@ void luaF_unlinkupval(UpVal* uv) void luaF_freeupval(lua_State* L, UpVal* uv, lua_Page* page) { - if (uv->v != &uv->u.value) /* is it open? */ - luaF_unlinkupval(uv); /* remove from open list */ + if (uv->v != &uv->u.value) /* is it open? */ + luaF_unlinkupval(uv); /* remove from open list */ luaM_freegco(L, uv, sizeof(UpVal), uv->memcat, page); /* free upvalue */ } diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 82ac00092..835572fa7 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -759,6 +759,8 @@ static int sweepgcopage(lua_State* L, lua_Page* page) // when true is returned it means that the element was deleted if (sweepgco(L, page, gco)) { + LUAU_ASSERT(busyBlocks > 0); + // if the last block was removed, page would be removed as well if (--busyBlocks == 0) return int(pos - start) / blockSize + 1; diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index e1dbce504..de85cf595 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -57,8 +57,7 @@ const size_t kSizeClasses = LUA_SIZECLASSES; const size_t kMaxSmallSize = 512; const size_t kPageSize = 16 * 1024 - 24; // slightly under 16KB since that results in less fragmentation due to heap metadata const size_t kBlockHeader = sizeof(double) > sizeof(void*) ? sizeof(double) : sizeof(void*); // suitable for aligning double & void* on all platforms -// TODO (FFlagLuauGcPagedSweep): when 'next' is removed, 'kBlockHeader' can be used unconditionally -const size_t kGCOHeader = sizeof(GCheader) > kBlockHeader ? sizeof(GCheader) : kBlockHeader; +const size_t kGCOLinkOffset = (sizeof(GCheader) + sizeof(void*) - 1) & ~(sizeof(void*) - 1); // GCO pages contain freelist links after the GC header struct SizeClassConfig { @@ -101,12 +100,12 @@ struct SizeClassConfig const SizeClassConfig kSizeClassConfig; -// size class for a block of size sz +// size class for a block of size sz; returns -1 for size=0 because empty allocations take no space #define sizeclass(sz) (size_t((sz)-1) < kMaxSmallSize ? kSizeClassConfig.classForSize[sz] : -1) // metadata for a block is stored in the first pointer of the block #define metadata(block) (*(void**)(block)) -#define freegcolink(block) (*(void**)((char*)block + kGCOHeader)) +#define freegcolink(block) (*(void**)((char*)block + kGCOLinkOffset)) /* ** About the realloc function: @@ -157,7 +156,7 @@ l_noret luaM_toobig(lua_State* L) luaG_runerror(L, "memory allocation error: block too big"); } -static lua_Page* luaM_newpage(lua_State* L, uint8_t sizeClass) +static lua_Page* newpageold(lua_State* L, uint8_t sizeClass) { LUAU_ASSERT(!FFlag::LuauGcPagedSweep); @@ -253,7 +252,7 @@ static lua_Page* newclasspage(lua_State* L, lua_Page** freepageset, lua_Page** g return page; } -static void luaM_freepage(lua_State* L, lua_Page* page, uint8_t sizeClass) +static void freepageold(lua_State* L, lua_Page* page, uint8_t sizeClass) { LUAU_ASSERT(!FFlag::LuauGcPagedSweep); @@ -310,7 +309,7 @@ static void freeclasspage(lua_State* L, lua_Page** freepageset, lua_Page** gcopa freepage(L, gcopageset, page); } -static void* luaM_newblock(lua_State* L, int sizeClass) +static void* newblock(lua_State* L, int sizeClass) { global_State* g = L->global; lua_Page* page = g->freepages[sizeClass]; @@ -321,7 +320,7 @@ static void* luaM_newblock(lua_State* L, int sizeClass) if (FFlag::LuauGcPagedSweep) page = newclasspage(L, g->freepages, NULL, sizeClass, true); else - page = luaM_newpage(L, sizeClass); + page = newpageold(L, sizeClass); } LUAU_ASSERT(!page->prev); @@ -363,7 +362,7 @@ static void* luaM_newblock(lua_State* L, int sizeClass) return (char*)block + kBlockHeader; } -static void* luaM_newgcoblock(lua_State* L, int sizeClass) +static void* newgcoblock(lua_State* L, int sizeClass) { LUAU_ASSERT(FFlag::LuauGcPagedSweep); @@ -390,11 +389,10 @@ static void* luaM_newgcoblock(lua_State* L, int sizeClass) } else { - // when separate block metadata is not used, free list link is stored inside the block data itself - block = (char*)page->freeList - kGCOHeader; - - ASAN_UNPOISON_MEMORY_REGION((char*)block + kGCOHeader, page->blockSize - kGCOHeader); + block = page->freeList; + ASAN_UNPOISON_MEMORY_REGION((char*)block + sizeof(GCheader), page->blockSize - sizeof(GCheader)); + // when separate block metadata is not used, free list link is stored inside the block data itself page->freeList = freegcolink(block); page->busyBlocks++; } @@ -412,7 +410,7 @@ static void* luaM_newgcoblock(lua_State* L, int sizeClass) return (char*)block; } -static void luaM_freeblock(lua_State* L, int sizeClass, void* block) +static void freeblock(lua_State* L, int sizeClass, void* block) { global_State* g = L->global; @@ -450,11 +448,11 @@ static void luaM_freeblock(lua_State* L, int sizeClass, void* block) if (FFlag::LuauGcPagedSweep) freeclasspage(L, g->freepages, NULL, page, sizeClass); else - luaM_freepage(L, page, sizeClass); + freepageold(L, page, sizeClass); } } -static void luaM_freegcoblock(lua_State* L, int sizeClass, void* block, lua_Page* page) +static void freegcoblock(lua_State* L, int sizeClass, void* block, lua_Page* page) { LUAU_ASSERT(FFlag::LuauGcPagedSweep); @@ -474,9 +472,9 @@ static void luaM_freegcoblock(lua_State* L, int sizeClass, void* block, lua_Page // when separate block metadata is not used, free list link is stored inside the block data itself freegcolink(block) = page->freeList; - page->freeList = (char*)block + kGCOHeader; + page->freeList = block; - ASAN_POISON_MEMORY_REGION((char*)block + kGCOHeader, page->blockSize - kGCOHeader); + ASAN_POISON_MEMORY_REGION((char*)block + sizeof(GCheader), page->blockSize - sizeof(GCheader)); page->busyBlocks--; @@ -491,7 +489,7 @@ void* luaM_new_(lua_State* L, size_t nsize, uint8_t memcat) int nclass = sizeclass(nsize); - void* block = nclass >= 0 ? luaM_newblock(L, nclass) : (*g->frealloc)(L, g->ud, NULL, 0, nsize); + void* block = nclass >= 0 ? newblock(L, nclass) : (*g->frealloc)(L, g->ud, NULL, 0, nsize); if (block == NULL && nsize > 0) luaD_throw(L, LUA_ERRMEM); @@ -506,6 +504,9 @@ GCObject* luaM_newgco_(lua_State* L, size_t nsize, uint8_t memcat) if (!FFlag::LuauGcPagedSweep) return (GCObject*)luaM_new_(L, nsize, memcat); + // we need to accommodate space for link for free blocks (freegcolink) + LUAU_ASSERT(nsize >= kGCOLinkOffset + sizeof(void*)); + global_State* g = L->global; int nclass = sizeclass(nsize); @@ -514,9 +515,7 @@ GCObject* luaM_newgco_(lua_State* L, size_t nsize, uint8_t memcat) if (nclass >= 0) { - LUAU_ASSERT(nsize > 8); - - block = luaM_newgcoblock(L, nclass); + block = newgcoblock(L, nclass); } else { @@ -546,7 +545,7 @@ void luaM_free_(lua_State* L, void* block, size_t osize, uint8_t memcat) int oclass = sizeclass(osize); if (oclass >= 0) - luaM_freeblock(L, oclass, block); + freeblock(L, oclass, block); else (*g->frealloc)(L, g->ud, block, osize, 0); @@ -571,7 +570,7 @@ void luaM_freegco_(lua_State* L, GCObject* block, size_t osize, uint8_t memcat, { block->gch.tt = LUA_TNIL; - luaM_freegcoblock(L, oclass, block, page); + freegcoblock(L, oclass, block, page); } else { @@ -596,7 +595,7 @@ void* luaM_realloc_(lua_State* L, void* block, size_t osize, size_t nsize, uint8 // if either block needs to be allocated using a block allocator, we can't use realloc directly if (nclass >= 0 || oclass >= 0) { - result = nclass >= 0 ? luaM_newblock(L, nclass) : (*g->frealloc)(L, g->ud, NULL, 0, nsize); + result = nclass >= 0 ? newblock(L, nclass) : (*g->frealloc)(L, g->ud, NULL, 0, nsize); if (result == NULL && nsize > 0) luaD_throw(L, LUA_ERRMEM); @@ -604,7 +603,7 @@ void* luaM_realloc_(lua_State* L, void* block, size_t osize, size_t nsize, uint8 memcpy(result, block, osize < nsize ? osize : nsize); if (oclass >= 0) - luaM_freeblock(L, oclass, block); + freeblock(L, oclass, block); else (*g->frealloc)(L, g->ud, block, osize, 0); } @@ -659,6 +658,8 @@ void luaM_visitpage(lua_Page* page, void* context, bool (*visitor)(void* context // when true is returned it means that the element was deleted if (visitor(context, page, gco)) { + LUAU_ASSERT(busyBlocks > 0); + // if the last block was removed, page would be removed as well if (--busyBlocks == 0) break; diff --git a/VM/src/lstring.cpp b/VM/src/lstring.cpp index cb22cc23a..9bbc43dec 100644 --- a/VM/src/lstring.cpp +++ b/VM/src/lstring.cpp @@ -57,7 +57,7 @@ void luaS_resize(lua_State* L, int newsize) { TString* p = tb->hash[i]; while (p) - { /* for each node in the list */ + { /* for each node in the list */ // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required TString* next = (TString*)p->next; /* save next */ unsigned int h = p->hash; diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index e58ff2a8e..cba3670ad 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -676,14 +676,9 @@ static void luau_execute(lua_State* L) VM_PROTECT_PC(); // set may fail TValue* res = luaH_setstr(L, h, tsvalue(kv)); - - if (res != luaO_nilobject) - { - int cachedslot = gval2slot(h, res); - // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ - VM_PATCH_C(pc - 2, cachedslot); - } - + int cachedslot = gval2slot(h, res); + // save cachedslot to accelerate future lookups; patches currently executing instruction since pc-2 rolls back two pc++ + VM_PATCH_C(pc - 2, cachedslot); setobj(L, res, ra); luaC_barriert(L, h, ra); VM_NEXT(); diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index cdb276c06..2472cd902 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -13,7 +13,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Read, false) LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Force, false) // TODO: RAII deallocation doesn't work for longjmp builds if a memory error happens @@ -157,11 +156,12 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size return 1; } - if (FFlag::LuauBytecodeV2Force ? (version != LBC_VERSION_FUTURE) : FFlag::LuauBytecodeV2Read ? (version != LBC_VERSION && version != LBC_VERSION_FUTURE) : (version != LBC_VERSION)) + if (FFlag::LuauBytecodeV2Force ? (version != LBC_VERSION_FUTURE) : (version != LBC_VERSION && version != LBC_VERSION_FUTURE)) { char chunkid[LUA_IDSIZE]; luaO_chunkid(chunkid, chunkname, LUA_IDSIZE); - lua_pushfstring(L, "%s: bytecode version mismatch (expected %d, got %d)", chunkid, FFlag::LuauBytecodeV2Force ? LBC_VERSION_FUTURE : LBC_VERSION, version); + lua_pushfstring(L, "%s: bytecode version mismatch (expected %d, got %d)", chunkid, + FFlag::LuauBytecodeV2Force ? LBC_VERSION_FUTURE : LBC_VERSION, version); return 1; } @@ -292,7 +292,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size p->p[j] = protos[fid]; } - if (FFlag::LuauBytecodeV2Force || (FFlag::LuauBytecodeV2Read && version == LBC_VERSION_FUTURE)) + if (FFlag::LuauBytecodeV2Force || version == LBC_VERSION_FUTURE) p->linedefined = readVarInt(data, size, offset); else p->linedefined = -1; diff --git a/extern/isocline/.gitignore b/extern/isocline/.gitignore new file mode 100644 index 000000000..470cc8137 --- /dev/null +++ b/extern/isocline/.gitignore @@ -0,0 +1,16 @@ +out/ +build/ +dist/ +doc/html/ +.vs/ +.vscode/ +.stack-work/ +.DS_Store +*.user +*.exe +*.hi +*.o +*_stub.h +*.lock +history.txt +isocline.debug.txt \ No newline at end of file diff --git a/extern/isocline/LICENSE b/extern/isocline/LICENSE new file mode 100644 index 000000000..7ac3104b1 --- /dev/null +++ b/extern/isocline/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 Daan Leijen + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/extern/isocline/include/isocline.h b/extern/isocline/include/isocline.h new file mode 100644 index 000000000..0d46cf3ff --- /dev/null +++ b/extern/isocline/include/isocline.h @@ -0,0 +1,627 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_ISOCLINE_H +#define IC_ISOCLINE_H + +#ifdef __cplusplus +extern "C" { +#endif + +#include // size_t +#include // bool +#include // uint32_t +#include // term_vprintf + + +/*! \mainpage +Isocline C API reference. + +Isocline is a pure C library that can be used as an alternative to the GNU readline library. + +See the [Github repository](https://github.com/daanx/isocline#readme) +for general information and building the library. + +Contents: +- \ref readline +- \ref bbcode +- \ref history +- \ref completion +- \ref highlight +- \ref options +- \ref helper +- \ref completex +- \ref term +- \ref async +- \ref alloc +*/ + +/// \defgroup readline Readline +/// The basic readline interface. +/// \{ + +/// Isocline version: 102 = 1.0.2. +#define IC_VERSION (104) + + +/// Read input from the user using rich editing abilities. +/// @param prompt_text The prompt text, can be NULL for the default (""). +/// The displayed prompt becomes `prompt_text` followed by the `prompt_marker` ("> "). +/// @returns the heap allocated input on succes, which should be `free`d by the caller. +/// Returns NULL on error, or if the user typed ctrl+d or ctrl+c. +/// +/// If the standard input (`stdin`) has no editing capability +/// (like a dumb terminal (e.g. `TERM`=`dumb`), running in a debuggen, a pipe or redirected file, etc.) +/// the input is read directly from the input stream up to the +/// next line without editing capability. +/// See also \a ic_set_prompt_marker(), \a ic_style_def() +/// +/// @see ic_set_prompt_marker(), ic_style_def() +char* ic_readline(const char* prompt_text); + +/// \} + + +//-------------------------------------------------------------- +/// \defgroup bbcode Formatted Text +/// Formatted text using [bbcode markup](https://github.com/daanx/isocline#bbcode-format). +/// \{ + +/// Print to the terminal while respection bbcode markup. +/// Any unclosed tags are closed automatically at the end of the print. +/// For example: +/// ``` +/// ic_print("[b]bold, [i]bold and italic[/i], [red]red and bold[/][/b] default."); +/// ic_print("[b]bold[/], [i b]bold and italic[/], [yellow on blue]yellow on blue background"); +/// ic_style_add("em","i color=#888800"); +/// ic_print("[em]emphasis"); +/// ``` +/// Properties that can be assigned are: +/// * `color=` _clr_, `bgcolor=` _clr_: where _clr_ is either a hex value `#`RRGGBB or `#`RGB, a +/// standard HTML color name, or an ANSI palette name, like `ansi-maroon`, `ansi-default`, etc. +/// * `bold`,`italic`,`reverse`,`underline`: can be `on` or `off`. +/// * everything else is a style; all HTML and ANSI color names are also a style (so we can just use `red` +/// instead of `color=red`, or `on red` instead of `bgcolor=red`), and there are +/// the `b`, `i`, `u`, and `r` styles for bold, italic, underline, and reverse. +/// +/// See [here](https://github.com/daanx/isocline#bbcode-format) for a description of the full bbcode format. +void ic_print( const char* s ); + +/// Print with bbcode markup ending with a newline. +/// @see ic_print() +void ic_println( const char* s ); + +/// Print formatted with bbcode markup. +/// @see ic_print() +void ic_printf(const char* fmt, ...); + +/// Print formatted with bbcode markup. +/// @see ic_print +void ic_vprintf(const char* fmt, va_list args); + +/// Define or redefine a style. +/// @param style_name The name of the style. +/// @param fmt The `fmt` string is the content of a tag and can contain +/// other styles. This is very useful to theme the output of a program +/// by assigning standard styles like `em` or `warning` etc. +void ic_style_def( const char* style_name, const char* fmt ); + +/// Start a global style that is only reset when calling a matching ic_style_close(). +void ic_style_open( const char* fmt ); + +/// End a global style. +void ic_style_close(void); + +/// \} + + +//-------------------------------------------------------------- +// History +//-------------------------------------------------------------- +/// \defgroup history History +/// Readline input history. +/// \{ + +/// Enable history. +/// Use a \a NULL filename to not persist the history. Use -1 for max_entries to get the default (200). +void ic_set_history(const char* fname, long max_entries ); + +/// Remove the last entry in the history. +/// The last returned input from ic_readline() is automatically added to the history; this function removes it. +void ic_history_remove_last(void); + +/// Clear the history. +void ic_history_clear(void); + +/// Add an entry to the history +void ic_history_add( const char* entry ); + +/// \} + +//-------------------------------------------------------------- +// Basic Completion +//-------------------------------------------------------------- + +/// \defgroup completion Completion +/// Basic word completion. +/// \{ + +/// A completion environment +struct ic_completion_env_s; + +/// A completion environment +typedef struct ic_completion_env_s ic_completion_env_t; + +/// A completion callback that is called by isocline when tab is pressed. +/// It is passed a completion environment (containing the current input and the current cursor position), +/// the current input up-to the cursor (`prefix`) +/// and the user given argument when the callback was set. +/// When using completion transformers, like `ic_complete_quoted_word` the `prefix` contains the +/// the word to be completed without escape characters or quotes. +typedef void (ic_completer_fun_t)(ic_completion_env_t* cenv, const char* prefix ); + +/// Set the default completion handler. +/// @param completer The completion function +/// @param arg Argument passed to the \a completer. +/// There can only be one default completion function, setting it again disables the previous one. +/// The initial completer use `ic_complete_filename`. +void ic_set_default_completer( ic_completer_fun_t* completer, void* arg); + + +/// In a completion callback (usually from ic_complete_word()), use this function to add a completion. +/// (the completion string is copied by isocline and do not need to be preserved or allocated). +/// +/// Returns `true` if the callback should continue trying to find more possible completions. +/// If `false` is returned, the callback should try to return and not add more completions (for improved latency). +bool ic_add_completion(ic_completion_env_t* cenv, const char* completion); + +/// In a completion callback (usually from ic_complete_word()), use this function to add a completion. +/// The `display` is used to display the completion in the completion menu, and `help` is +/// displayed for hints for example. Both can be `NULL` for the default. +/// (all are copied by isocline and do not need to be preserved or allocated). +/// +/// Returns `true` if the callback should continue trying to find more possible completions. +/// If `false` is returned, the callback should try to return and not add more completions (for improved latency). +bool ic_add_completion_ex( ic_completion_env_t* cenv, const char* completion, const char* display, const char* help ); + +/// In a completion callback (usually from ic_complete_word()), use this function to add completions. +/// The `completions` array should be terminated with a NULL element, and all elements +/// are added as completions if they start with `prefix`. +/// +/// Returns `true` if the callback should continue trying to find more possible completions. +/// If `false` is returned, the callback should try to return and not add more completions (for improved latency). +bool ic_add_completions(ic_completion_env_t* cenv, const char* prefix, const char** completions); + +/// Complete a filename. +/// Complete a filename given a semi-colon separated list of root directories `roots` and +/// semi-colon separated list of possible extensions (excluding directories). +/// If `roots` is NULL, the current directory is the root ("."). +/// If `extensions` is NULL, any extension will match. +/// Each root directory should _not_ end with a directory separator. +/// If a directory is completed, the `dir_separator` is added at the end if it is not `0`. +/// Usually the `dir_separator` is `/` but it can be set to `\\` on Windows systems. +/// For example: +/// ``` +/// /ho --> /home/ +/// /home/.ba --> /home/.bashrc +/// ``` +/// (This already uses ic_complete_quoted_word() so do not call it from inside a word handler). +void ic_complete_filename( ic_completion_env_t* cenv, const char* prefix, char dir_separator, const char* roots, const char* extensions ); + + + +/// Function that returns whether a (utf8) character (of length `len`) is in a certain character class +/// @see ic_char_is_separator() etc. +typedef bool (ic_is_char_class_fun_t)(const char* s, long len); + + +/// Complete a _word_ (i.e. _token_). +/// Calls the user provided function `fun` to complete on the +/// current _word_. Almost all user provided completers should use this function. +/// If `is_word_char` is NULL, the default `&ic_char_is_nonseparator` is used. +/// The `prefix` passed to `fun` is modified to only contain the current word, and +/// any results from `ic_add_completion` are automatically adjusted to replace that part. +/// For example, on the input "hello w", a the user `fun` only gets `w` and can just complete +/// with "world" resulting in "hello world" without needing to consider `delete_before` etc. +/// @see ic_complete_qword() for completing quoted and escaped tokens. +void ic_complete_word(ic_completion_env_t* cenv, const char* prefix, ic_completer_fun_t* fun, ic_is_char_class_fun_t* is_word_char); + + +/// Complete a quoted _word_. +/// Calls the user provided function `fun` to complete while taking +/// care of quotes and escape characters. Almost all user provided completers should use +/// this function. The `prefix` passed to `fun` is modified to be unquoted and unescaped, and +/// any results from `ic_add_completion` are automatically quoted and escaped again. +/// For example, completing `hello world`, the `fun` always just completes `hel` or `hello w` to `hello world`, +/// but depending on user input, it will complete as: +/// ``` +/// hel --> hello\ world +/// hello\ w --> hello\ world +/// hello w --> # no completion, the word is just 'w'> +/// "hel --> "hello world" +/// "hello w --> "hello world" +/// ``` +/// with proper quotes and escapes. +/// If `is_word_char` is NULL, the default `&ic_char_is_nonseparator` is used. +/// @see ic_complete_quoted_word() to customize the word boundary, quotes etc. +void ic_complete_qword( ic_completion_env_t* cenv, const char* prefix, ic_completer_fun_t* fun, ic_is_char_class_fun_t* is_word_char ); + + + +/// Complete a _word_. +/// Calls the user provided function `fun` to complete while taking +/// care of quotes and escape characters. Almost all user provided completers should use this function. +/// The `is_word_char` is a set of characters that are part of a "word". Use NULL for the default (`&ic_char_is_nonseparator`). +/// The `escape_char` is the escaping character, usually `\` but use 0 to not have escape characters. +/// The `quote_chars` define the quotes, use NULL for the default `"\'\""` quotes. +/// @see ic_complete_word() which uses the default values for `non_word_chars`, `quote_chars` and `\` for escape characters. +void ic_complete_qword_ex( ic_completion_env_t* cenv, const char* prefix, ic_completer_fun_t fun, + ic_is_char_class_fun_t* is_word_char, char escape_char, const char* quote_chars ); + +/// \} + +//-------------------------------------------------------------- +/// \defgroup highlight Syntax Highlighting +/// Basic syntax highlighting. +/// \{ + +/// A syntax highlight environment +struct ic_highlight_env_s; +typedef struct ic_highlight_env_s ic_highlight_env_t; + +/// A syntax highlighter callback that is called by readline to syntax highlight user input. +typedef void (ic_highlight_fun_t)(ic_highlight_env_t* henv, const char* input, void* arg); + +/// Set a syntax highlighter. +/// There can only be one highlight function, setting it again disables the previous one. +void ic_set_default_highlighter(ic_highlight_fun_t* highlighter, void* arg); + +/// Set the style of characters starting at position `pos`. +void ic_highlight(ic_highlight_env_t* henv, long pos, long count, const char* style ); + +/// Experimental: Convenience callback for a function that highlights `s` using bbcode's. +/// The returned string should be allocated and is free'd by the caller. +typedef char* (ic_highlight_format_fun_t)(const char* s, void* arg); + +/// Experimental: Convenience function for highlighting with bbcodes. +/// Can be called in a `ic_highlight_fun_t` callback to colorize the `input` using the +/// the provided `formatted` input that is the styled `input` with bbcodes. The +/// content of `formatted` without bbcode tags should match `input` exactly. +void ic_highlight_formatted(ic_highlight_env_t* henv, const char* input, const char* formatted); + +/// \} + +//-------------------------------------------------------------- +// Readline with a specific completer and highlighter +//-------------------------------------------------------------- + +/// \defgroup readline +/// \{ + +/// Read input from the user using rich editing abilities, +/// using a particular completion function and highlighter for this call only. +/// both can be NULL in which case the defaults are used. +/// @see ic_readline(), ic_set_prompt_marker(), ic_set_default_completer(), ic_set_default_highlighter(). +char* ic_readline_ex(const char* prompt_text, ic_completer_fun_t* completer, void* completer_arg, + ic_highlight_fun_t* highlighter, void* highlighter_arg); + +/// \} + + +//-------------------------------------------------------------- +// Options +//-------------------------------------------------------------- + +/// \defgroup options Options +/// \{ + +/// Set a prompt marker and a potential marker for extra lines with multiline input. +/// Pass \a NULL for the `prompt_marker` for the default marker (`"> "`). +/// Pass \a NULL for continuation prompt marker to make it equal to the `prompt_marker`. +void ic_set_prompt_marker( const char* prompt_marker, const char* continuation_prompt_marker ); + +/// Get the current prompt marker. +const char* ic_get_prompt_marker(void); + +/// Get the current continuation prompt marker. +const char* ic_get_continuation_prompt_marker(void); + +/// Disable or enable multi-line input (enabled by default). +/// Returns the previous setting. +bool ic_enable_multiline( bool enable ); + +/// Disable or enable sound (enabled by default). +/// A beep is used when tab cannot find any completion for example. +/// Returns the previous setting. +bool ic_enable_beep( bool enable ); + +/// Disable or enable color output (enabled by default). +/// Returns the previous setting. +bool ic_enable_color( bool enable ); + +/// Disable or enable duplicate entries in the history (disabled by default). +/// Returns the previous setting. +bool ic_enable_history_duplicates( bool enable ); + +/// Disable or enable automatic tab completion after a completion +/// to expand as far as possible if the completions are unique. (disabled by default). +/// Returns the previous setting. +bool ic_enable_auto_tab( bool enable ); + +/// Disable or enable preview of a completion selection (enabled by default) +/// Returns the previous setting. +bool ic_enable_completion_preview( bool enable ); + +/// Disable or enable automatic identation of continuation lines in multiline +/// input so it aligns with the initial prompt. +/// Returns the previous setting. +bool ic_enable_multiline_indent(bool enable); + +/// Disable or enable display of short help messages for history search etc. +/// (full help is always dispayed when pressing F1 regardless of this setting) +/// @returns the previous setting. +bool ic_enable_inline_help(bool enable); + +/// Disable or enable hinting (enabled by default) +/// Shows a hint inline when there is a single possible completion. +/// @returns the previous setting. +bool ic_enable_hint(bool enable); + +/// Set millisecond delay before a hint is displayed. Can be zero. (500ms by default). +long ic_set_hint_delay(long delay_ms); + +/// Disable or enable syntax highlighting (enabled by default). +/// This applies regardless whether a syntax highlighter callback was set (`ic_set_highlighter`) +/// Returns the previous setting. +bool ic_enable_highlight(bool enable); + + +/// Set millisecond delay for reading escape sequences in order to distinguish +/// a lone ESC from the start of a escape sequence. The defaults are 100ms and 10ms, +/// but it may be increased if working with very slow terminals. +void ic_set_tty_esc_delay(long initial_delay_ms, long followup_delay_ms); + +/// Enable highlighting of matching braces (and error highlight unmatched braces).` +bool ic_enable_brace_matching(bool enable); + +/// Set matching brace pairs. +/// Pass \a NULL for the default `"()[]{}"`. +void ic_set_matching_braces(const char* brace_pairs); + +/// Enable automatic brace insertion (enabled by default). +bool ic_enable_brace_insertion(bool enable); + +/// Set matching brace pairs for automatic insertion. +/// Pass \a NULL for the default `()[]{}\"\"''` +void ic_set_insertion_braces(const char* brace_pairs); + +/// \} + + +//-------------------------------------------------------------- +// Advanced Completion +//-------------------------------------------------------------- + +/// \defgroup completex Advanced Completion +/// \{ + +/// Get the raw current input (and cursor position if `cursor` != NULL) for the completion. +/// Usually completer functions should look at their `prefix` though as transformers +/// like `ic_complete_word` may modify the prefix (for example, unescape it). +const char* ic_completion_input( ic_completion_env_t* cenv, long* cursor ); + +/// Get the completion argument passed to `ic_set_completer`. +void* ic_completion_arg( const ic_completion_env_t* cenv ); + +/// Do we have already some completions? +bool ic_has_completions( const ic_completion_env_t* cenv ); + +/// Do we already have enough completions and should we return if possible? (for improved latency) +bool ic_stop_completing( const ic_completion_env_t* cenv); + + +/// Primitive completion, cannot be used with most transformers (like `ic_complete_word` and `ic_complete_qword`). +/// When completed, `delete_before` _bytes_ are deleted before the cursor position, +/// `delete_after` _bytes_ are deleted after the cursor, and finally `completion` is inserted. +/// The `display` is used to display the completion in the completion menu, and `help` is displayed +/// with hinting. Both `display` and `help` can be NULL. +/// (all are copied by isocline and do not need to be preserved or allocated). +/// +/// Returns `true` if the callback should continue trying to find more possible completions. +/// If `false` is returned, the callback should try to return and not add more completions (for improved latency). +bool ic_add_completion_prim( ic_completion_env_t* cenv, const char* completion, + const char* display, const char* help, + long delete_before, long delete_after); + +/// \} + +//-------------------------------------------------------------- +/// \defgroup helper Character Classes. +/// Convenience functions for character classes, highlighting and completion. +/// \{ + +/// Convenience: return the position of a previous code point in a UTF-8 string `s` from postion `pos`. +/// Returns `-1` if `pos <= 0` or `pos > strlen(s)` (or other errors). +long ic_prev_char( const char* s, long pos ); + +/// Convenience: return the position of the next code point in a UTF-8 string `s` from postion `pos`. +/// Returns `-1` if `pos < 0` or `pos >= strlen(s)` (or other errors). +long ic_next_char( const char* s, long pos ); + +/// Convenience: does a string `s` starts with a given `prefix` ? +bool ic_starts_with( const char* s, const char* prefix ); + +/// Convenience: does a string `s` starts with a given `prefix` ignoring (ascii) case? +bool ic_istarts_with( const char* s, const char* prefix ); + + +/// Convenience: character class for whitespace `[ \t\r\n]`. +bool ic_char_is_white(const char* s, long len); + +/// Convenience: character class for non-whitespace `[^ \t\r\n]`. +bool ic_char_is_nonwhite(const char* s, long len); + +/// Convenience: character class for separators. +/// (``[ \t\r\n,.;:/\\(){}\[\]]``.) +/// This is used for word boundaries in isocline. +bool ic_char_is_separator(const char* s, long len); + +/// Convenience: character class for non-separators. +bool ic_char_is_nonseparator(const char* s, long len); + +/// Convenience: character class for letters (`[A-Za-z]` and any unicode > 0x80). +bool ic_char_is_letter(const char* s, long len); + +/// Convenience: character class for digits (`[0-9]`). +bool ic_char_is_digit(const char* s, long len); + +/// Convenience: character class for hexadecimal digits (`[A-Fa-f0-9]`). +bool ic_char_is_hexdigit(const char* s, long len); + +/// Convenience: character class for identifier letters (`[A-Za-z0-9_-]` and any unicode > 0x80). +bool ic_char_is_idletter(const char* s, long len); + +/// Convenience: character class for filename letters (_not in_ " \t\r\n`@$><=;|&\{\}\(\)\[\]]"). +bool ic_char_is_filename_letter(const char* s, long len); + + +/// Convenience: If this is a token start, return the length. Otherwise return 0. +long ic_is_token(const char* s, long pos, ic_is_char_class_fun_t* is_token_char); + +/// Convenience: Does this match the specified token? +/// Ensures not to match prefixes or suffixes, and returns the length of the match (in bytes). +/// E.g. `ic_match_token("function",0,&ic_char_is_letter,"fun")` returns 0. +/// while `ic_match_token("fun x",0,&ic_char_is_letter,"fun"})` returns 3. +long ic_match_token(const char* s, long pos, ic_is_char_class_fun_t* is_token_char, const char* token); + + +/// Convenience: Do any of the specified tokens match? +/// Ensures not to match prefixes or suffixes, and returns the length of the match (in bytes). +/// E.g. `ic_match_any_token("function",0,&ic_char_is_letter,{"fun","func",NULL})` returns 0. +/// while `ic_match_any_token("func x",0,&ic_char_is_letter,{"fun","func",NULL})` returns 4. +long ic_match_any_token(const char* s, long pos, ic_is_char_class_fun_t* is_token_char, const char** tokens); + +/// \} + +//-------------------------------------------------------------- +/// \defgroup term Terminal +/// +/// Experimental: Low level terminal output. +/// Ensures basic ANSI SGR escape sequences are processed +/// in a portable way (e.g. on Windows) +/// \{ + +/// Initialize for terminal output. +/// Call this before using the terminal write functions (`ic_term_write`) +/// Does nothing on most platforms but on Windows it sets the console to UTF8 output and possible +/// enables virtual terminal processing. +void ic_term_init(void); + +/// Call this when done with the terminal functions. +void ic_term_done(void); + +/// Flush the terminal output. +/// (happens automatically on newline characters ('\n') as well). +void ic_term_flush(void); + +/// Write a string to the console (and process CSI escape sequences). +void ic_term_write(const char* s); + +/// Write a string to the console and end with a newline +/// (and process CSI escape sequences). +void ic_term_writeln(const char* s); + +/// Write a formatted string to the console. +/// (and process CSI escape sequences) +void ic_term_writef(const char* fmt, ...); + +/// Write a formatted string to the console. +void ic_term_vwritef(const char* fmt, va_list args); + +/// Set text attributes from a style. +void ic_term_style( const char* style ); + +/// Set text attribute to bold. +void ic_term_bold(bool enable); + +/// Set text attribute to underline. +void ic_term_underline(bool enable); + +/// Set text attribute to italic. +void ic_term_italic(bool enable); + +/// Set text attribute to reverse video. +void ic_term_reverse(bool enable); + +/// Set text attribute to ansi color palette index between 0 and 255 (or 256 for the ANSI "default" color). +/// (auto matched to smaller palette if not supported) +void ic_term_color_ansi(bool foreground, int color); + +/// Set text attribute to 24-bit RGB color (between `0x000000` and `0xFFFFFF`). +/// (auto matched to smaller palette if not supported) +void ic_term_color_rgb(bool foreground, uint32_t color ); + +/// Reset the text attributes. +void ic_term_reset( void ); + +/// Get the palette used by the terminal: +/// This is usually initialized from the COLORTERM environment variable. The +/// possible values of COLORTERM for each palette are given in parenthesis. +/// +/// - 1: monochrome (`monochrome`) +/// - 3: old ANSI terminal with 8 colors, using bold for bright (`8color`/`3bit`) +/// - 4: regular ANSI terminal with 16 colors. (`16color`/`4bit`) +/// - 8: terminal with ANSI 256 color palette. (`256color`/`8bit`) +/// - 24: true-color terminal with full RGB colors. (`truecolor`/`24bit`/`direct`) +int ic_term_get_color_bits( void ); + +/// \} + +//-------------------------------------------------------------- +/// \defgroup async ASync +/// Async support +/// \{ + +/// Thread-safe way to asynchronously unblock a readline. +/// Behaves as if the user pressed the `ctrl-C` character +/// (resulting in returning NULL from `ic_readline`). +/// Returns `true` if the event was successfully delivered. +/// (This may not be supported on all platforms, but it is +/// functional on Linux, macOS and Windows). +bool ic_async_stop(void); + +/// \} + +//-------------------------------------------------------------- +/// \defgroup alloc Custom Allocation +/// Register allocation functions for custom allocators +/// \{ + +typedef void* (ic_malloc_fun_t)( size_t size ); +typedef void* (ic_realloc_fun_t)( void* p, size_t newsize ); +typedef void (ic_free_fun_t)( void* p ); + +/// Initialize with custom allocation functions. +/// This must be called as the first function in a program! +void ic_init_custom_alloc( ic_malloc_fun_t* _malloc, ic_realloc_fun_t* _realloc, ic_free_fun_t* _free ); + +/// Free a potentially custom alloc'd pointer (in particular, the result returned from `ic_readline`) +void ic_free( void* p ); + +/// Allocate using the current memory allocator. +void* ic_malloc(size_t sz); + +/// Duplicate a string using the current memory allocator. +const char* ic_strdup( const char* s ); + +/// \} + +#ifdef __cplusplus +} +#endif + +#endif /// IC_ISOCLINE_H diff --git a/extern/isocline/readme.md b/extern/isocline/readme.md new file mode 100644 index 000000000..1f4709bdc --- /dev/null +++ b/extern/isocline/readme.md @@ -0,0 +1,460 @@ + + + + +# Isocline: a portable readline alternative. + +Isocline is a pure C library that can be used as an alternative to the GNU readline library (latest release v1.0.9, 2022-01-15). + +- Small: less than 8k lines and can be compiled as a single C file without + any dependencies or configuration (e.g. `gcc -c src/isocline.c`). + +- Portable: works on Unix, Windows, and macOS, and uses a minimal + subset of ANSI escape sequences. + +- Features: extensive multi-line editing mode (`shift-tab`), (24-bit) color, history, completion, unicode, + undo/redo, incremental history search, inline hints, syntax highlighting, brace matching, + closing brace insertion, auto indentation, graceful fallback, support for custom allocators, etc. + +- License: MIT. + +- Comes with a Haskell binding ([`System.Console.Isocline`][hdoc]. + +Enjoy, + Daan + + + +# Demo + +![recording](doc/record-macos.svg) + +Shows in order: unicode, syntax highlighting, brace matching, jump to matching brace, auto indent, multiline editing, 24-bit colors, inline hinting, filename completion, and incremental history search. +(screen capture was made with [termtosvg] by Nicolas Bedos) + +# Usage + +Include the isocline header in your C or C++ source: +```C +#include +``` + +and call `ic_readline` to get user input with rich editing abilities: +```C +char* input; +while( (input = ic_readline("prompt")) != NULL ) { // ctrl+d/c or errors return NULL + printf("you typed:\n%s\n", input); // use the input + free(input); +} +``` + +See the [example] for a full example with completion, syntax highligting, history, etc. + +# Run the Example + +You can compile and run the [example] as: +``` +$ gcc -o example -Iinclude test/example.c src/isocline.c +$ ./example +``` + +or, the Haskell [example][HaskellExample]: +``` +$ ghc -ihaskell test/Example.hs src/isocline.c +$ ./test/Example +``` + + +# Editing with Isocline + +Isocline tries to be as compatible as possible with standard [GNU Readline] key bindings. + +### Overview: +```apl + home/ctrl-a cursor end/ctrl-e + ┌─────────────────┼───────────────┐ (navigate) + │ ctrl-left │ ctrl-right │ + │ ┌───────┼──────┐ │ ctrl-r : search history + ▼ ▼ ▼ ▼ ▼ tab : complete word + prompt> it is the quintessential language shift-tab: insert new line + ▲ ▲ ▲ ▲ esc : delete input, done + │ └──────────────┘ │ ctrl-z : undo + │ alt-backsp alt-d │ + └─────────────────────────────────┘ (delete) + ctrl-u ctrl-k +``` + +Note: on macOS, the meta (alt) key is not directly available in most terminals. +Terminal/iTerm2 users can activate the meta key through +`Terminal` → `Preferences` → `Settings` → `Use option as meta key`. + +### Key Bindings + +These are also shown when pressing `F1` on a Isocline prompt. We use `^` as a shorthand for `ctrl-`: + +| Navigation | | +|-------------------|-------------------------------------------------| +| `left`,`^b` | go one character to the left | +| `right`,`^f ` | go one character to the right | +| `up ` | go one row up, or back in the history | +| `down ` | go one row down, or forward in the history | +| `^left ` | go to the start of the previous word | +| `^right ` | go to the end the current word | +| `home`,`^a ` | go to the start of the current line | +| `end`,`^e ` | go to the end of the current line | +| `pgup`,`^home ` | go to the start of the current input | +| `pgdn`,`^end ` | go to the end of the current input | +| `alt-m ` | jump to matching brace | +| `^p ` | go back in the history | +| `^n ` | go forward in the history | +| `^r`,`^s ` | search the history starting with the current word | + + +| Deletion | | +|-------------------|-------------------------------------------------| +| `del`,`^d ` | delete the current character | +| `backsp`,`^h ` | delete the previous character | +| `^w ` | delete to preceding white space | +| `alt-backsp ` | delete to the start of the current word | +| `alt-d ` | delete to the end of the current word | +| `^u ` | delete to the start of the current line | +| `^k ` | delete to the end of the current line | +| `esc ` | delete the current input, or done with empty input | + + +| Editing | | +|-------------------|-------------------------------------------------| +| `enter ` | accept current input | +| `^enter`,`^j`,`shift-tab` | create a new line for multi-line input | +| `^l ` | clear screen | +| `^t ` | swap with previous character (move character backward) | +| `^z`,`^_ ` | undo | +| `^y ` | redo | +| `tab ` | try to complete the current input | + + +| Completion menu | | +|-------------------|-------------------------------------------------| +| `enter`,`left` | use the currently selected completion | +| `1` - `9` | use completion N from the menu | +| `tab, down ` | select the next completion | +| `shift-tab, up` | select the previous completion | +| `esc ` | exit menu without completing | +| `pgdn`,`^enter`,`^j` | show all further possible completions | + + +| Incremental history search | | +|-------------------|-------------------------------------------------| +| `enter ` | use the currently found history entry | +| `backsp`,`^z ` | go back to the previous match (undo) | +| `tab`,`^r`,`up` | find the next match | +| `shift-tab`,`^s`,`down` | find an earlier match | +| `esc ` | exit search | + + +# Build the Library + +### Build as a Single Source + +Copy the sources (in `include` and `src`) into your project, or add the library as a [submodule]: +``` +$ git submodule add https://github.com/daanx/isocline +``` +and add `isocline/src/isocline.c` to your build rules -- no configuration is needed. + +### Build with CMake + +Clone the repository and run cmake to build a static library (`.a`/`.lib`): +``` +$ git clone https://github.com/daanx/isocline +$ cd isocline +$ mkdir -p build/release +$ cd build/release +$ cmake ../.. +$ cmake --build . +``` +This builds a static library `libisocline.a` (or `isocline.lib` on Windows) +and the example program: +``` +$ ./example +``` + +### Build the Haskell Library + +See the Haskell [readme][Haskell] for instructions to build and use the Haskell library. + + +# API Reference + +* See the [C API reference][docapi] and the [example] for example usage of history, completion, etc. + +* See the [Haskell API reference][hdoc] on Hackage and the Haskell [example][HaskellExample]. + + +# Motivation + +Isocline was created for use in the [Koka] interactive compiler. +This required: pure C (no dependency on a C++ runtime or other libraries), +portable (across Linux, macOS, and Windows), unicode support, +a BSD-style license, and good functionality for completion and multi-line editing. + +Some other excellent libraries that we considered: +[GNU readline], +[editline](https://github.com/troglobit/editline), +[linenoise](https://github.com/antirez/linenoise), +[replxx](https://github.com/AmokHuginnsson/replxx), and +[Haskeline](https://github.com/judah/haskeline). + + +# Formatted Output + +Isocline also exposes functions for rich terminal output +as `ic_print` (and `ic_println` and `ic_printf`). +Inspired by the (Python) [Rich][RichBBcode] library, +this supports a form of [bbcode]'s to format the output: +```c +ic_println( "[b]bold [red]and red[/red][/b]" ); +``` +Each print automatically closes any open tags that were +not yet closed. Also, you can use a general close +tag as `[/]` to close the innermost tag, so the +following print is equivalent to the earlier one: +```c +ic_println( "[b]bold [red]and red[/]" ); +``` +There can be multiple styles in one tag +(where the first name is used for the closing tag): +```c +ic_println( "[u #FFD700]underlined gold[/]" ); +``` + +Sometimes, you need to display arbitrary messages +that may contain sequences that you would not like +to be interpreted as bbcode tags. One way to do +this is the `[!`_tag_`]` which ignores formatting +up to a close tag of the form `[/`_tag_`]`. +```c +ic_printf( "[red]red? [!pre]%s[/pre].\n", "[blue]not blue!" ); +``` + +Predefined styles include `b` (bold), +`u` (underline), `i` (italic), and `r` (reverse video), but +you can (re)define any style yourself as: +```c +ic_style_def("warning", "crimson u"); +``` + +and use them like any builtin style or property: +```c +ic_println( "[warning]this is a warning![/]" ); +``` +which is great for adding themes to your application. + +Each `ic_print` function always closes any unclosed tags automatically. +To open a style persistently, use `ic_style_open` with a matching +`ic_style_close` which scopes over any `ic_print` statements in between. +```c +ic_style_open("warning"); +ic_println("[b]crimson underlined and bold[/]"); +ic_style_close(); +``` + +# Advanced + + +## BBCode Format + +An open tag can have multiple white space separated +entries that are +either a _style name_, or a primitive _property_[`=`_value_]. + +### Styles + +Isocline provides the following builtin styles as property shorthands: +`b` (bold), `u` (underline), `i` (italic), `r` (reverse video), +and some builtin styles for syntax highlighting: +`keyword`, `control` (control-flow keywords), `string`, +`comment`, `number`, `type`, `constant`. + +Predefined styles used by Isocline itself are: + +- `ic-prompt`: prompt style, e.g. `ic_style_def("ic-prompt", "yellow on blue")`. +- `ic-info`: information (like the numbers in a completion menu). +- `ic-diminish`: dim text (used for example in history search). +- `ic-emphasis`: emphasized text (also used in history search). +- `ic-hint`: color of an inline hint. +- `ic-error`: error color (like an unmatched brace). +- `ic-bracematch`: color of matching parenthesis. + +### Properties + +Boolean properties are by default `on`: + +- `bold` [`=`(`on`|`off`)] +- `italic` [`=`(`on`|`off`)] +- `underline` [`=`(`on`|`off`)] +- `reverse` [`=`(`on`|`off`)] + +Color properties can be assigned a _color_: + +- `color=`_color_ +- `bgcolor=`_color_ +- _color_: equivalent to `color=`_color_. +- `on` _color_: equivalent to `bgcolor=`_color_. + +A color value can be specified in many ways: + +- any standard HTML [color name][htmlcolors]. +- any of the 16 standard ANSI [color names][ansicolors] by prefixing `ansi-` + (like `ansi-black` or `ansi-maroon`). + The actual color value of these depend on the a terminal theme. +- `#`_rrggbb_ or `#`_rgb_ for a specific 24-bit color. +- `ansi-color=`_idx_: where 0 <= _idx_ <= 256 specifies an entry in the + standard ANSI 256 [color palette][ansicolor256], where 256 is used for the ANSI + default color. + + +## Environment Variables + +- `NO_COLOR`: if present no colors are displayed. +- `CLICOLOR=1`: if set, the `LSCOLORS` or `LS_COLORS` environment variables are used to colorize + filename completions. +- `COLORTERM=`(`truecolor`|`256color`|`16color`|`8color`|`monochrome`): enable a certain color palette, see the next section. +- `TERM`: used on some systems to determine the color + +## Colors + +Isocline supports 24-bit colors and any RGB colors are automatically +mapped to a reduced palette on older terminals if these do not +support true color. Detection of full color support +is not always possible to do automatically and you can +set the `COLORTERM` environment variable expicitly to force Isocline to use +a specific palette: +- `COLORTERM=truecolor`: use 24-bit colors. + +- `COLORTERM=256color`: use the ANSI 256 color palette. + +- `COLORTERM=16color` : use the regular ANSI 16 color + palette (8 normal and 8 bright colors). + +- `COLORTERM=8color`: use bold for bright colors. +- `COLORTERM=monochrome`: use no color. + +The above screenshots are made with the +[`test_colors.c`](https://github.com/daanx/isocline/blob/main/test/test_colors.c) program. You can test your own +terminal as: +``` +$ gcc -o test_colors -Iinclude test/test_colors.c src/isocline.c +$ ./test_colors +$ COLORTERM=truecolor ./test_colors +$ COLORTERM=16color ./test_colors +``` + +## ANSI Escape Sequences + +Isocline uses just few ANSI escape sequences that are widely +supported: +- `ESC[`_n_`A`, `ESC[`_n_`B`, `ESC[`_n_`C`, and `ESC[`_n_`D`, + for moving the cursor _n_ places up, down, right, and left. +- `ESC[K` to clear the line from the cursor. +- `ESC[`_n_`m` for colors, with _n_ one of: 0 (reset), 1,22 (bold), 3,23 (italic), + 4,24 (underline), 7,27 (reverse), 30-37,40-47,90-97,100-107 (color), + and 39,49 (select default color). +- `ESC[38;5;`_n_`m`, `ESC[48;5;`_n_`m`, `ESC[38;2;`_r_`;`_g_`;`_b_`m`, `ESC[48;2;`_r_`;`_g_`;`_b_`m`: + on terminals that support it, select + entry _n_ from the + 256 color ANSI palette (used with `XTERM=xterm-256color` for example), or directly specify + any 24-bit _rgb_ color (used with `COLORTERM=truecolor`) for the foreground or background. + +On Windows the above functionality is implemented using the Windows console API +(except if running in the new Windows Terminal which supports these escape +sequences natively). + +## Async and Threads + +Isocline is _not_ thread-safe and `ic_readline`_xxx_ and `ic_print`_xxx_ should +be used from one thread only. + +The best way to use `ic_readline` asynchronously is +to run it in a (blocking) dedicated thread and deliver +results from there to the async event loop. Isocline has the +```C +bool ic_async_stop(void) +``` +function that is thread-safe and can deliver an +asynchronous event to Isocline that unblocks a current +`ic_readline` and makes it behave as if the user pressed +`ctrl-c` (which returns NULL from the read line call). + +## Color Mapping + +To map full RGB colors to an ANSI 256 or 16-color palette +Isocline finds a palette color with the minimal "color distance" to +the original color. There are various +ways of calculating this: one way is to take the euclidean distance +in the sRGB space (_simple-rgb_), a slightly better way is to +take a weighted distance where the weight distribution is adjusted +according to how big the red component is ([redmean](https://en.wikipedia.org/wiki/Color_difference), +denoted as _delta-rgb_ in the figure), +this is used by Isocline), +and finally, we can first translate into a perceptually uniform color space +(CIElab) and calculate the distance there using the [CIEDE2000](https://en.wikipedia.org/wiki/Color_difference) +algorithm (_ciede2000_). Here are these three methods compared on +some colors: + +![color space comparison](doc/color/colorspace-map.png) + +Each top row is the true 24-bit RGB color. Surprisingly, +the sophisticated CIEDE2000 distance seems less good here compared to the +simpler methods (as in the upper left block for example) +(perhaps because this algorithm was created to find close +perceptual colors in images where lightness differences may be given +less weight?). CIEDE2000 also leads to more "outliers", for example as seen +in column 5. Given these results, Isocline uses _redmean_ for +color mapping. We also add a gray correction that makes it less +likely to substitute a color for a gray value (and the other way +around). + + +## Possible Future Extensions + +- Vi key bindings. +- kill buffer. +- make the `ic_print`_xxx_ functions thread-safe. +- extended low-level terminal functions. +- status and progress bars. +- prompt variants: confirm, etc. +- ... + +Contact me if you are interested in doing any of these :-) + + +# Releases + +* `2022-01-15`: v1.0.9: fix missing `ic_completion_arg` (issue #6), + fix null ptr check in ic_print (issue #7), fix crash when using /dev/null as both input and output. +* `2021-09-05`: v1.0.5: use our own wcwidth for consistency; + thanks to Hans-Georg Breunig for helping with testing on NetBSD. +* `2021-08-28`: v1.0.4: fix color query on Ubuntu/Gnome +* `2021-08-27`: v1.0.3: fix duplicates in completions +* `2021-08-23`: v1.0.2: fix windows eol wrapping +* `2021-08-21`: v1.0.1: fix line-buffering +* `2021-08-20`: v1.0.0: initial release + + + +[GNU readline]: https://tiswww.case.edu/php/chet/readline/rltop.html +[koka]: http://www.koka-lang.org +[submodule]: https://git-scm.com/book/en/v2/Git-Tools-Submodules +[Haskell]: https://github.com/daanx/isocline/tree/main/haskell +[HaskellExample]: https://github.com/daanx/isocline/blob/main/test/Example.hs +[example]: https://github.com/daanx/isocline/blob/main/test/example.c +[termtosvg]: https://github.com/nbedos/termtosvg +[Rich]: https://github.com/willmcgugan/rich +[RichBBcode]: https://rich.readthedocs.io/en/latest/markup.html +[bbcode]: https://en.wikipedia.org/wiki/BBCode +[htmlcolors]: https://en.wikipedia.org/wiki/Web_colors#HTML_color_names +[ansicolors]: https://en.wikipedia.org/wiki/Web_colors#Basic_colors +[ansicolor256]: https://en.wikipedia.org/wiki/ANSI_escape_code#8-bit +[docapi]: https://daanx.github.io/isocline +[hdoc]: https://hackage.haskell.org/package/isocline/docs/System-Console-Isocline.html diff --git a/extern/isocline/src/attr.c b/extern/isocline/src/attr.c new file mode 100644 index 000000000..b5ad78f8a --- /dev/null +++ b/extern/isocline/src/attr.c @@ -0,0 +1,294 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#include + +#include "common.h" +#include "stringbuf.h" // str_next_ofs +#include "attr.h" +#include "term.h" // color_from_ansi256 + +//------------------------------------------------------------- +// Attributes +//------------------------------------------------------------- + +ic_private attr_t attr_none(void) { + attr_t attr; + attr.value = 0; + return attr; +} + +ic_private attr_t attr_default(void) { + attr_t attr = attr_none(); + attr.x.color = IC_ANSI_DEFAULT; + attr.x.bgcolor = IC_ANSI_DEFAULT; + attr.x.bold = IC_OFF; + attr.x.underline = IC_OFF; + attr.x.reverse = IC_OFF; + attr.x.italic = IC_OFF; + return attr; +} + +ic_private bool attr_is_none(attr_t attr) { + return (attr.value == 0); +} + +ic_private bool attr_is_eq(attr_t attr1, attr_t attr2) { + return (attr1.value == attr2.value); +} + +ic_private attr_t attr_from_color( ic_color_t color ) { + attr_t attr = attr_none(); + attr.x.color = color; + return attr; +} + + +ic_private attr_t attr_update_with( attr_t oldattr, attr_t newattr ) { + attr_t attr = oldattr; + if (newattr.x.color != IC_COLOR_NONE) { attr.x.color = newattr.x.color; } + if (newattr.x.bgcolor != IC_COLOR_NONE) { attr.x.bgcolor = newattr.x.bgcolor; } + if (newattr.x.bold != IC_NONE) { attr.x.bold = newattr.x.bold; } + if (newattr.x.italic != IC_NONE) { attr.x.italic = newattr.x.italic; } + if (newattr.x.reverse != IC_NONE) { attr.x.reverse = newattr.x.reverse; } + if (newattr.x.underline != IC_NONE) { attr.x.underline = newattr.x.underline; } + return attr; +} + +static bool sgr_is_digit(char c) { + return (c >= '0' && c <= '9'); +} + +static bool sgr_is_sep( char c ) { + return (c==';' || c==':'); +} + +static bool sgr_next_par(const char* s, ssize_t* pi, ssize_t* par) { + const ssize_t i = *pi; + ssize_t n = 0; + while( sgr_is_digit(s[i+n])) { + n++; + } + if (n==0) { + *par = 0; + return true; + } + else { + *pi = i+n; + return ic_atoz(s+i, par); + } +} + +static bool sgr_next_par3(const char* s, ssize_t* pi, ssize_t* p1, ssize_t* p2, ssize_t* p3) { + bool ok = false; + ssize_t i = *pi; + if (sgr_next_par(s,&i,p1) && sgr_is_sep(s[i])) { + i++; + if (sgr_next_par(s,&i,p2) && sgr_is_sep(s[i])) { + i++; + if (sgr_next_par(s,&i,p3)) { + ok = true; + }; + } + } + *pi = i; + return ok; +} + +ic_private attr_t attr_from_sgr( const char* s, ssize_t len) { + attr_t attr = attr_none(); + for( ssize_t i = 0; i < len && s[i] != 0; i++) { + ssize_t cmd = 0; + if (!sgr_next_par(s,&i,&cmd)) continue; + switch(cmd) { + case 0: attr = attr_default(); break; + case 1: attr.x.bold = IC_ON; break; + case 3: attr.x.italic = IC_ON; break; + case 4: attr.x.underline = IC_ON; break; + case 7: attr.x.reverse = IC_ON; break; + case 22: attr.x.bold = IC_OFF; break; + case 23: attr.x.italic = IC_OFF; break; + case 24: attr.x.underline = IC_OFF; break; + case 27: attr.x.reverse = IC_OFF; break; + case 39: attr.x.color = IC_ANSI_DEFAULT; break; + case 49: attr.x.bgcolor = IC_ANSI_DEFAULT; break; + default: { + if (cmd >= 30 && cmd <= 37) { + attr.x.color = IC_ANSI_BLACK + (unsigned)(cmd - 30); + } + else if (cmd >= 40 && cmd <= 47) { + attr.x.bgcolor = IC_ANSI_BLACK + (unsigned)(cmd - 40); + } + else if (cmd >= 90 && cmd <= 97) { + attr.x.color = IC_ANSI_DARKGRAY + (unsigned)(cmd - 90); + } + else if (cmd >= 100 && cmd <= 107) { + attr.x.bgcolor = IC_ANSI_DARKGRAY + (unsigned)(cmd - 100); + } + else if ((cmd == 38 || cmd == 48) && sgr_is_sep(s[i])) { + // non-associative SGR :-( + ssize_t par = 0; + i++; + if (sgr_next_par(s, &i, &par)) { + if (par==5 && sgr_is_sep(s[i])) { + // ansi 256 index + i++; + if (sgr_next_par(s, &i, &par) && par >= 0 && par <= 0xFF) { + ic_color_t color = color_from_ansi256(par); + if (cmd==38) { attr.x.color = color; } + else { attr.x.bgcolor = color; } + } + } + else if (par == 2 && sgr_is_sep(s[i])) { + // rgb value + i++; + ssize_t r,g,b; + if (sgr_next_par3(s, &i, &r,&g,&b)) { + ic_color_t color = ic_rgbx(r,g,b); + if (cmd==38) { attr.x.color = color; } + else { attr.x.bgcolor = color; } + } + } + } + } + else { + debug_msg("attr: unknow ANSI SGR code: %zd\n", cmd ); + } + } + } + } + return attr; +} + +ic_private attr_t attr_from_esc_sgr( const char* s, ssize_t len) { + if (len <= 2 || s[0] != '\x1B' || s[1] != '[' || s[len-1] != 'm') return attr_none(); + return attr_from_sgr(s+2, len-2); +} + + +//------------------------------------------------------------- +// Attribute buffer +//------------------------------------------------------------- +struct attrbuf_s { + attr_t* attrs; + ssize_t capacity; + ssize_t count; + alloc_t* mem; +}; + +static bool attrbuf_ensure_capacity( attrbuf_t* ab, ssize_t needed ) { + if (needed <= ab->capacity) return true; + ssize_t newcap = (ab->capacity <= 0 ? 240 : (ab->capacity > 1000 ? ab->capacity + 1000 : 2*ab->capacity)); + if (needed > newcap) { newcap = needed; } + attr_t* newattrs = mem_realloc_tp( ab->mem, attr_t, ab->attrs, newcap ); + if (newattrs == NULL) return false; + ab->attrs = newattrs; + ab->capacity = newcap; + assert(needed <= ab->capacity); + return true; +} + +static bool attrbuf_ensure_extra( attrbuf_t* ab, ssize_t extra ) { + const ssize_t needed = ab->count + extra; + return attrbuf_ensure_capacity( ab, needed ); +} + + +ic_private attrbuf_t* attrbuf_new( alloc_t* mem ) { + attrbuf_t* ab = mem_zalloc_tp(mem,attrbuf_t); + if (ab == NULL) return NULL; + ab->mem = mem; + attrbuf_ensure_extra(ab,1); + return ab; +} + +ic_private void attrbuf_free( attrbuf_t* ab ) { + if (ab==NULL) return; + mem_free(ab->mem, ab->attrs); + mem_free(ab->mem, ab); +} + +ic_private void attrbuf_clear(attrbuf_t* ab) { + if (ab == NULL) return; + ab->count = 0; +} + +ic_private ssize_t attrbuf_len( attrbuf_t* ab ) { + return (ab==NULL ? 0 : ab->count); +} + +ic_private const attr_t* attrbuf_attrs( attrbuf_t* ab, ssize_t expected_len ) { + assert(expected_len <= ab->count ); + // expand if needed + if (ab->count < expected_len) { + if (!attrbuf_ensure_capacity(ab,expected_len)) return NULL; + for(ssize_t i = ab->count; i < expected_len; i++) { + ab->attrs[i] = attr_none(); + } + ab->count = expected_len; + } + return ab->attrs; +} + + + +static void attrbuf_update_set_at( attrbuf_t* ab, ssize_t pos, ssize_t count, attr_t attr, bool update ) { + const ssize_t end = pos + count; + if (!attrbuf_ensure_capacity(ab, end)) return; + ssize_t i; + // initialize if end is beyond the count (todo: avoid duplicate init and set if update==false?) + if (ab->count < end) { + for(i = ab->count; i < end; i++) { + ab->attrs[i] = attr_none(); + } + ab->count = end; + } + // fill pos to end with attr + for(i = pos; i < end; i++) { + ab->attrs[i] = (update ? attr_update_with(ab->attrs[i],attr) : attr); + } +} + +ic_private void attrbuf_set_at( attrbuf_t* ab, ssize_t pos, ssize_t count, attr_t attr ) { + attrbuf_update_set_at(ab, pos, count, attr, false); +} + +ic_private void attrbuf_update_at( attrbuf_t* ab, ssize_t pos, ssize_t count, attr_t attr ) { + attrbuf_update_set_at(ab, pos, count, attr, true); +} + +ic_private void attrbuf_insert_at( attrbuf_t* ab, ssize_t pos, ssize_t count, attr_t attr ) { + if (pos < 0 || pos > ab->count || count <= 0) return; + if (!attrbuf_ensure_extra(ab,count)) return; + ic_memmove( ab->attrs + pos + count, ab->attrs + pos, (ab->count - pos)*ssizeof(attr_t) ); + ab->count += count; + attrbuf_set_at( ab, pos, count, attr ); +} + + +// note: must allow ab == NULL! +ic_private ssize_t attrbuf_append_n( stringbuf_t* sb, attrbuf_t* ab, const char* s, ssize_t len, attr_t attr ) { + if (s == NULL || len == 0) return sbuf_len(sb); + if (ab != NULL) { + if (!attrbuf_ensure_extra(ab,len)) return sbuf_len(sb); + attrbuf_set_at(ab, ab->count, len, attr); + } + return sbuf_append_n(sb,s,len); +} + +ic_private attr_t attrbuf_attr_at( attrbuf_t* ab, ssize_t pos ) { + if (ab==NULL || pos < 0 || pos > ab->count) return attr_none(); + return ab->attrs[pos]; +} + +ic_private void attrbuf_delete_at( attrbuf_t* ab, ssize_t pos, ssize_t count ) { + if (ab==NULL || pos < 0 || pos > ab->count) return; + if (pos + count > ab->count) { count = ab->count - pos; } + if (count == 0) return; + assert(pos + count <= ab->count); + ic_memmove( ab->attrs + pos, ab->attrs + pos + count, ab->count - (pos + count) ); + ab->count -= count; +} diff --git a/extern/isocline/src/attr.h b/extern/isocline/src/attr.h new file mode 100644 index 000000000..8f37d0500 --- /dev/null +++ b/extern/isocline/src/attr.h @@ -0,0 +1,70 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_ATTR_H +#define IC_ATTR_H + +#include "common.h" +#include "stringbuf.h" + +//------------------------------------------------------------- +// text attributes +//------------------------------------------------------------- + +#define IC_ON (1) +#define IC_OFF (-1) +#define IC_NONE (0) + +// try to fit in 64 bits +// note: order is important for some compilers +// note: each color can actually be 25 bits +typedef union attr_s { + struct { + unsigned int color:28; + signed int bold:2; + signed int reverse:2; + unsigned int bgcolor:28; + signed int underline:2; + signed int italic:2; + } x; + uint64_t value; +} attr_t; + +ic_private attr_t attr_none(void); +ic_private attr_t attr_default(void); +ic_private attr_t attr_from_color( ic_color_t color ); + +ic_private bool attr_is_none(attr_t attr); +ic_private bool attr_is_eq(attr_t attr1, attr_t attr2); + +ic_private attr_t attr_update_with( attr_t attr, attr_t newattr ); + +ic_private attr_t attr_from_sgr( const char* s, ssize_t len); +ic_private attr_t attr_from_esc_sgr( const char* s, ssize_t len); + +//------------------------------------------------------------- +// attribute buffer used for rich rendering +//------------------------------------------------------------- + +struct attrbuf_s; +typedef struct attrbuf_s attrbuf_t; + +ic_private attrbuf_t* attrbuf_new( alloc_t* mem ); +ic_private void attrbuf_free( attrbuf_t* ab ); // ab can be NULL +ic_private void attrbuf_clear( attrbuf_t* ab ); // ab can be NULL +ic_private ssize_t attrbuf_len( attrbuf_t* ab); // ab can be NULL +ic_private const attr_t* attrbuf_attrs( attrbuf_t* ab, ssize_t expected_len ); +ic_private ssize_t attrbuf_append_n( stringbuf_t* sb, attrbuf_t* ab, const char* s, ssize_t len, attr_t attr ); + +ic_private void attrbuf_set_at( attrbuf_t* ab, ssize_t pos, ssize_t count, attr_t attr ); +ic_private void attrbuf_update_at( attrbuf_t* ab, ssize_t pos, ssize_t count, attr_t attr ); +ic_private void attrbuf_insert_at( attrbuf_t* ab, ssize_t pos, ssize_t count, attr_t attr ); + +ic_private attr_t attrbuf_attr_at( attrbuf_t* ab, ssize_t pos ); +ic_private void attrbuf_delete_at( attrbuf_t* ab, ssize_t pos, ssize_t count ); + +#endif // IC_ATTR_H diff --git a/extern/isocline/src/bbcode.c b/extern/isocline/src/bbcode.c new file mode 100644 index 000000000..4d11ac380 --- /dev/null +++ b/extern/isocline/src/bbcode.c @@ -0,0 +1,842 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#include +#include +#include +#include + +#include "common.h" +#include "attr.h" +#include "term.h" +#include "bbcode.h" + +//------------------------------------------------------------- +// HTML color table +//------------------------------------------------------------- + +#include "bbcode_colors.c" + +//------------------------------------------------------------- +// Types +//------------------------------------------------------------- + +typedef struct style_s { + const char* name; // name of the style + attr_t attr; // attribute to apply +} style_t; + +typedef enum align_e { + IC_ALIGN_LEFT, + IC_ALIGN_CENTER, + IC_ALIGN_RIGHT +} align_t; + +typedef struct width_s { + ssize_t w; // > 0 + align_t align; + bool dots; // "..." (e.g. "sentence...") + char fill; // " " (e.g. "hello ") +} width_t; + +typedef struct tag_s { + const char* name; // tag open name + attr_t attr; // the saved attribute before applying the style + width_t width; // start sequence of at most "width" columns + ssize_t pos; // start position in the output (used for width restriction) +} tag_t; + + +static void tag_init(tag_t* tag) { + memset(tag,0,sizeof(*tag)); +} + +struct bbcode_s { + tag_t* tags; // stack of tags; one entry for each open tag + ssize_t tags_capacity; + ssize_t tags_nesting; + style_t* styles; // list of used defined styles + ssize_t styles_capacity; + ssize_t styles_count; + term_t* term; // terminal + alloc_t* mem; // allocator + // caches + stringbuf_t* out; // print buffer + attrbuf_t* out_attrs; + stringbuf_t* vout; // vprintf buffer +}; + + +//------------------------------------------------------------- +// Create, helpers +//------------------------------------------------------------- + +ic_private bbcode_t* bbcode_new( alloc_t* mem, term_t* term ) { + bbcode_t* bb = mem_zalloc_tp(mem,bbcode_t); + if (bb==NULL) return NULL; + bb->mem = mem; + bb->term = term; + bb->out = sbuf_new(mem); + bb->out_attrs = attrbuf_new(mem); + bb->vout = sbuf_new(mem); + return bb; +} + +ic_private void bbcode_free( bbcode_t* bb ) { + for(ssize_t i = 0; i < bb->styles_count; i++) { + mem_free(bb->mem, bb->styles[i].name); + } + mem_free(bb->mem, bb->tags); + mem_free(bb->mem, bb->styles); + sbuf_free(bb->vout); + sbuf_free(bb->out); + attrbuf_free(bb->out_attrs); + mem_free(bb->mem, bb); +} + +ic_private void bbcode_style_add( bbcode_t* bb, const char* style_name, attr_t attr ) { + if (bb->styles_count >= bb->styles_capacity) { + ssize_t newlen = bb->styles_capacity + 32; + style_t* p = mem_realloc_tp( bb->mem, style_t, bb->styles, newlen ); + if (p == NULL) return; + bb->styles = p; + bb->styles_capacity = newlen; + } + assert(bb->styles_count < bb->styles_capacity); + bb->styles[bb->styles_count].name = mem_strdup( bb->mem, style_name ); + bb->styles[bb->styles_count].attr = attr; + bb->styles_count++; +} + +static ssize_t bbcode_tag_push( bbcode_t* bb, const tag_t* tag ) { + if (bb->tags_nesting >= bb->tags_capacity) { + ssize_t newcap = bb->tags_capacity + 32; + tag_t* p = mem_realloc_tp( bb->mem, tag_t, bb->tags, newcap ); + if (p == NULL) return -1; + bb->tags = p; + bb->tags_capacity = newcap; + } + assert(bb->tags_nesting < bb->tags_capacity); + bb->tags[bb->tags_nesting] = *tag; + bb->tags_nesting++; + return (bb->tags_nesting-1); +} + +static void bbcode_tag_pop( bbcode_t* bb, tag_t* tag ) { + if (bb->tags_nesting <= 0) { + if (tag != NULL) { tag_init(tag); } + } + else { + bb->tags_nesting--; + if (tag != NULL) { + *tag = bb->tags[bb->tags_nesting]; + } + } +} + +//------------------------------------------------------------- +// Invalid parse/values/balance +//------------------------------------------------------------- + +static void bbcode_invalid(const char* fmt, ... ) { + if (getenv("ISOCLINE_BBCODE_DEBUG") != NULL) { + va_list args; + va_start(args,fmt); + vfprintf(stderr,fmt,args); + va_end(args); + } +} + +//------------------------------------------------------------- +// Set attributes +//------------------------------------------------------------- + + +static attr_t bbcode_open( bbcode_t* bb, ssize_t out_pos, const tag_t* tag, attr_t current ) { + // save current and set + tag_t cur; + tag_init(&cur); + cur.name = tag->name; + cur.attr = current; + cur.width = tag->width; + cur.pos = out_pos; + bbcode_tag_push(bb,&cur); + return attr_update_with( current, tag->attr ); +} + +static bool bbcode_close( bbcode_t* bb, ssize_t base, const char* name, tag_t* pprev ) { + // pop until match + while (bb->tags_nesting > base) { + tag_t prev; + bbcode_tag_pop(bb,&prev); + if (name==NULL || prev.name==NULL || ic_stricmp(prev.name,name) == 0) { + // matched + if (pprev != NULL) { *pprev = prev; } + return true; + } + else { + // unbalanced: we either continue popping or restore the tags depending if there is a matching open tag in our tags. + bool has_open_tag = false; + if (name != NULL) { + for( ssize_t i = bb->tags_nesting - 1; i > base; i--) { + if (bb->tags[i].name != NULL && ic_stricmp(bb->tags[i].name, name) == 0) { + has_open_tag = true; + break; + } + } + } + bbcode_invalid("bbcode: unbalanced tags: open [%s], close [/%s]\n", prev.name, name); + if (!has_open_tag) { + bbcode_tag_push( bb, &prev ); // restore the tags and ignore this close tag + break; + } + else { + // continue until we hit our open tag + } + } + } + if (pprev != NULL) { memset(pprev,0,sizeof(*pprev)); } + return false; +} + +//------------------------------------------------------------- +// Update attributes +//------------------------------------------------------------- + +static const char* attr_update_bool( const char* fname, signed int* field, const char* value ) { + if (value == NULL || value[0] == 0 || strcmp(value,"on") || strcmp(value,"true") || strcmp(value,"1")) { + *field = IC_ON; + } + else if (strcmp(value,"off") || strcmp(value,"false") || strcmp(value,"0")) { + *field = IC_OFF; + } + else { + bbcode_invalid("bbcode: invalid %s value: %s\n", fname, value ); + } + return fname; +} + +static const char* attr_update_color( const char* fname, ic_color_t* field, const char* value ) { + if (value == NULL || value[0] == 0 || strcmp(value,"none") == 0) { + *field = IC_COLOR_NONE; + return fname; + } + + // hex value + if (value[0] == '#') { + uint32_t rgb = 0; + if (sscanf(value,"#%x",&rgb) == 1) { + *field = ic_rgb(rgb); + } + else { + bbcode_invalid("bbcode: invalid color code: %s\n", value); + } + return fname; + } + + // search color names + ssize_t lo = 0; + ssize_t hi = IC_HTML_COLOR_COUNT-1; + while( lo <= hi ) { + ssize_t mid = (lo + hi) / 2; + style_color_t* info = &html_colors[mid]; + int cmp = strcmp(info->name,value); + if (cmp < 0) { + lo = mid+1; + } + else if (cmp > 0) { + hi = mid-1; + } + else { + *field = info->color; + return fname; + } + } + bbcode_invalid("bbcode: unknown %s: %s\n", fname, value); + *field = IC_COLOR_NONE; + return fname; +} + +static const char* attr_update_sgr( const char* fname, attr_t* attr, const char* value ) { + *attr = attr_update_with(*attr, attr_from_sgr(value, ic_strlen(value))); + return fname; +} + +static void attr_update_width( width_t* pwidth, char default_fill, const char* value ) { + // parse width value: ;;; + width_t width; + memset(&width, 0, sizeof(width)); + width.fill = default_fill; // use 0 for no-fill (as for max-width) + if (ic_atoz(value, &width.w)) { + ssize_t i = 0; + while( value[i] != ';' && value[i] != 0 ) { i++; } + if (value[i] == ';') { + i++; + ssize_t len = 0; + while( value[i+len] != ';' && value[i+len] != 0 ) { len++; } + if (len == 4 && ic_istarts_with(value+i,"left")) { + width.align = IC_ALIGN_LEFT; + } + else if (len == 5 && ic_istarts_with(value+i,"right")) { + width.align = IC_ALIGN_RIGHT; + } + else if (len == 6 && ic_istarts_with(value+i,"center")) { + width.align = IC_ALIGN_CENTER; + } + i += len; + if (value[i] == ';') { + i++; len = 0; + while( value[i+len] != ';' && value[i+len] != 0 ) { len++; } + if (len == 1) { width.fill = value[i]; } + i+= len; + if (value[i] == ';') { + i++; len = 0; + while( value[i+len] != ';' && value[i+len] != 0 ) { len++; } + if ((len == 2 && ic_istarts_with(value+i,"on")) || (len == 1 && value[i] == '1')) { width.dots = true; } + i += len; + } + } + } + } + else { + bbcode_invalid("bbcode: illegal width: %s\n", value); + } + *pwidth = width; +} + +static const char* attr_update_ansi_color( const char* fname, ic_color_t* color, const char* value ) { + ssize_t num = 0; + if (ic_atoz(value, &num) && num >= 0 && num <= 256) { + *color = color_from_ansi256(num); + } + return fname; +} + + +static const char* attr_update_property( tag_t* tag, const char* attr_name, const char* value ) { + const char* fname = NULL; + fname = "bold"; + if (strcmp(attr_name,fname) == 0) { + signed int b = IC_NONE; + attr_update_bool(fname,&b, value); + if (b != IC_NONE) { tag->attr.x.bold = b; } + return fname; + } + fname = "italic"; + if (strcmp(attr_name,fname) == 0) { + signed int b = IC_NONE; + attr_update_bool(fname,&b, value); + if (b != IC_NONE) { tag->attr.x.italic = b; } + return fname; + } + fname = "underline"; + if (strcmp(attr_name,fname) == 0) { + signed int b = IC_NONE; + attr_update_bool(fname,&b, value); + if (b != IC_NONE) { tag->attr.x.underline = b; } + return fname; + } + fname = "reverse"; + if (strcmp(attr_name,fname) == 0) { + signed int b = IC_NONE; + attr_update_bool(fname,&b, value); + if (b != IC_NONE) { tag->attr.x.reverse = b; } + return fname; + } + fname = "color"; + if (strcmp(attr_name,fname) == 0) { + unsigned int color = IC_COLOR_NONE; + attr_update_color(fname, &color, value); + if (color != IC_COLOR_NONE) { tag->attr.x.color = color; } + return fname; + } + fname = "bgcolor"; + if (strcmp(attr_name,fname) == 0) { + unsigned int color = IC_COLOR_NONE; + attr_update_color(fname, &color, value); + if (color != IC_COLOR_NONE) { tag->attr.x.bgcolor = color; } + return fname; + } + fname = "ansi-sgr"; + if (strcmp(attr_name,fname) == 0) { + attr_update_sgr(fname, &tag->attr, value); + return fname; + } + fname = "ansi-color"; + if (strcmp(attr_name,fname) == 0) { + ic_color_t color = IC_COLOR_NONE;; + attr_update_ansi_color(fname, &color, value); + if (color != IC_COLOR_NONE) { tag->attr.x.color = color; } + return fname; + } + fname = "ansi-bgcolor"; + if (strcmp(attr_name,fname) == 0) { + ic_color_t color = IC_COLOR_NONE;; + attr_update_ansi_color(fname, &color, value); + if (color != IC_COLOR_NONE) { tag->attr.x.bgcolor = color; } + return fname; + } + fname = "width"; + if (strcmp(attr_name,fname) == 0) { + attr_update_width(&tag->width, ' ', value); + return fname; + } + fname = "max-width"; + if (strcmp(attr_name,fname) == 0) { + attr_update_width(&tag->width, 0, value); + return "width"; + } + else { + return NULL; + } +} + +static const style_t builtin_styles[] = { + { "b", { { IC_COLOR_NONE, IC_ON , IC_NONE, IC_COLOR_NONE, IC_NONE, IC_NONE } } }, + { "r", { { IC_COLOR_NONE, IC_NONE, IC_ON , IC_COLOR_NONE, IC_NONE, IC_NONE } } }, + { "u", { { IC_COLOR_NONE, IC_NONE, IC_NONE, IC_COLOR_NONE, IC_ON , IC_NONE } } }, + { "i", { { IC_COLOR_NONE, IC_NONE, IC_NONE, IC_COLOR_NONE, IC_NONE, IC_ON } } }, + { "em", { { IC_COLOR_NONE, IC_ON , IC_NONE, IC_COLOR_NONE, IC_NONE, IC_NONE } } }, // bold + { "url",{ { IC_COLOR_NONE, IC_NONE, IC_NONE, IC_COLOR_NONE, IC_ON, IC_NONE } } }, // underline + { NULL, { { IC_COLOR_NONE, IC_NONE, IC_NONE, IC_COLOR_NONE, IC_NONE, IC_NONE } } } +}; + +static void attr_update_with_styles( tag_t* tag, const char* attr_name, const char* value, + bool usebgcolor, const style_t* styles, ssize_t count ) +{ + // direct hex color? + if (attr_name[0] == '#' && (value == NULL || value[0]==0)) { + value = attr_name; + attr_name = (usebgcolor ? "bgcolor" : "color"); + } + // first try if it is a builtin property + const char* name; + if ((name = attr_update_property(tag,attr_name,value)) != NULL) { + if (tag->name != NULL) tag->name = name; + return; + } + // then check all styles + while( count-- > 0 ) { + const style_t* style = styles + count; + if (strcmp(style->name,attr_name) == 0) { + tag->attr = attr_update_with(tag->attr,style->attr); + if (tag->name != NULL) tag->name = style->name; + return; + } + } + // check builtin styles; todo: binary search? + for( const style_t* style = builtin_styles; style->name != NULL; style++) { + if (strcmp(style->name,attr_name) == 0) { + tag->attr = attr_update_with(tag->attr,style->attr); + if (tag->name != NULL) tag->name = style->name; + return; + } + } + // check colors as a style + ssize_t lo = 0; + ssize_t hi = IC_HTML_COLOR_COUNT-1; + while( lo <= hi ) { + ssize_t mid = (lo + hi) / 2; + style_color_t* info = &html_colors[mid]; + int cmp = strcmp(info->name,attr_name); + if (cmp < 0) { + lo = mid+1; + } + else if (cmp > 0) { + hi = mid-1; + } + else { + attr_t cattr = attr_none(); + if (usebgcolor) { cattr.x.bgcolor = info->color; } + else { cattr.x.color = info->color; } + tag->attr = attr_update_with(tag->attr,cattr); + if (tag->name != NULL) tag->name = info->name; + return; + } + } + // not found + bbcode_invalid("bbcode: unknown style: %s\n", attr_name); +} + + +ic_private attr_t bbcode_style( bbcode_t* bb, const char* style_name ) { + tag_t tag; + tag_init(&tag); + attr_update_with_styles( &tag, style_name, NULL, false, bb->styles, bb->styles_count ); + return tag.attr; +} + +//------------------------------------------------------------- +// Parse tags +//------------------------------------------------------------- + +ic_private const char* parse_skip_white(const char* s) { + while( *s != 0 && *s != ']') { + if (!(*s == ' ' || *s == '\t' || *s == '\n' || *s == '\r')) break; + s++; + } + return s; +} + +ic_private const char* parse_skip_to_white(const char* s) { + while( *s != 0 && *s != ']') { + if (*s == ' ' || *s == '\t' || *s == '\n' || *s == '\r') break; + s++; + } + return parse_skip_white(s); +} + +ic_private const char* parse_skip_to_end(const char* s) { + while( *s != 0 && *s != ']' ) { s++; } + return s; +} + +ic_private const char* parse_attr_name(const char* s) { + if (*s == '#') { + s++; // hex rgb color as id + while( *s != 0 && *s != ']') { + if (!((*s >= 'a' && *s <= 'f') || (*s >= 'A' && *s <= 'Z') || (*s >= '0' && *s <= '9'))) break; + s++; + } + } + else { + while( *s != 0 && *s != ']') { + if (!((*s >= 'a' && *s <= 'z') || (*s >= 'A' && *s <= 'Z') || + (*s >= '0' && *s <= '9') || *s == '_' || *s == '-')) break; + s++; + } + } + return s; +} + +ic_private const char* parse_value(const char* s, const char** start, const char** end) { + if (*s == '"') { + s++; + *start = s; + while( *s != 0 ) { + if (*s == '"') break; + s++; + } + *end = s; + if (*s == '"') { s++; } + } + else if (*s == '#') { + *start = s; + s++; + while( *s != 0 ) { + if (!((*s >= 'a' && *s <= 'f') || (*s >= 'A' && *s <= 'Z') || (*s >= '0' && *s <= '9'))) break; + s++; + } + *end = s; + } + else { + *start = s; + while( *s != 0 ) { + if (!((*s >= 'a' && *s <= 'z') || (*s >= 'A' && *s <= 'F') || (*s >= '0' && *s <= '9') || *s == '-' || *s == '_')) break; + s++; + } + *end = s; + } + return s; +} + +ic_private const char* parse_tag_value( tag_t* tag, char* idbuf, const char* s, const style_t* styles, ssize_t scount ) { + // parse: \s*[\w-]+\s*(=\s*) + bool usebgcolor = false; + const char* id = s; + const char* idend = parse_attr_name(id); + const char* val = NULL; + const char* valend = NULL; + if (id == idend) { + bbcode_invalid("bbcode: empty identifier? %.10s...\n", id ); + return parse_skip_to_white(id); + } + // "on" bgcolor? + s = parse_skip_white(idend); + if (idend - id == 2 && ic_strnicmp(id,"on",2) == 0 && *s != '=') { + usebgcolor = true; + id = s; + idend = parse_attr_name(id); + if (id == idend) { + bbcode_invalid("bbcode: empty identifier follows 'on'? %.10s...\n", id ); + return parse_skip_to_white(id); + } + s = parse_skip_white(idend); + } + // value + if (*s == '=') { + s++; + s = parse_skip_white(s); + s = parse_value(s, &val, &valend); + s = parse_skip_white(s); + } + // limit name and attr to 128 bytes + char valbuf[128]; + ic_strncpy( idbuf, 128, id, idend - id); + ic_strncpy( valbuf, 128, val, valend - val); + ic_str_tolower(idbuf); + ic_str_tolower(valbuf); + attr_update_with_styles( tag, idbuf, valbuf, usebgcolor, styles, scount ); + return s; +} + +static const char* parse_tag_values( tag_t* tag, char* idbuf, const char* s, const style_t* styles, ssize_t scount ) { + s = parse_skip_white(s); + idbuf[0] = 0; + ssize_t count = 0; + while( *s != 0 && *s != ']') { + char idbuf_next[128]; + s = parse_tag_value(tag, (count==0 ? idbuf : idbuf_next), s, styles, scount); + count++; + } + if (*s == ']') { s++; } + return s; +} + +static const char* parse_tag( tag_t* tag, char* idbuf, bool* open, bool* pre, const char* s, const style_t* styles, ssize_t scount ) { + *open = true; + *pre = false; + if (*s != '[') return s; + s = parse_skip_white(s+1); + if (*s == '!') { // pre + *pre = true; + s = parse_skip_white(s+1); + } + else if (*s == '/') { + *open = false; + s = parse_skip_white(s+1); + }; + s = parse_tag_values( tag, idbuf, s, styles, scount); + return s; +} + + +//--------------------------------------------------------- +// Styles +//--------------------------------------------------------- + +static void bbcode_parse_tag_content( bbcode_t* bb, const char* s, tag_t* tag ) { + tag_init(tag); + if (s != NULL) { + char idbuf[128]; + parse_tag_values(tag, idbuf, s, bb->styles, bb->styles_count); + } +} + +ic_private void bbcode_style_def( bbcode_t* bb, const char* style_name, const char* s ) { + tag_t tag; + bbcode_parse_tag_content( bb, s, &tag); + bbcode_style_add(bb, style_name, tag.attr); +} + +ic_private void bbcode_style_open( bbcode_t* bb, const char* fmt ) { + tag_t tag; + bbcode_parse_tag_content(bb, fmt, &tag); + term_set_attr( bb->term, bbcode_open(bb, 0, &tag, term_get_attr(bb->term)) ); +} + +ic_private void bbcode_style_close( bbcode_t* bb, const char* fmt ) { + const ssize_t base = bb->tags_nesting - 1; // as we end a style + tag_t tag; + bbcode_parse_tag_content(bb, fmt, &tag); + tag_t prev; + if (bbcode_close(bb, base, tag.name, &prev)) { + term_set_attr( bb->term, prev.attr ); + } +} + +//--------------------------------------------------------- +// Restrict to width +//--------------------------------------------------------- + +static void bbcode_restrict_width( ssize_t start, width_t width, stringbuf_t* out, attrbuf_t* attr_out ) { + if (width.w <= 0) return; + assert(start <= sbuf_len(out)); + assert(attr_out == NULL || sbuf_len(out) == attrbuf_len(attr_out)); + const char* s = sbuf_string(out) + start; + const ssize_t len = sbuf_len(out) - start; + const ssize_t w = str_column_width(s); + if (w == width.w) return; // fits exactly + if (w > width.w) { + // too large + ssize_t innerw = (width.dots && width.w > 3 ? width.w-3 : width.w); + if (width.align == IC_ALIGN_RIGHT) { + // right align + const ssize_t ndel = str_skip_until_fit( s, innerw ); + sbuf_delete_at( out, start, ndel ); + attrbuf_delete_at( attr_out, start, ndel ); + if (innerw < width.w) { + // add dots + sbuf_insert_at( out, "...", start ); + attr_t attr = attrbuf_attr_at(attr_out, start); + attrbuf_insert_at( attr_out, start, 3, attr); + } + } + else { + // left or center align + ssize_t count = str_take_while_fit( s, innerw ); + sbuf_delete_at( out, start + count, len - count ); + attrbuf_delete_at( attr_out, start + count, len - count ); + if (innerw < width.w) { + // add dots + attr_t attr = attrbuf_attr_at(attr_out,start); + attrbuf_append_n( out, attr_out, "...", 3, attr ); + } + } + } + else { + // too short, pad to width + const ssize_t diff = (width.w - w); + const ssize_t pad_left = (width.align == IC_ALIGN_RIGHT ? diff : (width.align == IC_ALIGN_LEFT ? 0 : diff / 2)); + const ssize_t pad_right = (width.align == IC_ALIGN_LEFT ? diff : (width.align == IC_ALIGN_RIGHT ? 0 : diff - pad_left)); + if (width.fill != 0 && pad_left > 0) { + const attr_t attr = attrbuf_attr_at(attr_out,start); + for( ssize_t i = 0; i < pad_left; i++) { // todo: optimize + sbuf_insert_char_at(out, width.fill, start); + } + attrbuf_insert_at( attr_out, start, pad_left, attr ); + } + if (width.fill != 0 && pad_right > 0) { + const attr_t attr = attrbuf_attr_at(attr_out,sbuf_len(out) - 1); + char buf[2]; + buf[0] = width.fill; + buf[1] = 0; + for( ssize_t i = 0; i < pad_right; i++) { // todo: optimize + attrbuf_append_n( out, attr_out, buf, 1, attr ); + } + } + } +} + +//--------------------------------------------------------- +// Print +//--------------------------------------------------------- + +ic_private ssize_t bbcode_process_tag( bbcode_t* bb, const char* s, const ssize_t nesting_base, + stringbuf_t* out, attrbuf_t* attr_out, attr_t* cur_attr ) { + assert(*s == '['); + tag_t tag; + tag_init(&tag); + bool open = true; + bool ispre = false; + char idbuf[128]; + const char* end = parse_tag( &tag, idbuf, &open, &ispre, s, bb->styles, bb->styles_count ); // todo: styles + assert(end > s); + if (open) { + if (!ispre) { + // open tag + *cur_attr = bbcode_open( bb, sbuf_len(out), &tag, *cur_attr ); + } + else { + // scan pre to end tag + attr_t attr = attr_update_with(*cur_attr, tag.attr); + char pre[132]; + if (snprintf(pre, 132, "[/%s]", idbuf) < ssizeof(pre)) { + const char* etag = strstr(end,pre); + if (etag == NULL) { + const ssize_t len = ic_strlen(end); + attrbuf_append_n(out, attr_out, end, len, attr); + end += len; + } + else { + attrbuf_append_n(out, attr_out, end, (etag - end), attr); + end = etag + ic_strlen(pre); + } + } + } + } + else { + // pop the tag + tag_t prev; + if (bbcode_close( bb, nesting_base, tag.name, &prev)) { + *cur_attr = prev.attr; + if (prev.width.w > 0) { + // closed a width tag; restrict the output to width + bbcode_restrict_width( prev.pos, prev.width, out, attr_out); + } + } + } + return (end - s); +} + +ic_private void bbcode_append( bbcode_t* bb, const char* s, stringbuf_t* out, attrbuf_t* attr_out ) { + if (bb == NULL || s == NULL) return; + attr_t attr = attr_none(); + const ssize_t base = bb->tags_nesting; // base; will not be popped + ssize_t i = 0; + while( s[i] != 0 ) { + // handle no tags in bulk + ssize_t nobb = 0; + char c; + while( (c = s[i+nobb]) != 0) { + if (c == '[' || c == '\\') { break; } + if (c == '\x1B' && s[i+nobb+1] == '[') { + nobb++; // don't count 'ESC[' as a tag opener + } + nobb++; + } + if (nobb > 0) { attrbuf_append_n(out, attr_out, s+i, nobb, attr); } + i += nobb; + // tag + if (s[i] == '[') { + i += bbcode_process_tag(bb, s+i, base, out, attr_out, &attr); + } + else if (s[i] == '\\') { + if (s[i+1] == '\\' || s[i+1] == '[') { + attrbuf_append_n(out, attr_out, s+i+1, 1, attr); // escape '\[' and '\\' + i += 2; + } + else { + attrbuf_append_n(out, attr_out, s+i, 1, attr); // pass '\\' as is + i++; + } + } + } + // pop unclosed openings + assert(bb->tags_nesting >= base); + while( bb->tags_nesting > base ) { + bbcode_tag_pop(bb,NULL); + }; +} + +ic_private void bbcode_print( bbcode_t* bb, const char* s ) { + if (bb->out == NULL || bb->out_attrs == NULL || s == NULL) return; + assert(sbuf_len(bb->out) == 0 && attrbuf_len(bb->out_attrs) == 0); + bbcode_append( bb, s, bb->out, bb->out_attrs ); + term_write_formatted( bb->term, sbuf_string(bb->out), attrbuf_attrs(bb->out_attrs,sbuf_len(bb->out)) ); + attrbuf_clear(bb->out_attrs); + sbuf_clear(bb->out); +} + +ic_private void bbcode_println( bbcode_t* bb, const char* s ) { + bbcode_print(bb,s); + term_writeln(bb->term, ""); +} + +ic_private void bbcode_vprintf( bbcode_t* bb, const char* fmt, va_list args ) { + if (bb->vout == NULL || fmt == NULL) return; + assert(sbuf_len(bb->vout) == 0); + sbuf_append_vprintf(bb->vout,fmt,args); + bbcode_print(bb, sbuf_string(bb->vout)); + sbuf_clear(bb->vout); +} + +ic_private void bbcode_printf( bbcode_t* bb, const char* fmt, ... ) { + va_list args; + va_start(args,fmt); + bbcode_vprintf(bb,fmt,args); + va_end(args); +} + +ic_private ssize_t bbcode_column_width( bbcode_t* bb, const char* s ) { + if (s==NULL || s[0] == 0) return 0; + if (bb->vout == NULL) { return str_column_width(s); } + assert(sbuf_len(bb->vout) == 0); + bbcode_append( bb, s, bb->vout, NULL); + const ssize_t w = str_column_width(sbuf_string(bb->vout)); + sbuf_clear(bb->vout); + return w; +} diff --git a/extern/isocline/src/bbcode.h b/extern/isocline/src/bbcode.h new file mode 100644 index 000000000..be96bfe2d --- /dev/null +++ b/extern/isocline/src/bbcode.h @@ -0,0 +1,37 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_BBCODE_H +#define IC_BBCODE_H + +#include +#include "common.h" +#include "term.h" + +struct bbcode_s; +typedef struct bbcode_s bbcode_t; + +ic_private bbcode_t* bbcode_new( alloc_t* mem, term_t* term ); +ic_private void bbcode_free( bbcode_t* bb ); + +ic_private void bbcode_style_add( bbcode_t* bb, const char* style_name, attr_t attr ); +ic_private void bbcode_style_def( bbcode_t* bb, const char* style_name, const char* s ); +ic_private void bbcode_style_open( bbcode_t* bb, const char* fmt ); +ic_private void bbcode_style_close( bbcode_t* bb, const char* fmt ); +ic_private attr_t bbcode_style( bbcode_t* bb, const char* style_name ); + +ic_private void bbcode_print( bbcode_t* bb, const char* s ); +ic_private void bbcode_println( bbcode_t* bb, const char* s ); +ic_private void bbcode_printf( bbcode_t* bb, const char* fmt, ... ); +ic_private void bbcode_vprintf( bbcode_t* bb, const char* fmt, va_list args ); + +ic_private ssize_t bbcode_column_width( bbcode_t* bb, const char* s ); + +// allows `attr_out == NULL`. +ic_private void bbcode_append( bbcode_t* bb, const char* s, stringbuf_t* out, attrbuf_t* attr_out ); + +#endif // IC_BBCODE_H diff --git a/extern/isocline/src/bbcode_colors.c b/extern/isocline/src/bbcode_colors.c new file mode 100644 index 000000000..245cd3de3 --- /dev/null +++ b/extern/isocline/src/bbcode_colors.c @@ -0,0 +1,194 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ + +// This file is included from "bbcode.c" and contains html color names + +#include "common.h" + +typedef struct style_color_s { + const char* name; + ic_color_t color; +} style_color_t; + +#define IC_HTML_COLOR_COUNT (172) + +// ordered list of HTML color names (so we can use binary search) +static style_color_t html_colors[IC_HTML_COLOR_COUNT+1] = { + { "aliceblue", IC_RGB(0xf0f8ff) }, + { "ansi-aqua", IC_ANSI_AQUA }, + { "ansi-black", IC_ANSI_BLACK }, + { "ansi-blue", IC_ANSI_BLUE }, + { "ansi-cyan", IC_ANSI_CYAN }, + { "ansi-darkgray", IC_ANSI_DARKGRAY }, + { "ansi-darkgrey", IC_ANSI_DARKGRAY }, + { "ansi-default", IC_ANSI_DEFAULT }, + { "ansi-fuchsia", IC_ANSI_FUCHSIA }, + { "ansi-gray", IC_ANSI_GRAY }, + { "ansi-green", IC_ANSI_GREEN }, + { "ansi-grey", IC_ANSI_GRAY }, + { "ansi-lightgray", IC_ANSI_LIGHTGRAY }, + { "ansi-lightgrey", IC_ANSI_LIGHTGRAY }, + { "ansi-lime" , IC_ANSI_LIME }, + { "ansi-magenta", IC_ANSI_MAGENTA }, + { "ansi-maroon", IC_ANSI_MAROON }, + { "ansi-navy", IC_ANSI_NAVY }, + { "ansi-olive", IC_ANSI_OLIVE }, + { "ansi-purple", IC_ANSI_PURPLE }, + { "ansi-red", IC_ANSI_RED }, + { "ansi-silver", IC_ANSI_SILVER }, + { "ansi-teal", IC_ANSI_TEAL }, + { "ansi-white", IC_ANSI_WHITE }, + { "ansi-yellow", IC_ANSI_YELLOW }, + { "antiquewhite", IC_RGB(0xfaebd7) }, + { "aqua", IC_RGB(0x00ffff) }, + { "aquamarine", IC_RGB(0x7fffd4) }, + { "azure", IC_RGB(0xf0ffff) }, + { "beige", IC_RGB(0xf5f5dc) }, + { "bisque", IC_RGB(0xffe4c4) }, + { "black", IC_RGB(0x000000) }, + { "blanchedalmond", IC_RGB(0xffebcd) }, + { "blue", IC_RGB(0x0000ff) }, + { "blueviolet", IC_RGB(0x8a2be2) }, + { "brown", IC_RGB(0xa52a2a) }, + { "burlywood", IC_RGB(0xdeb887) }, + { "cadetblue", IC_RGB(0x5f9ea0) }, + { "chartreuse", IC_RGB(0x7fff00) }, + { "chocolate", IC_RGB(0xd2691e) }, + { "coral", IC_RGB(0xff7f50) }, + { "cornflowerblue", IC_RGB(0x6495ed) }, + { "cornsilk", IC_RGB(0xfff8dc) }, + { "crimson", IC_RGB(0xdc143c) }, + { "cyan", IC_RGB(0x00ffff) }, + { "darkblue", IC_RGB(0x00008b) }, + { "darkcyan", IC_RGB(0x008b8b) }, + { "darkgoldenrod", IC_RGB(0xb8860b) }, + { "darkgray", IC_RGB(0xa9a9a9) }, + { "darkgreen", IC_RGB(0x006400) }, + { "darkgrey", IC_RGB(0xa9a9a9) }, + { "darkkhaki", IC_RGB(0xbdb76b) }, + { "darkmagenta", IC_RGB(0x8b008b) }, + { "darkolivegreen", IC_RGB(0x556b2f) }, + { "darkorange", IC_RGB(0xff8c00) }, + { "darkorchid", IC_RGB(0x9932cc) }, + { "darkred", IC_RGB(0x8b0000) }, + { "darksalmon", IC_RGB(0xe9967a) }, + { "darkseagreen", IC_RGB(0x8fbc8f) }, + { "darkslateblue", IC_RGB(0x483d8b) }, + { "darkslategray", IC_RGB(0x2f4f4f) }, + { "darkslategrey", IC_RGB(0x2f4f4f) }, + { "darkturquoise", IC_RGB(0x00ced1) }, + { "darkviolet", IC_RGB(0x9400d3) }, + { "deeppink", IC_RGB(0xff1493) }, + { "deepskyblue", IC_RGB(0x00bfff) }, + { "dimgray", IC_RGB(0x696969) }, + { "dimgrey", IC_RGB(0x696969) }, + { "dodgerblue", IC_RGB(0x1e90ff) }, + { "firebrick", IC_RGB(0xb22222) }, + { "floralwhite", IC_RGB(0xfffaf0) }, + { "forestgreen", IC_RGB(0x228b22) }, + { "fuchsia", IC_RGB(0xff00ff) }, + { "gainsboro", IC_RGB(0xdcdcdc) }, + { "ghostwhite", IC_RGB(0xf8f8ff) }, + { "gold", IC_RGB(0xffd700) }, + { "goldenrod", IC_RGB(0xdaa520) }, + { "gray", IC_RGB(0x808080) }, + { "green", IC_RGB(0x008000) }, + { "greenyellow", IC_RGB(0xadff2f) }, + { "grey", IC_RGB(0x808080) }, + { "honeydew", IC_RGB(0xf0fff0) }, + { "hotpink", IC_RGB(0xff69b4) }, + { "indianred", IC_RGB(0xcd5c5c) }, + { "indigo", IC_RGB(0x4b0082) }, + { "ivory", IC_RGB(0xfffff0) }, + { "khaki", IC_RGB(0xf0e68c) }, + { "lavender", IC_RGB(0xe6e6fa) }, + { "lavenderblush", IC_RGB(0xfff0f5) }, + { "lawngreen", IC_RGB(0x7cfc00) }, + { "lemonchiffon", IC_RGB(0xfffacd) }, + { "lightblue", IC_RGB(0xadd8e6) }, + { "lightcoral", IC_RGB(0xf08080) }, + { "lightcyan", IC_RGB(0xe0ffff) }, + { "lightgoldenrodyellow", IC_RGB(0xfafad2) }, + { "lightgray", IC_RGB(0xd3d3d3) }, + { "lightgreen", IC_RGB(0x90ee90) }, + { "lightgrey", IC_RGB(0xd3d3d3) }, + { "lightpink", IC_RGB(0xffb6c1) }, + { "lightsalmon", IC_RGB(0xffa07a) }, + { "lightseagreen", IC_RGB(0x20b2aa) }, + { "lightskyblue", IC_RGB(0x87cefa) }, + { "lightslategray", IC_RGB(0x778899) }, + { "lightslategrey", IC_RGB(0x778899) }, + { "lightsteelblue", IC_RGB(0xb0c4de) }, + { "lightyellow", IC_RGB(0xffffe0) }, + { "lime", IC_RGB(0x00ff00) }, + { "limegreen", IC_RGB(0x32cd32) }, + { "linen", IC_RGB(0xfaf0e6) }, + { "magenta", IC_RGB(0xff00ff) }, + { "maroon", IC_RGB(0x800000) }, + { "mediumaquamarine", IC_RGB(0x66cdaa) }, + { "mediumblue", IC_RGB(0x0000cd) }, + { "mediumorchid", IC_RGB(0xba55d3) }, + { "mediumpurple", IC_RGB(0x9370db) }, + { "mediumseagreen", IC_RGB(0x3cb371) }, + { "mediumslateblue", IC_RGB(0x7b68ee) }, + { "mediumspringgreen", IC_RGB(0x00fa9a) }, + { "mediumturquoise", IC_RGB(0x48d1cc) }, + { "mediumvioletred", IC_RGB(0xc71585) }, + { "midnightblue", IC_RGB(0x191970) }, + { "mintcream", IC_RGB(0xf5fffa) }, + { "mistyrose", IC_RGB(0xffe4e1) }, + { "moccasin", IC_RGB(0xffe4b5) }, + { "navajowhite", IC_RGB(0xffdead) }, + { "navy", IC_RGB(0x000080) }, + { "oldlace", IC_RGB(0xfdf5e6) }, + { "olive", IC_RGB(0x808000) }, + { "olivedrab", IC_RGB(0x6b8e23) }, + { "orange", IC_RGB(0xffa500) }, + { "orangered", IC_RGB(0xff4500) }, + { "orchid", IC_RGB(0xda70d6) }, + { "palegoldenrod", IC_RGB(0xeee8aa) }, + { "palegreen", IC_RGB(0x98fb98) }, + { "paleturquoise", IC_RGB(0xafeeee) }, + { "palevioletred", IC_RGB(0xdb7093) }, + { "papayawhip", IC_RGB(0xffefd5) }, + { "peachpuff", IC_RGB(0xffdab9) }, + { "peru", IC_RGB(0xcd853f) }, + { "pink", IC_RGB(0xffc0cb) }, + { "plum", IC_RGB(0xdda0dd) }, + { "powderblue", IC_RGB(0xb0e0e6) }, + { "purple", IC_RGB(0x800080) }, + { "rebeccapurple", IC_RGB(0x663399) }, + { "red", IC_RGB(0xff0000) }, + { "rosybrown", IC_RGB(0xbc8f8f) }, + { "royalblue", IC_RGB(0x4169e1) }, + { "saddlebrown", IC_RGB(0x8b4513) }, + { "salmon", IC_RGB(0xfa8072) }, + { "sandybrown", IC_RGB(0xf4a460) }, + { "seagreen", IC_RGB(0x2e8b57) }, + { "seashell", IC_RGB(0xfff5ee) }, + { "sienna", IC_RGB(0xa0522d) }, + { "silver", IC_RGB(0xc0c0c0) }, + { "skyblue", IC_RGB(0x87ceeb) }, + { "slateblue", IC_RGB(0x6a5acd) }, + { "slategray", IC_RGB(0x708090) }, + { "slategrey", IC_RGB(0x708090) }, + { "snow", IC_RGB(0xfffafa) }, + { "springgreen", IC_RGB(0x00ff7f) }, + { "steelblue", IC_RGB(0x4682b4) }, + { "tan", IC_RGB(0xd2b48c) }, + { "teal", IC_RGB(0x008080) }, + { "thistle", IC_RGB(0xd8bfd8) }, + { "tomato", IC_RGB(0xff6347) }, + { "turquoise", IC_RGB(0x40e0d0) }, + { "violet", IC_RGB(0xee82ee) }, + { "wheat", IC_RGB(0xf5deb3) }, + { "white", IC_RGB(0xffffff) }, + { "whitesmoke", IC_RGB(0xf5f5f5) }, + { "yellow", IC_RGB(0xffff00) }, + { "yellowgreen", IC_RGB(0x9acd32) }, + {NULL, 0} +}; diff --git a/extern/isocline/src/common.c b/extern/isocline/src/common.c new file mode 100644 index 000000000..1d9fb566c --- /dev/null +++ b/extern/isocline/src/common.c @@ -0,0 +1,347 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ + +#include +#include +#include +#include +#include "common.h" + + +//------------------------------------------------------------- +// String wrappers for ssize_t +//------------------------------------------------------------- + +ic_private ssize_t ic_strlen( const char* s ) { + if (s==NULL) return 0; + return to_ssize_t(strlen(s)); +} + +ic_private void ic_memmove( void* dest, const void* src, ssize_t n ) { + assert(dest!=NULL && src != NULL); + if (n <= 0) return; + memmove(dest,src,to_size_t(n)); +} + + +ic_private void ic_memcpy( void* dest, const void* src, ssize_t n ) { + assert(dest!=NULL && src != NULL); + if (dest == NULL || src == NULL || n <= 0) return; + memcpy(dest,src,to_size_t(n)); +} + +ic_private void ic_memset(void* dest, uint8_t value, ssize_t n) { + assert(dest!=NULL); + if (dest == NULL || n <= 0) return; + memset(dest,(int8_t)value,to_size_t(n)); +} + +ic_private bool ic_memnmove( void* dest, ssize_t dest_size, const void* src, ssize_t n ) { + assert(dest!=NULL && src != NULL); + if (n <= 0) return true; + if (dest_size < n) { assert(false); return false; } + memmove(dest,src,to_size_t(n)); + return true; +} + +ic_private bool ic_strcpy( char* dest, ssize_t dest_size /* including 0 */, const char* src) { + assert(dest!=NULL && src != NULL); + if (dest == NULL || dest_size <= 0) return false; + ssize_t slen = ic_strlen(src); + if (slen >= dest_size) return false; + strcpy(dest,src); + assert(dest[slen] == 0); + return true; +} + + +ic_private bool ic_strncpy( char* dest, ssize_t dest_size /* including 0 */, const char* src, ssize_t n) { + assert(dest!=NULL && n < dest_size); + if (dest == NULL || dest_size <= 0) return false; + if (n >= dest_size) return false; + if (src==NULL || n <= 0) { + dest[0] = 0; + } + else { + strncpy(dest,src,to_size_t(n)); + dest[n] = 0; + } + return true; +} + +//------------------------------------------------------------- +// String matching +//------------------------------------------------------------- + +ic_public bool ic_starts_with( const char* s, const char* prefix ) { + if (s==prefix) return true; + if (prefix==NULL) return true; + if (s==NULL) return false; + + ssize_t i; + for( i = 0; s[i] != 0 && prefix[i] != 0; i++) { + if (s[i] != prefix[i]) return false; + } + return (prefix[i] == 0); +} + +ic_private char ic_tolower( char c ) { + return (c >= 'A' && c <= 'Z' ? c - 'A' + 'a' : c); +} + +ic_private void ic_str_tolower(char* s) { + while(*s != 0) { + *s = ic_tolower(*s); + s++; + } +} + +ic_public bool ic_istarts_with( const char* s, const char* prefix ) { + if (s==prefix) return true; + if (prefix==NULL) return true; + if (s==NULL) return false; + + ssize_t i; + for( i = 0; s[i] != 0 && prefix[i] != 0; i++) { + if (ic_tolower(s[i]) != ic_tolower(prefix[i])) return false; + } + return (prefix[i] == 0); +} + + +ic_private int ic_strnicmp(const char* s1, const char* s2, ssize_t n) { + if (s1 == NULL && s2 == NULL) return 0; + if (s1 == NULL) return -1; + if (s2 == NULL) return 1; + ssize_t i; + for (i = 0; s1[i] != 0 && i < n; i++) { // note: if s2[i] == 0 the loop will stop as c1 != c2 + char c1 = ic_tolower(s1[i]); + char c2 = ic_tolower(s2[i]); + if (c1 < c2) return -1; + if (c1 > c2) return 1; + } + return ((i >= n || s2[i] == 0) ? 0 : -1); +} + +ic_private int ic_stricmp(const char* s1, const char* s2) { + ssize_t len1 = ic_strlen(s1); + ssize_t len2 = ic_strlen(s2); + if (len1 < len2) return -1; + if (len1 > len2) return 1; + return (ic_strnicmp(s1, s2, (len1 >= len2 ? len1 : len2))); +} + + +static const char* ic_stristr(const char* s, const char* pat) { + if (s==NULL) return NULL; + if (pat==NULL || pat[0] == 0) return s; + ssize_t patlen = ic_strlen(pat); + for (ssize_t i = 0; s[i] != 0; i++) { + if (ic_strnicmp(s + i, pat, patlen) == 0) return (s+i); + } + return NULL; +} + +ic_private bool ic_contains(const char* big, const char* s) { + if (big == NULL) return false; + if (s == NULL) return true; + return (strstr(big,s) != NULL); +} + +ic_private bool ic_icontains(const char* big, const char* s) { + if (big == NULL) return false; + if (s == NULL) return true; + return (ic_stristr(big,s) != NULL); +} + + +//------------------------------------------------------------- +// Unicode +// QUTF-8: See +// Raw bytes are code points 0xEE000 - 0xEE0FF +//------------------------------------------------------------- +#define IC_UNICODE_RAW ((unicode_t)(0xEE000U)) + +ic_private unicode_t unicode_from_raw(uint8_t c) { + return (IC_UNICODE_RAW + c); +} + +ic_private bool unicode_is_raw(unicode_t u, uint8_t* c) { + if (u >= IC_UNICODE_RAW && u <= IC_UNICODE_RAW + 0xFF) { + *c = (uint8_t)(u - IC_UNICODE_RAW); + return true; + } + else { + return false; + } +} + +ic_private void unicode_to_qutf8(unicode_t u, uint8_t buf[5]) { + memset(buf, 0, 5); + if (u <= 0x7F) { + buf[0] = (uint8_t)u; + } + else if (u <= 0x07FF) { + buf[0] = (0xC0 | ((uint8_t)(u >> 6))); + buf[1] = (0x80 | (((uint8_t)u) & 0x3F)); + } + else if (u <= 0xFFFF) { + buf[0] = (0xE0 | ((uint8_t)(u >> 12))); + buf[1] = (0x80 | (((uint8_t)(u >> 6)) & 0x3F)); + buf[2] = (0x80 | (((uint8_t)u) & 0x3F)); + } + else if (u <= 0x10FFFF) { + if (unicode_is_raw(u, &buf[0])) { + buf[1] = 0; + } + else { + buf[0] = (0xF0 | ((uint8_t)(u >> 18))); + buf[1] = (0x80 | (((uint8_t)(u >> 12)) & 0x3F)); + buf[2] = (0x80 | (((uint8_t)(u >> 6)) & 0x3F)); + buf[3] = (0x80 | (((uint8_t)u) & 0x3F)); + } + } +} + +// is this a utf8 continuation byte? +ic_private bool utf8_is_cont(uint8_t c) { + return ((c & 0xC0) == 0x80); +} + +ic_private unicode_t unicode_from_qutf8(const uint8_t* s, ssize_t len, ssize_t* count) { + unicode_t c0 = 0; + if (len <= 0 || s == NULL) { + goto fail; + } + // 1 byte + c0 = s[0]; + if (c0 <= 0x7F && len >= 1) { + if (count != NULL) *count = 1; + return c0; + } + else if (c0 <= 0xC1) { // invalid continuation byte or invalid 0xC0, 0xC1 + goto fail; + } + // 2 bytes + else if (c0 <= 0xDF && len >= 2 && utf8_is_cont(s[1])) { + if (count != NULL) *count = 2; + return (((c0 & 0x1F) << 6) | (s[1] & 0x3F)); + } + // 3 bytes: reject overlong and surrogate halves + else if (len >= 3 && + ((c0 == 0xE0 && s[1] >= 0xA0 && s[1] <= 0xBF && utf8_is_cont(s[2])) || + (c0 >= 0xE1 && c0 <= 0xEC && utf8_is_cont(s[1]) && utf8_is_cont(s[2])) + )) + { + if (count != NULL) *count = 3; + return (((c0 & 0x0F) << 12) | ((unicode_t)(s[1] & 0x3F) << 6) | (s[2] & 0x3F)); + } + // 4 bytes: reject overlong + else if (len >= 4 && + (((c0 == 0xF0 && s[1] >= 0x90 && s[1] <= 0xBF && utf8_is_cont(s[2]) && utf8_is_cont(s[3])) || + (c0 >= 0xF1 && c0 <= 0xF3 && utf8_is_cont(s[1]) && utf8_is_cont(s[2]) && utf8_is_cont(s[3])) || + (c0 == 0xF4 && s[1] >= 0x80 && s[1] <= 0x8F && utf8_is_cont(s[2]) && utf8_is_cont(s[3]))) + )) + { + if (count != NULL) *count = 4; + return (((c0 & 0x07) << 18) | ((unicode_t)(s[1] & 0x3F) << 12) | ((unicode_t)(s[2] & 0x3F) << 6) | (s[3] & 0x3F)); + } +fail: + if (count != NULL) *count = 1; + return unicode_from_raw(s[0]); +} + + +//------------------------------------------------------------- +// Debug +//------------------------------------------------------------- + +#if defined(IC_NO_DEBUG_MSG) +// nothing +#elif !defined(IC_DEBUG_TO_FILE) +ic_private void debug_msg(const char* fmt, ...) { + if (getenv("ISOCLINE_DEBUG")) { + va_list args; + va_start(args, fmt); + vfprintf(stderr, fmt, args); + va_end(args); + } +} +#else +ic_private void debug_msg(const char* fmt, ...) { + static int debug_init; + static const char* debug_fname = "isocline.debug.txt"; + // initialize? + if (debug_init==0) { + debug_init = -1; + const char* rdebug = getenv("ISOCLINE_DEBUG"); + if (rdebug!=NULL && strcmp(rdebug,"1") == 0) { + FILE* fdbg = fopen(debug_fname, "w"); + if (fdbg!=NULL) { + debug_init = 1; + fclose(fdbg); + } + } + } + if (debug_init <= 0) return; + + // write debug messages + FILE* fdbg = fopen(debug_fname, "a"); + if (fdbg==NULL) return; + va_list args; + va_start(args, fmt); + vfprintf(fdbg, fmt, args); + fclose(fdbg); + va_end(args); +} +#endif + + +//------------------------------------------------------------- +// Allocation +//------------------------------------------------------------- + +ic_private void* mem_malloc(alloc_t* mem, ssize_t sz) { + return mem->malloc(to_size_t(sz)); +} + +ic_private void* mem_zalloc(alloc_t* mem, ssize_t sz) { + void* p = mem_malloc(mem, sz); + if (p != NULL) memset(p, 0, to_size_t(sz)); + return p; +} + +ic_private void* mem_realloc(alloc_t* mem, void* p, ssize_t newsz) { + return mem->realloc(p, to_size_t(newsz)); +} + +ic_private void mem_free(alloc_t* mem, const void* p) { + mem->free((void*)p); +} + +ic_private char* mem_strdup(alloc_t* mem, const char* s) { + if (s==NULL) return NULL; + ssize_t n = ic_strlen(s); + char* p = mem_malloc_tp_n(mem, char, n+1); + if (p == NULL) return NULL; + ic_memcpy(p, s, n+1); + return p; +} + +ic_private char* mem_strndup(alloc_t* mem, const char* s, ssize_t n) { + if (s==NULL || n < 0) return NULL; + char* p = mem_malloc_tp_n(mem, char, n+1); + if (p == NULL) return NULL; + ssize_t i; + for (i = 0; i < n && s[i] != 0; i++) { + p[i] = s[i]; + } + assert(i <= n); + p[i] = 0; + return p; +} + diff --git a/extern/isocline/src/common.h b/extern/isocline/src/common.h new file mode 100644 index 000000000..dd5b25697 --- /dev/null +++ b/extern/isocline/src/common.h @@ -0,0 +1,187 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ + +#pragma once +#ifndef IC_COMMON_H +#define IC_COMMON_H + +//------------------------------------------------------------- +// Headers and defines +//------------------------------------------------------------- + +#include // ssize_t +#include +#include +#include +#include +#include +#include "../include/isocline.h" // ic_malloc_fun_t, ic_color_t etc. + +# ifdef __cplusplus +# define ic_extern_c extern "C" +# else +# define ic_extern_c +# endif + +#if defined(IC_SEPARATE_OBJS) +# define ic_public ic_extern_c +# if defined(__GNUC__) // includes clang and icc +# define ic_private __attribute__((visibility("hidden"))) +# else +# define ic_private +# endif +#else +# define ic_private static +# define ic_public ic_extern_c +#endif + +#define ic_unused(x) (void)(x) + + +//------------------------------------------------------------- +// ssize_t +//------------------------------------------------------------- + +#if defined(_MSC_VER) +typedef intptr_t ssize_t; +#endif + +#define ssizeof(tp) (ssize_t)(sizeof(tp)) +static inline size_t to_size_t(ssize_t sz) { return (sz >= 0 ? (size_t)sz : 0); } +static inline ssize_t to_ssize_t(size_t sz) { return (sz <= SIZE_MAX/2 ? (ssize_t)sz : 0); } + +ic_private void ic_memmove(void* dest, const void* src, ssize_t n); +ic_private void ic_memcpy(void* dest, const void* src, ssize_t n); +ic_private void ic_memset(void* dest, uint8_t value, ssize_t n); +ic_private bool ic_memnmove(void* dest, ssize_t dest_size, const void* src, ssize_t n); + +ic_private ssize_t ic_strlen(const char* s); +ic_private bool ic_strcpy(char* dest, ssize_t dest_size /* including 0 */, const char* src); +ic_private bool ic_strncpy(char* dest, ssize_t dest_size /* including 0 */, const char* src, ssize_t n); + +ic_private bool ic_contains(const char* big, const char* s); +ic_private bool ic_icontains(const char* big, const char* s); +ic_private char ic_tolower(char c); +ic_private void ic_str_tolower(char* s); +ic_private int ic_stricmp(const char* s1, const char* s2); +ic_private int ic_strnicmp(const char* s1, const char* s2, ssize_t n); + + + +//--------------------------------------------------------------------- +// Unicode +// +// We use "qutf-8" (quite like utf-8) encoding and decoding. +// Internally we always use valid utf-8. If we encounter invalid +// utf-8 bytes (or bytes >= 0x80 from any other encoding) we encode +// these as special code points in the "raw plane" (0xEE000 - 0xEE0FF). +// When decoding we are then able to restore such raw bytes as-is. +// See +//--------------------------------------------------------------------- + +typedef uint32_t unicode_t; + +ic_private void unicode_to_qutf8(unicode_t u, uint8_t buf[5]); +ic_private unicode_t unicode_from_qutf8(const uint8_t* s, ssize_t len, ssize_t* nread); // validating + +ic_private unicode_t unicode_from_raw(uint8_t c); +ic_private bool unicode_is_raw(unicode_t u, uint8_t* c); + +ic_private bool utf8_is_cont(uint8_t c); + + +//------------------------------------------------------------- +// Colors +//------------------------------------------------------------- + +// A color is either RGB or an ANSI code. +// (RGB colors have bit 24 set to distinguish them from the ANSI color palette colors.) +// (Isocline will automatically convert from RGB on terminals that do not support full colors) +typedef uint32_t ic_color_t; + +// Create a color from a 24-bit color value. +ic_private ic_color_t ic_rgb(uint32_t hex); + +// Create a color from a 8-bit red/green/blue components. +// The value of each component is capped between 0 and 255. +ic_private ic_color_t ic_rgbx(ssize_t r, ssize_t g, ssize_t b); + +#define IC_COLOR_NONE (0) +#define IC_RGB(rgb) (0x1000000 | (uint32_t)(rgb)) // ic_rgb(rgb) // define to it can be used as a constant + +// ANSI colors. +// The actual colors used is usually determined by the terminal theme +// See +#define IC_ANSI_BLACK (30) +#define IC_ANSI_MAROON (31) +#define IC_ANSI_GREEN (32) +#define IC_ANSI_OLIVE (33) +#define IC_ANSI_NAVY (34) +#define IC_ANSI_PURPLE (35) +#define IC_ANSI_TEAL (36) +#define IC_ANSI_SILVER (37) +#define IC_ANSI_DEFAULT (39) + +#define IC_ANSI_GRAY (90) +#define IC_ANSI_RED (91) +#define IC_ANSI_LIME (92) +#define IC_ANSI_YELLOW (93) +#define IC_ANSI_BLUE (94) +#define IC_ANSI_FUCHSIA (95) +#define IC_ANSI_AQUA (96) +#define IC_ANSI_WHITE (97) + +#define IC_ANSI_DARKGRAY IC_ANSI_GRAY +#define IC_ANSI_LIGHTGRAY IC_ANSI_SILVER +#define IC_ANSI_MAGENTA IC_ANSI_FUCHSIA +#define IC_ANSI_CYAN IC_ANSI_AQUA + + + +//------------------------------------------------------------- +// Debug +//------------------------------------------------------------- + +#if defined(IC_NO_DEBUG_MSG) +#define debug_msg(fmt,...) (void)(0) +#else +ic_private void debug_msg( const char* fmt, ... ); +#endif + + +//------------------------------------------------------------- +// Abstract environment +//------------------------------------------------------------- +struct ic_env_s; +typedef struct ic_env_s ic_env_t; + + +//------------------------------------------------------------- +// Allocation +//------------------------------------------------------------- + +typedef struct alloc_s { + ic_malloc_fun_t* malloc; + ic_realloc_fun_t* realloc; + ic_free_fun_t* free; +} alloc_t; + + +ic_private void* mem_malloc( alloc_t* mem, ssize_t sz ); +ic_private void* mem_zalloc( alloc_t* mem, ssize_t sz ); +ic_private void* mem_realloc( alloc_t* mem, void* p, ssize_t newsz ); +ic_private void mem_free( alloc_t* mem, const void* p ); +ic_private char* mem_strdup( alloc_t* mem, const char* s); +ic_private char* mem_strndup( alloc_t* mem, const char* s, ssize_t n); + +#define mem_zalloc_tp(mem,tp) (tp*)mem_zalloc(mem,ssizeof(tp)) +#define mem_malloc_tp_n(mem,tp,n) (tp*)mem_malloc(mem,(n)*ssizeof(tp)) +#define mem_zalloc_tp_n(mem,tp,n) (tp*)mem_zalloc(mem,(n)*ssizeof(tp)) +#define mem_realloc_tp(mem,tp,p,n) (tp*)mem_realloc(mem,p,(n)*ssizeof(tp)) + + +#endif // IC_COMMON_H diff --git a/extern/isocline/src/completers.c b/extern/isocline/src/completers.c new file mode 100644 index 000000000..e9701c166 --- /dev/null +++ b/extern/isocline/src/completers.c @@ -0,0 +1,675 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#include +#include + +#include "../include/isocline.h" +#include "common.h" +#include "env.h" +#include "stringbuf.h" +#include "completions.h" + + + +//------------------------------------------------------------- +// Word completion +//------------------------------------------------------------- + +// free variables for word completion +typedef struct word_closure_s { + long delete_before_adjust; + void* prev_env; + ic_completion_fun_t* prev_complete; +} word_closure_t; + + +// word completion callback +static bool token_add_completion_ex(ic_env_t* env, void* closure, const char* replacement, const char* display, const char* help, long delete_before, long delete_after) { + word_closure_t* wenv = (word_closure_t*)(closure); + // call the previous completer with an adjusted delete-before + return (*wenv->prev_complete)(env, wenv->prev_env, replacement, display, help, wenv->delete_before_adjust + delete_before, delete_after); +} + + +ic_public void ic_complete_word(ic_completion_env_t* cenv, const char* prefix, ic_completer_fun_t* fun, + ic_is_char_class_fun_t* is_word_char) +{ + if (is_word_char == NULL) is_word_char = &ic_char_is_nonseparator; + + ssize_t len = ic_strlen(prefix); + ssize_t pos = len; // will be start of the 'word' (excluding a potential start quote) + while (pos > 0) { + // go back one code point + ssize_t ofs = str_prev_ofs(prefix, pos, NULL); + if (ofs <= 0) break; + if (!(*is_word_char)(prefix + (pos - ofs), (long)ofs)) { + break; + } + pos -= ofs; + } + if (pos < 0) { pos = 0; } + + // stop if empty word + // if (len == pos) return; + + // set up the closure + word_closure_t wenv; + wenv.delete_before_adjust = (long)(len - pos); + wenv.prev_complete = cenv->complete; + wenv.prev_env = cenv->env; + cenv->complete = &token_add_completion_ex; + cenv->closure = &wenv; + + // and call the user completion routine + (*fun)(cenv, prefix + pos); + + // restore the original environment + cenv->complete = wenv.prev_complete; + cenv->closure = wenv.prev_env; +} + + +//------------------------------------------------------------- +// Quoted word completion (with escape characters) +//------------------------------------------------------------- + +// free variables for word completion +typedef struct qword_closure_s { + char escape_char; + char quote; + long delete_before_adjust; + stringbuf_t* sbuf; + void* prev_env; + ic_is_char_class_fun_t* is_word_char; + ic_completion_fun_t* prev_complete; +} qword_closure_t; + + +// word completion callback +static bool qword_add_completion_ex(ic_env_t* env, void* closure, const char* replacement, const char* display, const char* help, + long delete_before, long delete_after) { + qword_closure_t* wenv = (qword_closure_t*)(closure); + sbuf_replace( wenv->sbuf, replacement ); + if (wenv->quote != 0) { + // add end quote + sbuf_append_char( wenv->sbuf, wenv->quote); + } + else { + // escape non-word characters if it was not quoted + ssize_t pos = 0; + ssize_t next; + while ( (next = sbuf_next_ofs(wenv->sbuf, pos, NULL)) > 0 ) + { + if (!(*wenv->is_word_char)(sbuf_string(wenv->sbuf) + pos, (long)next)) { // strchr(wenv->non_word_char, sbuf_char_at( wenv->sbuf, pos )) != NULL) { + sbuf_insert_char_at( wenv->sbuf, wenv->escape_char, pos); + pos++; + } + pos += next; + } + } + // and call the previous completion function + return (*wenv->prev_complete)( env, wenv->prev_env, sbuf_string(wenv->sbuf), display, help, wenv->delete_before_adjust + delete_before, delete_after ); +} + + +ic_public void ic_complete_qword( ic_completion_env_t* cenv, const char* prefix, ic_completer_fun_t* fun, ic_is_char_class_fun_t* is_word_char ) { + ic_complete_qword_ex( cenv, prefix, fun, is_word_char, '\\', NULL); +} + + +ic_public void ic_complete_qword_ex( ic_completion_env_t* cenv, const char* prefix, ic_completer_fun_t* fun, + ic_is_char_class_fun_t* is_word_char, char escape_char, const char* quote_chars ) { + if (is_word_char == NULL) is_word_char = &ic_char_is_nonseparator ; + if (quote_chars == NULL) quote_chars = "'\""; + + ssize_t len = ic_strlen(prefix); + ssize_t pos; // will be start of the 'word' (excluding a potential start quote) + char quote = 0; + ssize_t quote_len = 0; + + // 1. look for a starting quote + if (quote_chars[0] != 0) { + // we go forward and count all quotes; if it is uneven, we need to complete quoted. + ssize_t qpos_open = -1; + ssize_t qpos_close = -1; + ssize_t qcount = 0; + pos = 0; + while(pos < len) { + if (prefix[pos] == escape_char && prefix[pos+1] != 0 && + !(*is_word_char)(prefix + pos + 1, 1)) // strchr(non_word_char, prefix[pos+1]) != NULL + { + pos++; // skip escape and next char + } + else if (qcount % 2 == 0 && strchr(quote_chars, prefix[pos]) != NULL) { + // open quote + qpos_open = pos; + quote = prefix[pos]; + qcount++; + } + else if (qcount % 2 == 1 && prefix[pos] == quote) { + // close quote + qpos_close = pos; + qcount++; + } + else if (!(*is_word_char)(prefix + pos, 1)) { // strchr(non_word_char, prefix[pos]) != NULL) { + qpos_close = -1; + } + ssize_t ofs = str_next_ofs( prefix, len, pos, NULL ); + if (ofs <= 0) break; + pos += ofs; + } + if ((qcount % 2 == 0 && qpos_close >= 0) || // if the last quote is only followed by word chars, we still complete it + (qcount % 2 == 1)) // opening quote found + { + quote_len = (len - qpos_open - 1); + pos = qpos_open + 1; // pos points to the word start just after the quote. + } + else { + quote = 0; + } + } + + // 2. if we did not find a quoted word, look for non-word-chars + if (quote == 0) { + pos = len; + while(pos > 0) { + // go back one code point + ssize_t ofs = str_prev_ofs(prefix, pos, NULL ); + if (ofs <= 0) break; + if (!(*is_word_char)(prefix + (pos - ofs), (long)ofs)) { // strchr(non_word_char, prefix[pos - ofs]) != NULL) { + // non word char, break if it is not escaped + if (pos <= ofs || prefix[pos - ofs - 1] != escape_char) break; + // otherwise go on + pos--; // skip escaped char + } + pos -= ofs; + } + } + + // stop if empty word + // if (len == pos) return; + + // allocate new unescaped word prefix + char* word = mem_strndup( cenv->env->mem, prefix + pos, (quote==0 ? len - pos : quote_len)); + if (word == NULL) return; + + if (quote == 0) { + // unescape prefix + ssize_t wlen = len - pos; + ssize_t wpos = 0; + while (wpos < wlen) { + ssize_t ofs = str_next_ofs(word, wlen, wpos, NULL); + if (ofs <= 0) break; + if (word[wpos] == escape_char && word[wpos+1] != 0 && + !(*is_word_char)(word + wpos + 1, (long)ofs)) // strchr(non_word_char, word[wpos+1]) != NULL) { + { + ic_memmove(word + wpos, word + wpos + 1, wlen - wpos /* including 0 */); + } + wpos += ofs; + } + } + #ifdef _WIN32 + else { + // remove inner quote: "c:\Program Files\"Win + ssize_t wlen = len - pos; + ssize_t wpos = 0; + while (wpos < wlen) { + ssize_t ofs = str_next_ofs(word, wlen, wpos, NULL); + if (ofs <= 0) break; + if (word[wpos] == escape_char && word[wpos+1] == quote) { + word[wpos+1] = escape_char; + ic_memmove(word + wpos, word + wpos + 1, wlen - wpos /* including 0 */); + } + wpos += ofs; + } + } + #endif + + // set up the closure + qword_closure_t wenv; + wenv.quote = quote; + wenv.is_word_char = is_word_char; + wenv.escape_char = escape_char; + wenv.delete_before_adjust = (long)(len - pos); + wenv.prev_complete = cenv->complete; + wenv.prev_env = cenv->env; + wenv.sbuf = sbuf_new(cenv->env->mem); + if (wenv.sbuf == NULL) { mem_free(cenv->env->mem, word); return; } + cenv->complete = &qword_add_completion_ex; + cenv->closure = &wenv; + + // and call the user completion routine + (*fun)( cenv, word ); + + // restore the original environment + cenv->complete = wenv.prev_complete; + cenv->closure = wenv.prev_env; + + sbuf_free(wenv.sbuf); + mem_free(cenv->env->mem, word); +} + + + + +//------------------------------------------------------------- +// Complete file names +// Listing files +//------------------------------------------------------------- +#include + +typedef enum file_type_e { + // must follow BSD style LSCOLORS order + FT_DEFAULT = 0, + FT_DIR, + FT_SYM, + FT_SOCK, + FT_PIPE, + FT_BLOCK, + FT_CHAR, + FT_SETUID, + FT_SETGID, + FT_DIR_OW_STICKY, + FT_DIR_OW, + FT_DIR_STICKY, + FT_EXE, + FT_LAST +} file_type_t; + +static int cli_color; // 1 enabled, 0 not initialized, -1 disabled +static const char* lscolors = "exfxcxdxbxegedabagacad"; // default BSD setting +static const char* ls_colors; +static const char* ls_colors_names[] = { "no=","di=","ln=","so=","pi=","bd=","cd=","su=","sg=","tw=","ow=","st=","ex=", NULL }; + +static bool ls_colors_init(void) { + if (cli_color != 0) return (cli_color >= 1); + // colors enabled? + const char* s = getenv("CLICOLOR"); + if (s==NULL || (strcmp(s, "1")!=0 && strcmp(s, "") != 0)) { + cli_color = -1; + return false; + } + cli_color = 1; + s = getenv("LS_COLORS"); + if (s != NULL) { ls_colors = s; } + s = getenv("LSCOLORS"); + if (s != NULL) { lscolors = s; } + return true; +} + +static bool ls_valid_esc(ssize_t c) { + return ((c==0 || c==1 || c==4 || c==7 || c==22 || c==24 || c==27) || + (c >= 30 && c <= 37) || (c >= 40 && c <= 47) || + (c >= 90 && c <= 97) || (c >= 100 && c <= 107)); +} + +static bool ls_colors_from_key(stringbuf_t* sb, const char* key) { + // find key + ssize_t keylen = ic_strlen(key); + if (keylen <= 0) return false; + const char* p = strstr(ls_colors, key); + if (p == NULL) return false; + p += keylen; + if (key[keylen-1] != '=') { + if (*p != '=') return false; + p++; + } + ssize_t len = 0; + while (p[len] != 0 && p[len] != ':') { + len++; + } + if (len <= 0) return false; + sbuf_append(sb, "[ansi-sgr=\"" ); + sbuf_append_n(sb, p, len ); + sbuf_append(sb, "\"]"); + return true; +} + +static int ls_colors_from_char(char c) { + if (c >= 'a' && c <= 'h') { return (c - 'a'); } + else if (c >= 'A' && c <= 'H') { return (c - 'A') + 8; } + else if (c == 'x') { return 256; } + else return 256; // default +} + +static bool ls_colors_append(stringbuf_t* sb, file_type_t ft, const char* ext) { + if (!ls_colors_init()) return false; + if (ls_colors != NULL) { + // GNU style + if (ft == FT_DEFAULT && ext != NULL) { + // first try extension match + if (ls_colors_from_key(sb, ext)) return true; + } + if (ft >= FT_DEFAULT && ft < FT_LAST) { + // then a filetype match + const char* key = ls_colors_names[ft]; + if (ls_colors_from_key(sb, key)) return true; + } + } + else if (lscolors != NULL) { + // BSD style + char fg = 'x'; + char bg = 'x'; + if (ic_strlen(lscolors) > (2*(ssize_t)ft)+1) { + fg = lscolors[2*ft]; + bg = lscolors[2*ft + 1]; + } + sbuf_appendf(sb, "[ansi-color=%d ansi-bgcolor=%d]", ls_colors_from_char(fg), ls_colors_from_char(bg) ); + return true; + } + return false; +} + +static void ls_colorize(bool no_lscolor, stringbuf_t* sb, file_type_t ft, const char* name, const char* ext, char dirsep) { + bool close = (no_lscolor ? false : ls_colors_append( sb, ft, ext)); + sbuf_append(sb, "[!pre]" ); + sbuf_append(sb, name); + if (dirsep != 0) sbuf_append_char(sb, dirsep); + sbuf_append(sb,"[/pre]" ); + if (close) { sbuf_append(sb, "[/]"); } +} + +#if defined(_WIN32) +#include +#include + +static bool os_is_dir(const char* cpath) { + struct _stat64 st = { 0 }; + _stat64(cpath, &st); + return ((st.st_mode & _S_IFDIR) != 0); +} + +static file_type_t os_get_filetype(const char* cpath) { + struct _stat64 st = { 0 }; + _stat64(cpath, &st); + if (((st.st_mode) & _S_IFDIR) != 0) return FT_DIR; + if (((st.st_mode) & _S_IFCHR) != 0) return FT_CHAR; + if (((st.st_mode) & _S_IFIFO) != 0) return FT_PIPE; + if (((st.st_mode) & _S_IEXEC) != 0) return FT_EXE; + return FT_DEFAULT; +} + + +#define dir_cursor intptr_t +#define dir_entry struct __finddata64_t + +static bool os_findfirst(alloc_t* mem, const char* path, dir_cursor* d, dir_entry* entry) { + stringbuf_t* spath = sbuf_new(mem); + if (spath == NULL) return false; + sbuf_append(spath, path); + sbuf_append(spath, "\\*"); + *d = _findfirsti64(sbuf_string(spath), entry); + mem_free(mem,spath); + return (*d != -1); +} + +static bool os_findnext(dir_cursor d, dir_entry* entry) { + return (_findnexti64(d, entry) == 0); +} + +static void os_findclose(dir_cursor d) { + _findclose(d); +} + +static const char* os_direntry_name(dir_entry* entry) { + return entry->name; +} + +static bool os_path_is_absolute( const char* path ) { + if (path != NULL && path[0] != 0 && path[1] == ':' && (path[2] == '\\' || path[2] == '/' || path[2] == 0)) { + char drive = path[0]; + return ((drive >= 'A' && drive <= 'Z') || (drive >= 'a' && drive <= 'z')); + } + else return false; +} + +ic_private char ic_dirsep(void) { + return '\\'; +} +#else + +#include +#include +#include +#include + +static bool os_is_dir(const char* cpath) { + struct stat st; + memset(&st, 0, sizeof(st)); + stat(cpath, &st); + return (S_ISDIR(st.st_mode)); +} + +static file_type_t os_get_filetype(const char* cpath) { + struct stat st; + memset(&st, 0, sizeof(st)); + lstat(cpath, &st); + switch ((st.st_mode)&S_IFMT) { + case S_IFSOCK: return FT_SOCK; + case S_IFLNK: { + return FT_SYM; + } + case S_IFIFO: return FT_PIPE; + case S_IFCHR: return FT_CHAR; + case S_IFBLK: return FT_BLOCK; + case S_IFDIR: { + if ((st.st_mode & S_ISUID) != 0) return FT_SETUID; + if ((st.st_mode & S_ISGID) != 0) return FT_SETGID; + if ((st.st_mode & S_IWGRP) != 0 && (st.st_mode & S_ISVTX) != 0) return FT_DIR_OW_STICKY; + if ((st.st_mode & S_IWGRP)) return FT_DIR_OW; + if ((st.st_mode & S_ISVTX)) return FT_DIR_STICKY; + return FT_DIR; + } + case S_IFREG: + default: { + if ((st.st_mode & S_IXUSR) != 0) return FT_EXE; + return FT_DEFAULT; + } + } +} + + +#define dir_cursor DIR* +#define dir_entry struct dirent* + +static bool os_findnext(dir_cursor d, dir_entry* entry) { + *entry = readdir(d); + return (*entry != NULL); +} + +static bool os_findfirst(alloc_t* mem, const char* cpath, dir_cursor* d, dir_entry* entry) { + ic_unused(mem); + *d = opendir(cpath); + if (*d == NULL) { + return false; + } + else { + return os_findnext(*d, entry); + } +} + +static void os_findclose(dir_cursor d) { + closedir(d); +} + +static const char* os_direntry_name(dir_entry* entry) { + return (*entry)->d_name; +} + +static bool os_path_is_absolute( const char* path ) { + return (path != NULL && path[0] == '/'); +} + +ic_private char ic_dirsep(void) { + return '/'; +} +#endif + + + +//------------------------------------------------------------- +// File completion +//------------------------------------------------------------- + +static bool ends_with_n(const char* name, ssize_t name_len, const char* ending, ssize_t len) { + if (name_len < len) return false; + if (ending == NULL || len <= 0) return true; + for (ssize_t i = 1; i <= len; i++) { + char c1 = name[name_len - i]; + char c2 = ending[len - i]; + #ifdef _WIN32 + if (ic_tolower(c1) != ic_tolower(c2)) return false; + #else + if (c1 != c2) return false; + #endif + } + return true; +} + +static bool match_extension(const char* name, const char* extensions) { + if (extensions == NULL || extensions[0] == 0) return true; + if (name == NULL) return false; + ssize_t name_len = ic_strlen(name); + ssize_t len = ic_strlen(extensions); + ssize_t cur = 0; + //debug_msg("match extensions: %s ~ %s", name, extensions); + for (ssize_t end = 0; end <= len; end++) { + if (extensions[end] == ';' || extensions[end] == 0) { + if (ends_with_n(name, name_len, extensions+cur, (end - cur))) { + return true; + } + cur = end+1; + } + } + return false; +} + +static bool filename_complete_indir( ic_completion_env_t* cenv, stringbuf_t* dir, + stringbuf_t* dir_prefix, stringbuf_t* display, + const char* base_prefix, + char dir_sep, const char* extensions ) +{ + dir_cursor d = 0; + dir_entry entry; + bool cont = true; + if (os_findfirst(cenv->env->mem, sbuf_string(dir), &d, &entry)) { + do { + const char* name = os_direntry_name(&entry); + if (name != NULL && strcmp(name, ".") != 0 && strcmp(name, "..") != 0 && + ic_istarts_with(name, base_prefix)) + { + // possible match, first check if it is a directory + file_type_t ft; + bool isdir; + const ssize_t plen = sbuf_len(dir_prefix); + sbuf_append(dir_prefix, name); + { // check directory and potentially add a dirsep to the dir_prefix + const ssize_t dlen = sbuf_len(dir); + sbuf_append_char(dir,ic_dirsep()); + sbuf_append(dir,name); + ft = os_get_filetype(sbuf_string(dir)); + isdir = os_is_dir(sbuf_string(dir)); + if (isdir && dir_sep != 0) { + sbuf_append_char(dir_prefix,dir_sep); + } + sbuf_delete_from(dir,dlen); // restore dir + } + if (isdir || match_extension(name, extensions)) { + // add completion + sbuf_clear(display); + ls_colorize(cenv->env->no_lscolors, display, ft, name, NULL, (isdir ? dir_sep : 0)); + cont = ic_add_completion_ex(cenv, sbuf_string(dir_prefix), sbuf_string(display), NULL); + } + sbuf_delete_from( dir_prefix, plen ); // restore dir_prefix + } + } while (cont && os_findnext(d, &entry)); + os_findclose(d); + } + return cont; +} + +typedef struct filename_closure_s { + const char* roots; + const char* extensions; + char dir_sep; +} filename_closure_t; + +static void filename_completer( ic_completion_env_t* cenv, const char* prefix ) { + if (prefix == NULL) return; + filename_closure_t* fclosure = (filename_closure_t*)cenv->arg; + stringbuf_t* root_dir = sbuf_new(cenv->env->mem); + stringbuf_t* dir_prefix = sbuf_new(cenv->env->mem); + stringbuf_t* display = sbuf_new(cenv->env->mem); + if (root_dir!=NULL && dir_prefix != NULL && display != NULL) + { + // split prefix in dir_prefix / base. + const char* base = strrchr(prefix,'/'); + #ifdef _WIN32 + const char* base2 = strrchr(prefix,'\\'); + if (base == NULL || base2 > base) base = base2; + #endif + if (base != NULL) { + base++; + sbuf_append_n(dir_prefix, prefix, base - prefix ); // includes dir separator + } + + // absolute path + if (os_path_is_absolute(prefix)) { + // do not use roots but try to complete directly + if (base != NULL) { + sbuf_append_n( root_dir, prefix, (base - prefix)); // include dir separator + } + filename_complete_indir( cenv, root_dir, dir_prefix, display, + (base != NULL ? base : prefix), + fclosure->dir_sep, fclosure->extensions ); + } + else { + // relative path, complete with respect to every root. + const char* next; + const char* root = fclosure->roots; + while ( root != NULL ) { + // create full root in `root_dir` + sbuf_clear(root_dir); + next = strchr(root,';'); + if (next == NULL) { + sbuf_append( root_dir, root ); + root = NULL; + } + else { + sbuf_append_n( root_dir, root, next - root ); + root = next + 1; + } + sbuf_append_char( root_dir, ic_dirsep()); + + // add the dir_prefix to the root + if (base != NULL) { + sbuf_append_n( root_dir, prefix, (base - prefix) - 1); + } + + // and complete in this directory + filename_complete_indir( cenv, root_dir, dir_prefix, display, + (base != NULL ? base : prefix), + fclosure->dir_sep, fclosure->extensions); + } + } + } + sbuf_free(display); + sbuf_free(root_dir); + sbuf_free(dir_prefix); +} + +ic_public void ic_complete_filename( ic_completion_env_t* cenv, const char* prefix, char dir_sep, const char* roots, const char* extensions ) { + if (roots == NULL) roots = "."; + if (extensions == NULL) extensions = ""; + if (dir_sep == 0) dir_sep = ic_dirsep(); + filename_closure_t fclosure; + fclosure.dir_sep = dir_sep; + fclosure.roots = roots; + fclosure.extensions = extensions; + cenv->arg = &fclosure; + ic_complete_qword_ex( cenv, prefix, &filename_completer, &ic_char_is_filename_letter, '\\', "'\""); +} diff --git a/extern/isocline/src/completions.c b/extern/isocline/src/completions.c new file mode 100644 index 000000000..01453efc7 --- /dev/null +++ b/extern/isocline/src/completions.c @@ -0,0 +1,326 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#include +#include +#include + +#include "../include/isocline.h" +#include "common.h" +#include "env.h" +#include "stringbuf.h" +#include "completions.h" + + +//------------------------------------------------------------- +// Completions +//------------------------------------------------------------- + +typedef struct completion_s { + const char* replacement; + const char* display; + const char* help; + ssize_t delete_before; + ssize_t delete_after; +} completion_t; + +struct completions_s { + ic_completer_fun_t* completer; + void* completer_arg; + ssize_t completer_max; + ssize_t count; + ssize_t len; + completion_t* elems; + alloc_t* mem; +}; + +static void default_filename_completer( ic_completion_env_t* cenv, const char* prefix ); + +ic_private completions_t* completions_new(alloc_t* mem) { + completions_t* cms = mem_zalloc_tp(mem, completions_t); + if (cms == NULL) return NULL; + cms->mem = mem; + cms->completer = &default_filename_completer; + return cms; +} + +ic_private void completions_free(completions_t* cms) { + if (cms == NULL) return; + completions_clear(cms); + if (cms->elems != NULL) { + mem_free(cms->mem, cms->elems); + cms->elems = NULL; + cms->count = 0; + cms->len = 0; + } + mem_free(cms->mem, cms); // free ourselves +} + + +ic_private void completions_clear(completions_t* cms) { + while (cms->count > 0) { + completion_t* cm = cms->elems + cms->count - 1; + mem_free( cms->mem, cm->display); + mem_free( cms->mem, cm->replacement); + mem_free( cms->mem, cm->help); + memset(cm,0,sizeof(*cm)); + cms->count--; + } +} + +static void completions_push(completions_t* cms, const char* replacement, const char* display, const char* help, ssize_t delete_before, ssize_t delete_after) +{ + if (cms->count >= cms->len) { + ssize_t newlen = (cms->len <= 0 ? 32 : cms->len*2); + completion_t* newelems = mem_realloc_tp(cms->mem, completion_t, cms->elems, newlen ); + if (newelems == NULL) return; + cms->elems = newelems; + cms->len = newlen; + } + assert(cms->count < cms->len); + completion_t* cm = cms->elems + cms->count; + cm->replacement = mem_strdup(cms->mem,replacement); + cm->display = mem_strdup(cms->mem,display); + cm->help = mem_strdup(cms->mem,help); + cm->delete_before = delete_before; + cm->delete_after = delete_after; + cms->count++; +} + +ic_private ssize_t completions_count(completions_t* cms) { + return cms->count; +} + +static bool completions_contains(completions_t* cms, const char* replacement) { + for( ssize_t i = 0; i < cms->count; i++ ) { + const completion_t* c = cms->elems + i; + if (strcmp(replacement,c->replacement) == 0) { return true; } + } + return false; +} + +ic_private bool completions_add(completions_t* cms, const char* replacement, const char* display, const char* help, ssize_t delete_before, ssize_t delete_after) { + if (cms->completer_max <= 0) return false; + cms->completer_max--; + //debug_msg("completion: add: %d,%d, %s\n", delete_before, delete_after, replacement); + if (!completions_contains(cms,replacement)) { + completions_push(cms, replacement, display, help, delete_before, delete_after); + } + return true; +} + +static completion_t* completions_get(completions_t* cms, ssize_t index) { + if (index < 0 || cms->count <= 0 || index >= cms->count) return NULL; + return &cms->elems[index]; +} + +ic_private const char* completions_get_display( completions_t* cms, ssize_t index, const char** help ) { + if (help != NULL) { *help = NULL; } + completion_t* cm = completions_get(cms, index); + if (cm == NULL) return NULL; + if (help != NULL) { *help = cm->help; } + return (cm->display != NULL ? cm->display : cm->replacement); +} + +ic_private const char* completions_get_help( completions_t* cms, ssize_t index ) { + completion_t* cm = completions_get(cms, index); + if (cm == NULL) return NULL; + return cm->help; +} + +ic_private const char* completions_get_hint(completions_t* cms, ssize_t index, const char** help) { + if (help != NULL) { *help = NULL; } + completion_t* cm = completions_get(cms, index); + if (cm == NULL) return NULL; + ssize_t len = ic_strlen(cm->replacement); + if (len < cm->delete_before) return NULL; + const char* hint = (cm->replacement + cm->delete_before); + if (*hint == 0 || utf8_is_cont((uint8_t)(*hint))) return NULL; // utf8 boundary? + if (help != NULL) { *help = cm->help; } + return hint; +} + +ic_private void completions_set_completer(completions_t* cms, ic_completer_fun_t* completer, void* arg) { + cms->completer = completer; + cms->completer_arg = arg; +} + +ic_private void completions_get_completer(completions_t* cms, ic_completer_fun_t** completer, void** arg) { + *completer = cms->completer; + *arg = cms->completer_arg; +} + + +ic_public void* ic_completion_arg( const ic_completion_env_t* cenv ) { + return (cenv == NULL ? NULL : cenv->env->completions->completer_arg); +} + +ic_public bool ic_has_completions( const ic_completion_env_t* cenv ) { + return (cenv == NULL ? false : cenv->env->completions->count > 0); +} + +ic_public bool ic_stop_completing( const ic_completion_env_t* cenv) { + return (cenv == NULL ? true : cenv->env->completions->completer_max <= 0); +} + + +static ssize_t completion_apply( completion_t* cm, stringbuf_t* sbuf, ssize_t pos ) { + if (cm == NULL) return -1; + debug_msg( "completion: apply: %s at %zd\n", cm->replacement, pos); + ssize_t start = pos - cm->delete_before; + if (start < 0) start = 0; + ssize_t n = cm->delete_before + cm->delete_after; + if (ic_strlen(cm->replacement) == n && strncmp(sbuf_string_at(sbuf,start), cm->replacement, to_size_t(n)) == 0) { + // no changes + return -1; + } + else { + sbuf_delete_from_to( sbuf, start, pos + cm->delete_after ); + return sbuf_insert_at(sbuf, cm->replacement, start); + } +} + +ic_private ssize_t completions_apply( completions_t* cms, ssize_t index, stringbuf_t* sbuf, ssize_t pos ) { + completion_t* cm = completions_get(cms, index); + return completion_apply( cm, sbuf, pos ); +} + + +static int completion_compare(const void* p1, const void* p2) { + if (p1 == NULL || p2 == NULL) return 0; + const completion_t* cm1 = (const completion_t*)p1; + const completion_t* cm2 = (const completion_t*)p2; + return ic_stricmp(cm1->replacement, cm2->replacement); +} + +ic_private void completions_sort(completions_t* cms) { + if (cms->count <= 0) return; + qsort(cms->elems, to_size_t(cms->count), sizeof(cms->elems[0]), &completion_compare); +} + +#define IC_MAX_PREFIX (256) + +// find longest common prefix and complete with that. +ic_private ssize_t completions_apply_longest_prefix(completions_t* cms, stringbuf_t* sbuf, ssize_t pos) { + if (cms->count <= 1) { + return completions_apply(cms,0,sbuf,pos); + } + + // set initial prefix to the first entry + completion_t* cm = completions_get(cms, 0); + if (cm == NULL) return -1; + + char prefix[IC_MAX_PREFIX+1]; + ssize_t delete_before = cm->delete_before; + ic_strncpy( prefix, IC_MAX_PREFIX+1, cm->replacement, IC_MAX_PREFIX ); + prefix[IC_MAX_PREFIX] = 0; + + // and visit all others to find the longest common prefix + for(ssize_t i = 1; i < cms->count; i++) { + cm = completions_get(cms,i); + if (cm->delete_before != delete_before) { // deletions must match delete_before + prefix[0] = 0; + break; + } + // check if it is still a prefix + const char* r = cm->replacement; + ssize_t j; + for(j = 0; prefix[j] != 0 && r[j] != 0; j++) { + if (prefix[j] != r[j]) break; + } + prefix[j] = 0; + if (j <= 0) break; + } + + // check the length + ssize_t len = ic_strlen(prefix); + if (len <= 0 || len < delete_before) return -1; + + // we found a prefix :-) + completion_t cprefix; + memset(&cprefix,0,sizeof(cprefix)); + cprefix.delete_before = delete_before; + cprefix.replacement = prefix; + ssize_t newpos = completion_apply( &cprefix, sbuf, pos); + if (newpos < 0) return newpos; + + // adjust all delete_before for the new replacement + for( ssize_t i = 0; i < cms->count; i++) { + cm = completions_get(cms,i); + cm->delete_before = len; + } + + return newpos; +} + + +//------------------------------------------------------------- +// Completer functions +//------------------------------------------------------------- + +ic_public bool ic_add_completions(ic_completion_env_t* cenv, const char* prefix, const char** completions) { + for (const char** pc = completions; *pc != NULL; pc++) { + if (ic_istarts_with(*pc, prefix)) { + if (!ic_add_completion_ex(cenv, *pc, NULL, NULL)) return false; + } + } + return true; +} + +ic_public bool ic_add_completion(ic_completion_env_t* cenv, const char* replacement) { + return ic_add_completion_ex(cenv, replacement, NULL, NULL); +} + +ic_public bool ic_add_completion_ex( ic_completion_env_t* cenv, const char* replacement, const char* display, const char* help ) { + return ic_add_completion_prim(cenv,replacement,display,help,0,0); +} + +ic_public bool ic_add_completion_prim(ic_completion_env_t* cenv, const char* replacement, const char* display, const char* help, long delete_before, long delete_after) { + return (*cenv->complete)(cenv->env, cenv->closure, replacement, display, help, delete_before, delete_after ); +} + +static bool prim_add_completion(ic_env_t* env, void* funenv, const char* replacement, const char* display, const char* help, long delete_before, long delete_after) { + ic_unused(funenv); + return completions_add(env->completions, replacement, display, help, delete_before, delete_after); +} + +ic_public void ic_set_default_completer(ic_completer_fun_t* completer, void* arg) { + ic_env_t* env = ic_get_env(); if (env == NULL) return; + completions_set_completer(env->completions, completer, arg); +} + +ic_private ssize_t completions_generate(struct ic_env_s* env, completions_t* cms, const char* input, ssize_t pos, ssize_t max) { + completions_clear(cms); + if (cms->completer == NULL || input == NULL || ic_strlen(input) < pos) return 0; + + // set up env + ic_completion_env_t cenv; + cenv.env = env; + cenv.input = input, + cenv.cursor = (long)pos; + cenv.arg = cms->completer_arg; + cenv.complete = &prim_add_completion; + cenv.closure = NULL; + const char* prefix = mem_strndup(cms->mem, input, pos); + cms->completer_max = max; + + // and complete + cms->completer(&cenv,prefix); + + // restore + mem_free(cms->mem,prefix); + return completions_count(cms); +} + +// The default completer is no completion is set +static void default_filename_completer( ic_completion_env_t* cenv, const char* prefix ) { + #ifdef _WIN32 + const char sep = '\\'; + #else + const char sep = '/'; + #endif + ic_complete_filename( cenv, prefix, sep, ".", NULL); +} diff --git a/extern/isocline/src/completions.h b/extern/isocline/src/completions.h new file mode 100644 index 000000000..8361d5078 --- /dev/null +++ b/extern/isocline/src/completions.h @@ -0,0 +1,52 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_COMPLETIONS_H +#define IC_COMPLETIONS_H + +#include "common.h" +#include "stringbuf.h" + + +//------------------------------------------------------------- +// Completions +//------------------------------------------------------------- +#define IC_MAX_COMPLETIONS_TO_SHOW (1000) +#define IC_MAX_COMPLETIONS_TO_TRY (IC_MAX_COMPLETIONS_TO_SHOW/4) + +typedef struct completions_s completions_t; + +ic_private completions_t* completions_new(alloc_t* mem); +ic_private void completions_free(completions_t* cms); +ic_private void completions_clear(completions_t* cms); +ic_private bool completions_add(completions_t* cms , const char* replacement, const char* display, const char* help, ssize_t delete_before, ssize_t delete_after); +ic_private ssize_t completions_count(completions_t* cms); +ic_private ssize_t completions_generate(struct ic_env_s* env, completions_t* cms , const char* input, ssize_t pos, ssize_t max); +ic_private void completions_sort(completions_t* cms); +ic_private void completions_set_completer(completions_t* cms, ic_completer_fun_t* completer, void* arg); +ic_private const char* completions_get_display(completions_t* cms , ssize_t index, const char** help); +ic_private const char* completions_get_hint(completions_t* cms, ssize_t index, const char** help); +ic_private void completions_get_completer(completions_t* cms, ic_completer_fun_t** completer, void** arg); + +ic_private ssize_t completions_apply(completions_t* cms, ssize_t index, stringbuf_t* sbuf, ssize_t pos); +ic_private ssize_t completions_apply_longest_prefix(completions_t* cms, stringbuf_t* sbuf, ssize_t pos); + +//------------------------------------------------------------- +// Completion environment +//------------------------------------------------------------- +typedef bool (ic_completion_fun_t)( ic_env_t* env, void* funenv, const char* replacement, const char* display, const char* help, long delete_before, long delete_after ); + +struct ic_completion_env_s { + ic_env_t* env; // the isocline environment + const char* input; // current full input + long cursor; // current cursor position + void* arg; // argument given to `ic_set_completer` + void* closure; // free variables for function composition + ic_completion_fun_t* complete; // function that adds a completion +}; + +#endif // IC_COMPLETIONS_H diff --git a/extern/isocline/src/editline.c b/extern/isocline/src/editline.c new file mode 100644 index 000000000..270c42d92 --- /dev/null +++ b/extern/isocline/src/editline.c @@ -0,0 +1,1142 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#include +#include + +#include "common.h" +#include "term.h" +#include "tty.h" +#include "env.h" +#include "stringbuf.h" +#include "history.h" +#include "completions.h" +#include "undo.h" +#include "highlight.h" + +//------------------------------------------------------------- +// The editor state +//------------------------------------------------------------- + + + +// editor state +typedef struct editor_s { + stringbuf_t* input; // current user input + stringbuf_t* extra; // extra displayed info (for completion menu etc) + stringbuf_t* hint; // hint displayed as part of the input + stringbuf_t* hint_help; // help for a hint. + ssize_t pos; // current cursor position in the input + ssize_t cur_rows; // current used rows to display our content (including extra content) + ssize_t cur_row; // current row that has the cursor (0 based, relative to the prompt) + ssize_t termw; + bool modified; // has a modification happened? (used for history navigation for example) + bool disable_undo; // temporarily disable auto undo (for history search) + ssize_t history_idx; // current index in the history + editstate_t* undo; // undo buffer + editstate_t* redo; // redo buffer + const char* prompt_text; // text of the prompt before the prompt marker + alloc_t* mem; // allocator + // caches + attrbuf_t* attrs; // reuse attribute buffers + attrbuf_t* attrs_extra; +} editor_t; + + + + + +//------------------------------------------------------------- +// Main edit line +//------------------------------------------------------------- +static char* edit_line( ic_env_t* env, const char* prompt_text ); // defined at bottom +static void edit_refresh(ic_env_t* env, editor_t* eb); + +ic_private char* ic_editline(ic_env_t* env, const char* prompt_text) { + tty_start_raw(env->tty); + term_start_raw(env->term); + char* line = edit_line(env,prompt_text); + term_end_raw(env->term,false); + tty_end_raw(env->tty); + term_writeln(env->term,""); + term_flush(env->term); + return line; +} + + +//------------------------------------------------------------- +// Undo/Redo +//------------------------------------------------------------- + +// capture the current edit state +static void editor_capture(editor_t* eb, editstate_t** es ) { + if (!eb->disable_undo) { + editstate_capture( eb->mem, es, sbuf_string(eb->input), eb->pos ); + } +} + +static void editor_undo_capture(editor_t* eb ) { + editor_capture(eb, &eb->undo ); +} + +static void editor_undo_forget(editor_t* eb) { + if (eb->disable_undo) return; + const char* input = NULL; + ssize_t pos = 0; + editstate_restore(eb->mem, &eb->undo, &input, &pos); + mem_free(eb->mem, input); +} + +static void editor_restore(editor_t* eb, editstate_t** from, editstate_t** to ) { + if (eb->disable_undo) return; + if (*from == NULL) return; + const char* input; + if (to != NULL) { editor_capture( eb, to ); } + if (!editstate_restore( eb->mem, from, &input, &eb->pos )) return; + sbuf_replace( eb->input, input ); + mem_free(eb->mem, input); + eb->modified = false; +} + +static void editor_undo_restore(editor_t* eb, bool with_redo ) { + editor_restore(eb, &eb->undo, (with_redo ? &eb->redo : NULL)); +} + +static void editor_redo_restore(editor_t* eb ) { + editor_restore(eb, &eb->redo, &eb->undo); + eb->modified = false; +} + +static void editor_start_modify(editor_t* eb ) { + editor_undo_capture(eb); + editstate_done(eb->mem, &eb->redo); // clear redo + eb->modified = true; +} + + + +static bool editor_pos_is_at_end(editor_t* eb ) { + return (eb->pos == sbuf_len(eb->input)); +} + +//------------------------------------------------------------- +// Row/Column width and positioning +//------------------------------------------------------------- + + +static void edit_get_prompt_width( ic_env_t* env, editor_t* eb, bool in_extra, ssize_t* promptw, ssize_t* cpromptw ) { + if (in_extra) { + *promptw = 0; + *cpromptw = 0; + } + else { + // todo: cache prompt widths + ssize_t textw = bbcode_column_width(env->bbcode, eb->prompt_text); + ssize_t markerw = bbcode_column_width(env->bbcode, env->prompt_marker); + ssize_t cmarkerw = bbcode_column_width(env->bbcode, env->cprompt_marker); + *promptw = markerw + textw; + *cpromptw = (env->no_multiline_indent || *promptw < cmarkerw ? cmarkerw : *promptw); + } +} + +static ssize_t edit_get_rowcol( ic_env_t* env, editor_t* eb, rowcol_t* rc ) { + ssize_t promptw, cpromptw; + edit_get_prompt_width(env, eb, false, &promptw, &cpromptw); + return sbuf_get_rc_at_pos( eb->input, eb->termw, promptw, cpromptw, eb->pos, rc ); +} + +static void edit_set_pos_at_rowcol( ic_env_t* env, editor_t* eb, ssize_t row, ssize_t col ) { + ssize_t promptw, cpromptw; + edit_get_prompt_width(env, eb, false, &promptw, &cpromptw); + ssize_t pos = sbuf_get_pos_at_rc( eb->input, eb->termw, promptw, cpromptw, row, col ); + if (pos < 0) return; + eb->pos = pos; + edit_refresh(env, eb); +} + +static bool edit_pos_is_at_row_end( ic_env_t* env, editor_t* eb ) { + rowcol_t rc; + edit_get_rowcol( env, eb, &rc ); + return rc.last_on_row; +} + +static void edit_write_prompt( ic_env_t* env, editor_t* eb, ssize_t row, bool in_extra ) { + if (in_extra) return; + bbcode_style_open(env->bbcode, "ic-prompt"); + if (row==0) { + // regular prompt text + bbcode_print( env->bbcode, eb->prompt_text ); + } + else if (!env->no_multiline_indent) { + // multiline continuation indentation + // todo: cache prompt widths + ssize_t textw = bbcode_column_width(env->bbcode, eb->prompt_text ); + ssize_t markerw = bbcode_column_width(env->bbcode, env->prompt_marker); + ssize_t cmarkerw = bbcode_column_width(env->bbcode, env->cprompt_marker); + if (cmarkerw < markerw + textw) { + term_write_repeat(env->term, " ", markerw + textw - cmarkerw ); + } + } + // the marker + bbcode_print(env->bbcode, (row == 0 ? env->prompt_marker : env->cprompt_marker )); + bbcode_style_close(env->bbcode,NULL); +} + +//------------------------------------------------------------- +// Refresh +//------------------------------------------------------------- + +typedef struct refresh_info_s { + ic_env_t* env; + editor_t* eb; + attrbuf_t* attrs; + bool in_extra; + ssize_t first_row; + ssize_t last_row; +} refresh_info_t; + +static bool edit_refresh_rows_iter( + const char* s, + ssize_t row, ssize_t row_start, ssize_t row_len, + ssize_t startw, bool is_wrap, const void* arg, void* res) +{ + ic_unused(res); ic_unused(startw); + const refresh_info_t* info = (const refresh_info_t*)(arg); + term_t* term = info->env->term; + + // debug_msg("edit: line refresh: row %zd, len: %zd\n", row, row_len); + if (row < info->first_row) return false; + if (row > info->last_row) return true; // should not occur + + // term_clear_line(term); + edit_write_prompt(info->env, info->eb, row, info->in_extra); + + //' write output + if (info->attrs == NULL || (info->env->no_highlight && info->env->no_bracematch)) { + term_write_n( term, s + row_start, row_len ); + } + else { + term_write_formatted_n( term, s + row_start, attrbuf_attrs(info->attrs, row_start + row_len) + row_start, row_len ); + } + + // write line ending + if (row < info->last_row) { + if (is_wrap && tty_is_utf8(info->env->tty)) { + #ifndef __APPLE__ + bbcode_print( info->env->bbcode, "[ic-dim]\xE2\x86\x90"); // left arrow + #else + bbcode_print( info->env->bbcode, "[ic-dim]\xE2\x86\xB5" ); // return symbol + #endif + } + term_clear_to_end_of_line(term); + term_writeln(term, ""); + } + else { + term_clear_to_end_of_line(term); + } + return (row >= info->last_row); +} + +static void edit_refresh_rows(ic_env_t* env, editor_t* eb, stringbuf_t* input, attrbuf_t* attrs, + ssize_t promptw, ssize_t cpromptw, bool in_extra, + ssize_t first_row, ssize_t last_row) +{ + if (input == NULL) return; + refresh_info_t info; + info.env = env; + info.eb = eb; + info.attrs = attrs; + info.in_extra = in_extra; + info.first_row = first_row; + info.last_row = last_row; + sbuf_for_each_row( input, eb->termw, promptw, cpromptw, &edit_refresh_rows_iter, &info, NULL); +} + + +static void edit_refresh(ic_env_t* env, editor_t* eb) +{ + // calculate the new cursor row and total rows needed + ssize_t promptw, cpromptw; + edit_get_prompt_width( env, eb, false, &promptw, &cpromptw ); + + if (eb->attrs != NULL) { + highlight( env->mem, env->bbcode, sbuf_string(eb->input), eb->attrs, + (env->no_highlight ? NULL : env->highlighter), env->highlighter_arg ); + } + + // highlight matching braces + if (eb->attrs != NULL && !env->no_bracematch) { + highlight_match_braces(sbuf_string(eb->input), eb->attrs, eb->pos, ic_env_get_match_braces(env), + bbcode_style(env->bbcode,"ic-bracematch"), bbcode_style(env->bbcode,"ic-error")); + } + + // insert hint + if (sbuf_len(eb->hint) > 0) { + if (eb->attrs != NULL) { + attrbuf_insert_at( eb->attrs, eb->pos, sbuf_len(eb->hint), bbcode_style(env->bbcode, "ic-hint") ); + } + sbuf_insert_at(eb->input, sbuf_string(eb->hint), eb->pos ); + } + + // render extra (like a completion menu) + stringbuf_t* extra = NULL; + if (sbuf_len(eb->extra) > 0) { + extra = sbuf_new(eb->mem); + if (extra != NULL) { + if (sbuf_len(eb->hint_help) > 0) { + bbcode_append(env->bbcode, sbuf_string(eb->hint_help), extra, eb->attrs_extra); + } + bbcode_append(env->bbcode, sbuf_string(eb->extra), extra, eb->attrs_extra); + } + } + + // calculate rows and row/col position + rowcol_t rc = { 0 }; + const ssize_t rows_input = sbuf_get_rc_at_pos( eb->input, eb->termw, promptw, cpromptw, eb->pos, &rc ); + rowcol_t rc_extra = { 0 }; + ssize_t rows_extra = 0; + if (extra != NULL) { + rows_extra = sbuf_get_rc_at_pos( extra, eb->termw, 0, 0, 0 /*pos*/, &rc_extra ); + } + const ssize_t rows = rows_input + rows_extra; + debug_msg("edit: refresh: rows %zd, cursor: %zd,%zd (previous rows %zd, cursor row %zd)\n", rows, rc.row, rc.col, eb->cur_rows, eb->cur_row); + + // only render at most terminal height rows + const ssize_t termh = term_get_height(env->term); + ssize_t first_row = 0; // first visible row + ssize_t last_row = rows - 1; // last visible row + if (rows > termh) { + first_row = rc.row - termh + 1; // ensure cursor is visible + if (first_row < 0) first_row = 0; + last_row = first_row + termh - 1; + } + assert(last_row - first_row < termh); + + // reduce flicker + buffer_mode_t bmode = term_set_buffer_mode(env->term, BUFFERED); + + // back up to the first line + term_start_of_line(env->term); + term_up(env->term, (eb->cur_row >= termh ? termh-1 : eb->cur_row) ); + // term_clear_lines_to_end(env->term); // gives flicker in old Windows cmd prompt + + // render rows + edit_refresh_rows( env, eb, eb->input, eb->attrs, promptw, cpromptw, false, first_row, last_row ); + if (rows_extra > 0) { + assert(extra != NULL); + const ssize_t first_rowx = (first_row > rows_input ? first_row - rows_input : 0); + const ssize_t last_rowx = last_row - rows_input; assert(last_rowx >= 0); + edit_refresh_rows(env, eb, extra, eb->attrs_extra, 0, 0, true, first_rowx, last_rowx); + } + + // overwrite trailing rows we do not use anymore + ssize_t rrows = last_row - first_row + 1; // rendered rows + if (rrows < termh && rows < eb->cur_rows) { + ssize_t clear = eb->cur_rows - rows; + while (rrows < termh && clear > 0) { + clear--; + rrows++; + term_writeln(env->term,""); + term_clear_line(env->term); + } + } + + // move cursor back to edit position + term_start_of_line(env->term); + term_up(env->term, first_row + rrows - 1 - rc.row ); + term_right(env->term, rc.col + (rc.row == 0 ? promptw : cpromptw)); + + // and refresh + term_flush(env->term); + + // stop buffering + term_set_buffer_mode(env->term, bmode); + + // restore input by removing the hint + sbuf_delete_at(eb->input, eb->pos, sbuf_len(eb->hint)); + sbuf_delete_at(eb->extra, 0, sbuf_len(eb->hint_help)); + attrbuf_clear(eb->attrs); + attrbuf_clear(eb->attrs_extra); + sbuf_free(extra); + + // update previous + eb->cur_rows = rows; + eb->cur_row = rc.row; +} + +// clear current output +static void edit_clear(ic_env_t* env, editor_t* eb ) { + term_attr_reset(env->term); + term_up(env->term, eb->cur_row); + + // overwrite all rows + for( ssize_t i = 0; i < eb->cur_rows; i++) { + term_clear_line(env->term); + term_writeln(env->term, ""); + } + + // move cursor back + term_up(env->term, eb->cur_rows - eb->cur_row ); +} + + +// clear screen and refresh +static void edit_clear_screen(ic_env_t* env, editor_t* eb ) { + ssize_t cur_rows = eb->cur_rows; + eb->cur_rows = term_get_height(env->term) - 1; + edit_clear(env,eb); + eb->cur_rows = cur_rows; + edit_refresh(env,eb); +} + + +// refresh after a terminal window resized (but before doing further edit operations!) +static bool edit_resize(ic_env_t* env, editor_t* eb ) { + // update dimensions + term_update_dim(env->term); + ssize_t newtermw = term_get_width(env->term); + if (eb->termw == newtermw) return false; + + // recalculate the row layout assuming the hardwrapping for the new terminal width + ssize_t promptw, cpromptw; + edit_get_prompt_width( env, eb, false, &promptw, &cpromptw ); + sbuf_insert_at(eb->input, sbuf_string(eb->hint), eb->pos); // insert used hint + + // render extra (like a completion menu) + stringbuf_t* extra = NULL; + if (sbuf_len(eb->extra) > 0) { + extra = sbuf_new(eb->mem); + if (extra != NULL) { + if (sbuf_len(eb->hint_help) > 0) { + bbcode_append(env->bbcode, sbuf_string(eb->hint_help), extra, NULL); + } + bbcode_append(env->bbcode, sbuf_string(eb->extra), extra, NULL); + } + } + rowcol_t rc = { 0 }; + const ssize_t rows_input = sbuf_get_wrapped_rc_at_pos( eb->input, eb->termw, newtermw, promptw, cpromptw, eb->pos, &rc ); + rowcol_t rc_extra = { 0 }; + ssize_t rows_extra = 0; + if (extra != NULL) { + rows_extra = sbuf_get_wrapped_rc_at_pos(extra, eb->termw, newtermw, 0, 0, 0 /*pos*/, &rc_extra); + } + ssize_t rows = rows_input + rows_extra; + debug_msg("edit: resize: new rows: %zd, cursor row: %zd (previous: rows: %zd, cursor row %zd)\n", rows, rc.row, eb->cur_rows, eb->cur_row); + + // update the newly calculated row and rows + eb->cur_row = rc.row; + if (rows > eb->cur_rows) { + eb->cur_rows = rows; + } + eb->termw = newtermw; + edit_refresh(env,eb); + + // remove hint again + sbuf_delete_at(eb->input, eb->pos, sbuf_len(eb->hint)); + sbuf_free(extra); + return true; +} + +static void editor_append_hint_help(editor_t* eb, const char* help) { + sbuf_clear(eb->hint_help); + if (help != NULL) { + sbuf_replace(eb->hint_help, "[ic-info]"); + sbuf_append(eb->hint_help, help); + sbuf_append(eb->hint_help, "[/ic-info]\n"); + } +} + +// refresh with possible hint +static void edit_refresh_hint(ic_env_t* env, editor_t* eb) { + if (env->no_hint || env->hint_delay > 0) { + // refresh without hint first + edit_refresh(env, eb); + if (env->no_hint) return; + } + + // and see if we can construct a hint (displayed after a delay) + ssize_t count = completions_generate(env, env->completions, sbuf_string(eb->input), eb->pos, 2); + if (count == 1) { + const char* help = NULL; + const char* hint = completions_get_hint(env->completions, 0, &help); + if (hint != NULL) { + sbuf_replace(eb->hint, hint); + editor_append_hint_help(eb, help); + // do auto-tabbing? + if (env->complete_autotab) { + stringbuf_t* sb = sbuf_new(env->mem); // temporary buffer for completion + if (sb != NULL) { + sbuf_replace( sb, sbuf_string(eb->input) ); + ssize_t pos = eb->pos; + const char* extra_hint = hint; + do { + ssize_t newpos = sbuf_insert_at( sb, extra_hint, pos ); + if (newpos <= pos) break; + pos = newpos; + count = completions_generate(env, env->completions, sbuf_string(sb), pos, 2); + if (count == 1) { + const char* extra_help = NULL; + extra_hint = completions_get_hint(env->completions, 0, &extra_help); + if (extra_hint != NULL) { + editor_append_hint_help(eb, extra_help); + sbuf_append(eb->hint, extra_hint); + } + } + } + while(count == 1); + sbuf_free(sb); + } + } + } + } + + if (env->hint_delay <= 0) { + // refresh with hint directly + edit_refresh(env, eb); + } +} + +//------------------------------------------------------------- +// Edit operations +//------------------------------------------------------------- + +static void edit_history_prev(ic_env_t* env, editor_t* eb); +static void edit_history_next(ic_env_t* env, editor_t* eb); + +static void edit_undo_restore(ic_env_t* env, editor_t* eb) { + editor_undo_restore(eb, true); + edit_refresh(env,eb); +} + +static void edit_redo_restore(ic_env_t* env, editor_t* eb) { + editor_redo_restore(eb); + edit_refresh(env,eb); +} + +static void edit_cursor_left(ic_env_t* env, editor_t* eb) { + ssize_t cwidth = 1; + ssize_t prev = sbuf_prev(eb->input,eb->pos,&cwidth); + if (prev < 0) return; + rowcol_t rc; + edit_get_rowcol( env, eb, &rc); + eb->pos = prev; + edit_refresh(env,eb); +} + +static void edit_cursor_right(ic_env_t* env, editor_t* eb) { + ssize_t cwidth = 1; + ssize_t next = sbuf_next(eb->input,eb->pos,&cwidth); + if (next < 0) return; + rowcol_t rc; + edit_get_rowcol( env, eb, &rc); + eb->pos = next; + edit_refresh(env,eb); +} + +static void edit_cursor_line_end(ic_env_t* env, editor_t* eb) { + ssize_t end = sbuf_find_line_end(eb->input,eb->pos); + if (end < 0) return; + eb->pos = end; + edit_refresh(env,eb); +} + +static void edit_cursor_line_start(ic_env_t* env, editor_t* eb) { + ssize_t start = sbuf_find_line_start(eb->input,eb->pos); + if (start < 0) return; + eb->pos = start; + edit_refresh(env,eb); +} + +static void edit_cursor_next_word(ic_env_t* env, editor_t* eb) { + ssize_t end = sbuf_find_word_end(eb->input,eb->pos); + if (end < 0) return; + eb->pos = end; + edit_refresh(env,eb); +} + +static void edit_cursor_prev_word(ic_env_t* env, editor_t* eb) { + ssize_t start = sbuf_find_word_start(eb->input,eb->pos); + if (start < 0) return; + eb->pos = start; + edit_refresh(env,eb); +} + +static void edit_cursor_next_ws_word(ic_env_t* env, editor_t* eb) { + ssize_t end = sbuf_find_ws_word_end(eb->input, eb->pos); + if (end < 0) return; + eb->pos = end; + edit_refresh(env, eb); +} + +static void edit_cursor_prev_ws_word(ic_env_t* env, editor_t* eb) { + ssize_t start = sbuf_find_ws_word_start(eb->input, eb->pos); + if (start < 0) return; + eb->pos = start; + edit_refresh(env, eb); +} + +static void edit_cursor_to_start(ic_env_t* env, editor_t* eb) { + eb->pos = 0; + edit_refresh(env,eb); +} + +static void edit_cursor_to_end(ic_env_t* env, editor_t* eb) { + eb->pos = sbuf_len(eb->input); + edit_refresh(env,eb); +} + + +static void edit_cursor_row_up(ic_env_t* env, editor_t* eb) { + rowcol_t rc; + edit_get_rowcol( env, eb, &rc); + if (rc.row == 0) { + edit_history_prev(env,eb); + } + else { + edit_set_pos_at_rowcol( env, eb, rc.row - 1, rc.col ); + } +} + +static void edit_cursor_row_down(ic_env_t* env, editor_t* eb) { + rowcol_t rc; + ssize_t rows = edit_get_rowcol( env, eb, &rc); + if (rc.row + 1 >= rows) { + edit_history_next(env,eb); + } + else { + edit_set_pos_at_rowcol( env, eb, rc.row + 1, rc.col ); + } +} + + +static void edit_cursor_match_brace(ic_env_t* env, editor_t* eb) { + ssize_t match = find_matching_brace( sbuf_string(eb->input), eb->pos, ic_env_get_match_braces(env), NULL ); + if (match < 0) return; + eb->pos = match; + edit_refresh(env,eb); +} + +static void edit_backspace(ic_env_t* env, editor_t* eb) { + if (eb->pos <= 0) return; + editor_start_modify(eb); + eb->pos = sbuf_delete_char_before(eb->input,eb->pos); + edit_refresh(env,eb); +} + +static void edit_delete_char(ic_env_t* env, editor_t* eb) { + if (eb->pos >= sbuf_len(eb->input)) return; + editor_start_modify(eb); + sbuf_delete_char_at(eb->input,eb->pos); + edit_refresh(env,eb); +} + +static void edit_delete_all(ic_env_t* env, editor_t* eb) { + if (sbuf_len(eb->input) <= 0) return; + editor_start_modify(eb); + sbuf_clear(eb->input); + eb->pos = 0; + edit_refresh(env,eb); +} + +static void edit_delete_to_end_of_line(ic_env_t* env, editor_t* eb) { + ssize_t start = sbuf_find_line_start(eb->input,eb->pos); + if (start < 0) return; + ssize_t end = sbuf_find_line_end(eb->input,eb->pos); + if (end < 0) return; + editor_start_modify(eb); + // if on an empty line, remove it completely + if (start == end && sbuf_char_at(eb->input,end) == '\n') { + end++; + } + else if (start == end && sbuf_char_at(eb->input,start - 1) == '\n') { + eb->pos--; + } + sbuf_delete_from_to( eb->input, eb->pos, end ); + edit_refresh(env,eb); +} + +static void edit_delete_to_start_of_line(ic_env_t* env, editor_t* eb) { + ssize_t start = sbuf_find_line_start(eb->input,eb->pos); + if (start < 0) return; + ssize_t end = sbuf_find_line_end(eb->input,eb->pos); + if (end < 0) return; + editor_start_modify(eb); + // delete start newline if it was an empty line + bool goright = false; + if (start > 0 && sbuf_char_at(eb->input,start-1) == '\n' && start == end) { + // if it is an empty line remove it + start--; + // afterwards, move to start of next line if it exists (so the cursor stays on the same row) + goright = true; + } + sbuf_delete_from_to( eb->input, start, eb->pos ); + eb->pos = start; + if (goright) edit_cursor_right(env,eb); + edit_refresh(env,eb); +} + +static void edit_delete_line(ic_env_t* env, editor_t* eb) { + ssize_t start = sbuf_find_line_start(eb->input,eb->pos); + if (start < 0) return; + ssize_t end = sbuf_find_line_end(eb->input,eb->pos); + if (end < 0) return; + editor_start_modify(eb); + // delete newline as well so no empty line is left; + bool goright = false; + if (start > 0 && sbuf_char_at(eb->input,start-1) == '\n') { + start--; + // afterwards, move to start of next line if it exists (so the cursor stays on the same row) + goright = true; + } + else if (sbuf_char_at(eb->input,end) == '\n') { + end++; + } + sbuf_delete_from_to(eb->input,start,end); + eb->pos = start; + if (goright) edit_cursor_right(env,eb); + edit_refresh(env,eb); +} + +static void edit_delete_to_start_of_word(ic_env_t* env, editor_t* eb) { + ssize_t start = sbuf_find_word_start(eb->input,eb->pos); + if (start < 0) return; + editor_start_modify(eb); + sbuf_delete_from_to( eb->input, start, eb->pos ); + eb->pos = start; + edit_refresh(env,eb); +} + +static void edit_delete_to_end_of_word(ic_env_t* env, editor_t* eb) { + ssize_t end = sbuf_find_word_end(eb->input,eb->pos); + if (end < 0) return; + editor_start_modify(eb); + sbuf_delete_from_to( eb->input, eb->pos, end ); + edit_refresh(env,eb); +} + +static void edit_delete_to_start_of_ws_word(ic_env_t* env, editor_t* eb) { + ssize_t start = sbuf_find_ws_word_start(eb->input, eb->pos); + if (start < 0) return; + editor_start_modify(eb); + sbuf_delete_from_to(eb->input, start, eb->pos); + eb->pos = start; + edit_refresh(env, eb); +} + +static void edit_delete_to_end_of_ws_word(ic_env_t* env, editor_t* eb) { + ssize_t end = sbuf_find_ws_word_end(eb->input, eb->pos); + if (end < 0) return; + editor_start_modify(eb); + sbuf_delete_from_to(eb->input, eb->pos, end); + edit_refresh(env, eb); +} + + +static void edit_delete_word(ic_env_t* env, editor_t* eb) { + ssize_t start = sbuf_find_word_start(eb->input,eb->pos); + if (start < 0) return; + ssize_t end = sbuf_find_word_end(eb->input,eb->pos); + if (end < 0) return; + editor_start_modify(eb); + sbuf_delete_from_to(eb->input,start,end); + eb->pos = start; + edit_refresh(env,eb); +} + +static void edit_swap_char( ic_env_t* env, editor_t* eb ) { + if (eb->pos <= 0 || eb->pos == sbuf_len(eb->input)) return; + editor_start_modify(eb); + eb->pos = sbuf_swap_char(eb->input,eb->pos); + edit_refresh(env,eb); +} + +static void edit_multiline_eol(ic_env_t* env, editor_t* eb) { + if (eb->pos <= 0) return; + if (sbuf_string(eb->input)[eb->pos-1] != env->multiline_eol) return; + editor_start_modify(eb); + // replace line continuation with a real newline + sbuf_delete_at( eb->input, eb->pos-1, 1); + sbuf_insert_at( eb->input, "\n", eb->pos-1); + edit_refresh(env,eb); +} + +static void edit_insert_unicode(ic_env_t* env, editor_t* eb, unicode_t u) { + editor_start_modify(eb); + ssize_t nextpos = sbuf_insert_unicode_at(eb->input, u, eb->pos); + if (nextpos >= 0) eb->pos = nextpos; + edit_refresh_hint(env, eb); +} + +static void edit_auto_brace(ic_env_t* env, editor_t* eb, char c) { + if (env->no_autobrace) return; + const char* braces = ic_env_get_auto_braces(env); + for (const char* b = braces; *b != 0; b += 2) { + if (*b == c) { + const char close = b[1]; + //if (sbuf_char_at(eb->input, eb->pos) != close) { + sbuf_insert_char_at(eb->input, close, eb->pos); + bool balanced = false; + find_matching_brace(sbuf_string(eb->input), eb->pos, braces, &balanced ); + if (!balanced) { + // don't insert if it leads to an unbalanced expression. + sbuf_delete_char_at(eb->input, eb->pos); + } + //} + return; + } + else if (b[1] == c) { + // close brace, check if there we don't overwrite to the right + if (sbuf_char_at(eb->input, eb->pos) == c) { + sbuf_delete_char_at(eb->input, eb->pos); + } + return; + } + } +} + +static void editor_auto_indent(editor_t* eb, const char* pre, const char* post ) { + assert(eb->pos > 0 && sbuf_char_at(eb->input,eb->pos-1) == '\n'); + ssize_t prelen = ic_strlen(pre); + if (prelen > 0) { + if (eb->pos - 1 < prelen) return; + if (!ic_starts_with(sbuf_string(eb->input) + eb->pos - 1 - prelen, pre)) return; + if (!ic_starts_with(sbuf_string(eb->input) + eb->pos, post)) return; + eb->pos = sbuf_insert_at(eb->input, " ", eb->pos); + sbuf_insert_char_at(eb->input, '\n', eb->pos); + } +} + +static void edit_insert_char(ic_env_t* env, editor_t* eb, char c) { + editor_start_modify(eb); + ssize_t nextpos = sbuf_insert_char_at( eb->input, c, eb->pos ); + if (nextpos >= 0) eb->pos = nextpos; + edit_auto_brace(env, eb, c); + if (c=='\n') { + editor_auto_indent(eb, "{", "}"); // todo: custom auto indent tokens? + } + edit_refresh_hint(env,eb); +} + +//------------------------------------------------------------- +// Help +//------------------------------------------------------------- + +#include "editline_help.c" + +//------------------------------------------------------------- +// History +//------------------------------------------------------------- + +#include "editline_history.c" + +//------------------------------------------------------------- +// Completion +//------------------------------------------------------------- + +#include "editline_completion.c" + + +//------------------------------------------------------------- +// Edit line: main edit loop +//------------------------------------------------------------- + +static char* edit_line( ic_env_t* env, const char* prompt_text ) +{ + // set up an edit buffer + editor_t eb; + memset(&eb, 0, sizeof(eb)); + eb.mem = env->mem; + eb.input = sbuf_new(env->mem); + eb.extra = sbuf_new(env->mem); + eb.hint = sbuf_new(env->mem); + eb.hint_help= sbuf_new(env->mem); + eb.termw = term_get_width(env->term); + eb.pos = 0; + eb.cur_rows = 1; + eb.cur_row = 0; + eb.modified = false; + eb.prompt_text = (prompt_text != NULL ? prompt_text : ""); + eb.history_idx = 0; + editstate_init(&eb.undo); + editstate_init(&eb.redo); + if (eb.input==NULL || eb.extra==NULL || eb.hint==NULL || eb.hint_help==NULL) { + return NULL; + } + + // caching + if (!(env->no_highlight && env->no_bracematch)) { + eb.attrs = attrbuf_new(env->mem); + eb.attrs_extra = attrbuf_new(env->mem); + } + + // show prompt + edit_write_prompt(env, &eb, 0, false); + + // always a history entry for the current input + history_push(env->history, ""); + + // process keys + code_t c; // current key code + while(true) { + // read a character + term_flush(env->term); + if (env->hint_delay <= 0 || sbuf_len(eb.hint) == 0) { + // blocking read + c = tty_read(env->tty); + } + else { + // timeout to display hint + if (!tty_read_timeout(env->tty, env->hint_delay, &c)) { + // timed-out + if (sbuf_len(eb.hint) > 0) { + // display hint + edit_refresh(env, &eb); + } + c = tty_read(env->tty); + } + else { + // clear the pending hint if we got input before the delay expired + sbuf_clear(eb.hint); + sbuf_clear(eb.hint_help); + } + } + + // update terminal in case of a resize + if (tty_term_resize_event(env->tty)) { + edit_resize(env,&eb); + } + + // clear hint only after a potential resize (so resize row calculations are correct) + const bool had_hint = (sbuf_len(eb.hint) > 0); + sbuf_clear(eb.hint); + sbuf_clear(eb.hint_help); + + // if the user tries to move into a hint with left-cursor or end, we complete it first + if ((c == KEY_RIGHT || c == KEY_END) && had_hint) { + edit_generate_completions(env, &eb, true); + c = KEY_NONE; + } + + // Operations that may return + if (c == KEY_ENTER) { + if (!env->singleline_only && eb.pos > 0 && + sbuf_string(eb.input)[eb.pos-1] == env->multiline_eol && + edit_pos_is_at_row_end(env,&eb)) + { + // replace line-continuation with newline + edit_multiline_eol(env,&eb); + } + else { + // otherwise done + break; + } + } + else if (c == KEY_CTRL_D) { + if (eb.pos == 0 && editor_pos_is_at_end(&eb)) break; // ctrl+D on empty quits with NULL + edit_delete_char(env,&eb); // otherwise it is like delete + } + else if (c == KEY_CTRL_C || c == KEY_EVENT_STOP) { + break; // ctrl+C or STOP event quits with NULL + } + else if (c == KEY_ESC) { + if (eb.pos == 0 && editor_pos_is_at_end(&eb)) break; // ESC on empty input returns with empty input + edit_delete_all(env,&eb); // otherwise delete the current input + // edit_delete_line(env,&eb); // otherwise delete the current line + } + else if (c == KEY_BELL /* ^G */) { + edit_delete_all(env,&eb); + break; // ctrl+G cancels (and returns empty input) + } + + // Editing Operations + else switch(c) { + // events + case KEY_EVENT_RESIZE: // not used + edit_resize(env,&eb); + break; + case KEY_EVENT_AUTOTAB: + edit_generate_completions(env, &eb, true); + break; + + // completion, history, help, undo + case KEY_TAB: + case WITH_ALT('?'): + edit_generate_completions(env,&eb,false); + break; + case KEY_CTRL_R: + case KEY_CTRL_S: + edit_history_search_with_current_word(env,&eb); + break; + case KEY_CTRL_P: + edit_history_prev(env, &eb); + break; + case KEY_CTRL_N: + edit_history_next(env, &eb); + break; + case KEY_CTRL_L: + edit_clear_screen(env, &eb); + break; + case KEY_CTRL_Z: + case WITH_CTRL('_'): + edit_undo_restore(env, &eb); + break; + case KEY_CTRL_Y: + edit_redo_restore(env, &eb); + break; + case KEY_F1: + edit_show_help(env, &eb); + break; + + // navigation + case KEY_LEFT: + case KEY_CTRL_B: + edit_cursor_left(env,&eb); + break; + case KEY_RIGHT: + case KEY_CTRL_F: + if (eb.pos == sbuf_len(eb.input)) { + edit_generate_completions( env, &eb, false ); + } + else { + edit_cursor_right(env,&eb); + } + break; + case KEY_UP: + edit_cursor_row_up(env,&eb); + break; + case KEY_DOWN: + edit_cursor_row_down(env,&eb); + break; + case KEY_HOME: + case KEY_CTRL_A: + edit_cursor_line_start(env,&eb); + break; + case KEY_END: + case KEY_CTRL_E: + edit_cursor_line_end(env,&eb); + break; + case KEY_CTRL_LEFT: + case WITH_SHIFT(KEY_LEFT): + case WITH_ALT('b'): + edit_cursor_prev_word(env,&eb); + break; + case KEY_CTRL_RIGHT: + case WITH_SHIFT(KEY_RIGHT): + case WITH_ALT('f'): + if (eb.pos == sbuf_len(eb.input)) { + edit_generate_completions( env, &eb, false ); + } + else { + edit_cursor_next_word(env,&eb); + } + break; + case KEY_CTRL_HOME: + case WITH_SHIFT(KEY_HOME): + case KEY_PAGEUP: + case WITH_ALT('<'): + edit_cursor_to_start(env,&eb); + break; + case KEY_CTRL_END: + case WITH_SHIFT(KEY_END): + case KEY_PAGEDOWN: + case WITH_ALT('>'): + edit_cursor_to_end(env,&eb); + break; + case WITH_ALT('m'): + edit_cursor_match_brace(env,&eb); + break; + + // deletion + case KEY_BACKSP: + edit_backspace(env,&eb); + break; + case KEY_DEL: + edit_delete_char(env,&eb); + break; + case WITH_ALT('d'): + edit_delete_to_end_of_word(env,&eb); + break; + case KEY_CTRL_W: + edit_delete_to_start_of_ws_word(env, &eb); + break; + case WITH_ALT(KEY_DEL): + case WITH_ALT(KEY_BACKSP): + edit_delete_to_start_of_word(env,&eb); + break; + case KEY_CTRL_U: + edit_delete_to_start_of_line(env,&eb); + break; + case KEY_CTRL_K: + edit_delete_to_end_of_line(env,&eb); + break; + case KEY_CTRL_T: + edit_swap_char(env,&eb); + break; + + // Editing + case KEY_SHIFT_TAB: + case KEY_LINEFEED: // '\n' (ctrl+J, shift+enter) + if (!env->singleline_only) { + edit_insert_char(env, &eb, '\n'); + } + break; + default: { + char chr; + unicode_t uchr; + if (code_is_ascii_char(c,&chr)) { + edit_insert_char(env,&eb,chr); + } + else if (code_is_unicode(c, &uchr)) { + edit_insert_unicode(env,&eb, uchr); + } + else { + debug_msg( "edit: ignore code: 0x%04x\n", c); + } + break; + } + } + + } + + // goto end + eb.pos = sbuf_len(eb.input); + + // refresh once more but without brace matching + bool bm = env->no_bracematch; + env->no_bracematch = true; + edit_refresh(env,&eb); + env->no_bracematch = bm; + + // save result + char* res; + if ((c == KEY_CTRL_D && sbuf_len(eb.input) == 0) || c == KEY_CTRL_C || c == KEY_EVENT_STOP) { + res = NULL; + } + else if (!tty_is_utf8(env->tty)) { + res = sbuf_strdup_from_utf8(eb.input); + } + else { + res = sbuf_strdup(eb.input); + } + + // update history + history_update(env->history, sbuf_string(eb.input)); + if (res == NULL || sbuf_len(eb.input) <= 1) { ic_history_remove_last(); } // no empty or single-char entries + history_save(env->history); + + // free resources + editstate_done(env->mem, &eb.undo); + editstate_done(env->mem, &eb.redo); + attrbuf_free(eb.attrs); + attrbuf_free(eb.attrs_extra); + sbuf_free(eb.input); + sbuf_free(eb.extra); + sbuf_free(eb.hint); + sbuf_free(eb.hint_help); + + return res; +} + diff --git a/extern/isocline/src/editline_completion.c b/extern/isocline/src/editline_completion.c new file mode 100644 index 000000000..1734ef345 --- /dev/null +++ b/extern/isocline/src/editline_completion.c @@ -0,0 +1,277 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ + +//------------------------------------------------------------- +// Completion menu: this file is included in editline.c +//------------------------------------------------------------- + +// return true if anything changed +static bool edit_complete(ic_env_t* env, editor_t* eb, ssize_t idx) { + editor_start_modify(eb); + ssize_t newpos = completions_apply(env->completions, idx, eb->input, eb->pos); + if (newpos < 0) { + editor_undo_restore(eb,false); + return false; + } + eb->pos = newpos; + edit_refresh(env,eb); + return true; +} + +static bool edit_complete_longest_prefix(ic_env_t* env, editor_t* eb ) { + editor_start_modify(eb); + ssize_t newpos = completions_apply_longest_prefix( env->completions, eb->input, eb->pos ); + if (newpos < 0) { + editor_undo_restore(eb,false); + return false; + } + eb->pos = newpos; + edit_refresh(env,eb); + return true; +} + +ic_private void sbuf_append_tagged( stringbuf_t* sb, const char* tag, const char* content ) { + sbuf_appendf(sb, "[%s]", tag); + sbuf_append(sb,content); + sbuf_append(sb,"[/]"); +} + +static void editor_append_completion(ic_env_t* env, editor_t* eb, ssize_t idx, ssize_t width, bool numbered, bool selected ) { + const char* help = NULL; + const char* display = completions_get_display(env->completions, idx, &help); + if (display == NULL) return; + if (numbered) { + sbuf_appendf(eb->extra, "[ic-info]%s%zd [/]", (selected ? (tty_is_utf8(env->tty) ? "\xE2\x86\x92" : "*") : " "), 1 + idx); + width -= 3; + } + + if (width > 0) { + sbuf_appendf(eb->extra, "[width=\"%zd;left; ;on\"]", width ); + } + if (selected) { + sbuf_append(eb->extra, "[ic-emphasis]"); + } + sbuf_append(eb->extra, display); + if (selected) { sbuf_append(eb->extra,"[/ic-emphasis]"); } + if (help != NULL) { + sbuf_append(eb->extra, " "); + sbuf_append_tagged(eb->extra, "ic-info", help ); + } + if (width > 0) { sbuf_append(eb->extra,"[/width]"); } +} + +// 2 and 3 column output up to 80 wide +#define IC_DISPLAY2_MAX 34 +#define IC_DISPLAY2_COL (3+IC_DISPLAY2_MAX) +#define IC_DISPLAY2_WIDTH (2*IC_DISPLAY2_COL + 2) // 75 + +#define IC_DISPLAY3_MAX 21 +#define IC_DISPLAY3_COL (3+IC_DISPLAY3_MAX) +#define IC_DISPLAY3_WIDTH (3*IC_DISPLAY3_COL + 2*2) // 76 + +static void editor_append_completion2(ic_env_t* env, editor_t* eb, ssize_t col_width, ssize_t idx1, ssize_t idx2, ssize_t selected ) { + editor_append_completion(env, eb, idx1, col_width, true, (idx1 == selected) ); + sbuf_append( eb->extra, " "); + editor_append_completion(env, eb, idx2, col_width, true, (idx2 == selected) ); +} + +static void editor_append_completion3(ic_env_t* env, editor_t* eb, ssize_t col_width, ssize_t idx1, ssize_t idx2, ssize_t idx3, ssize_t selected ) { + editor_append_completion(env, eb, idx1, col_width, true, (idx1 == selected) ); + sbuf_append( eb->extra, " "); + editor_append_completion(env, eb, idx2, col_width, true, (idx2 == selected)); + sbuf_append( eb->extra, " "); + editor_append_completion(env, eb, idx3, col_width, true, (idx3 == selected) ); +} + +static ssize_t edit_completions_max_width( ic_env_t* env, ssize_t count ) { + ssize_t max_width = 0; + for( ssize_t i = 0; i < count; i++) { + const char* help = NULL; + ssize_t w = bbcode_column_width(env->bbcode, completions_get_display(env->completions, i, &help)); + if (help != NULL) { + w += 2 + bbcode_column_width(env->bbcode, help); + } + if (w > max_width) { + max_width = w; + } + } + return max_width; +} + +static void edit_completion_menu(ic_env_t* env, editor_t* eb, bool more_available) { + ssize_t count = completions_count(env->completions); + ssize_t count_displayed = count; + assert(count > 1); + ssize_t selected = (env->complete_nopreview ? 0 : -1); // select first or none + ssize_t percolumn = count; + +again: + // show first 9 (or 8) completions + sbuf_clear(eb->extra); + ssize_t twidth = term_get_width(env->term) - 1; + ssize_t colwidth; + if (count > 3 && ((colwidth = 3 + edit_completions_max_width(env, 9))*3 + 2*2) < twidth) { + // display as a 3 column block + count_displayed = (count > 9 ? 9 : count); + percolumn = 3; + for (ssize_t rw = 0; rw < percolumn; rw++) { + if (rw > 0) sbuf_append(eb->extra, "\n"); + editor_append_completion3(env, eb, colwidth, rw, percolumn+rw, (2*percolumn)+rw, selected); + } + } + else if (count > 4 && ((colwidth = 3 + edit_completions_max_width(env, 8))*2 + 2) < twidth) { + // display as a 2 column block if some entries are too wide for three columns + count_displayed = (count > 8 ? 8 : count); + percolumn = (count_displayed <= 6 ? 3 : 4); + for (ssize_t rw = 0; rw < percolumn; rw++) { + if (rw > 0) sbuf_append(eb->extra, "\n"); + editor_append_completion2(env, eb, colwidth, rw, percolumn+rw, selected); + } + } + else { + // display as a list + count_displayed = (count > 9 ? 9 : count); + percolumn = count_displayed; + for (ssize_t i = 0; i < count_displayed; i++) { + if (i > 0) sbuf_append(eb->extra, "\n"); + editor_append_completion(env, eb, i, -1, true /* numbered */, selected == i); + } + } + if (count > count_displayed) { + if (more_available) { + sbuf_append(eb->extra, "\n[ic-info](press page-down (or ctrl-j) to see all further completions)[/]"); + } + else { + sbuf_appendf(eb->extra, "\n[ic-info](press page-down (or ctrl-j) to see all %zd completions)[/]", count ); + } + } + if (!env->complete_nopreview && selected >= 0 && selected <= count_displayed) { + edit_complete(env,eb,selected); + editor_undo_restore(eb,false); + } + else { + edit_refresh(env, eb); + } + + // read here; if not a valid key, push it back and return to main event loop + code_t c = tty_read(env->tty); + if (tty_term_resize_event(env->tty)) { + edit_resize(env, eb); + } + sbuf_clear(eb->extra); + + // direct selection? + if (c >= '1' && c <= '9') { + ssize_t i = (c - '1'); + if (i < count) { + selected = i; + c = KEY_ENTER; + } + } + + // process commands + if (c == KEY_DOWN || c == KEY_TAB) { + selected++; + if (selected >= count_displayed) { + //term_beep(env->term); + selected = 0; + } + goto again; + } + else if (c == KEY_UP || c == KEY_SHIFT_TAB) { + selected--; + if (selected < 0) { + selected = count_displayed - 1; + //term_beep(env->term); + } + goto again; + } + else if (c == KEY_F1) { + edit_show_help(env, eb); + goto again; + } + else if (c == KEY_ESC) { + completions_clear(env->completions); + edit_refresh(env,eb); + c = 0; // ignore and return + } + else if (selected >= 0 && (c == KEY_ENTER || c == KEY_RIGHT || c == KEY_END)) /* || c == KEY_TAB*/ { + // select the current entry + assert(selected < count); + c = 0; + edit_complete(env, eb, selected); + if (env->complete_autotab) { + tty_code_pushback(env->tty,KEY_EVENT_AUTOTAB); // immediately try to complete again + } + } + else if (!env->complete_nopreview && !code_is_virt_key(c)) { + // if in preview mode, select the current entry and exit the menu + assert(selected < count); + edit_complete(env, eb, selected); + } + else if ((c == KEY_PAGEDOWN || c == KEY_LINEFEED) && count > 9) { + // show all completions + c = 0; + if (more_available) { + // generate all entries (up to the max (= 1000)) + count = completions_generate(env, env->completions, sbuf_string(eb->input), eb->pos, IC_MAX_COMPLETIONS_TO_SHOW); + } + rowcol_t rc; + edit_get_rowcol(env,eb,&rc); + edit_clear(env,eb); + edit_write_prompt(env,eb,0,false); + term_writeln(env->term, ""); + for(ssize_t i = 0; i < count; i++) { + const char* display = completions_get_display(env->completions, i, NULL); + if (display != NULL) { + bbcode_println(env->bbcode, display); + } + } + if (count >= IC_MAX_COMPLETIONS_TO_SHOW) { + bbcode_println(env->bbcode, "[ic-info]... and more.[/]"); + } + else { + bbcode_printf(env->bbcode, "[ic-info](%zd possible completions)[/]\n", count ); + } + for(ssize_t i = 0; i < rc.row+1; i++) { + term_write(env->term, " \n"); + } + eb->cur_rows = 0; + edit_refresh(env,eb); + } + else { + edit_refresh(env,eb); + } + // done + completions_clear(env->completions); + if (c != 0) tty_code_pushback(env->tty,c); +} + +static void edit_generate_completions(ic_env_t* env, editor_t* eb, bool autotab) { + debug_msg( "edit: complete: %zd: %s\n", eb->pos, sbuf_string(eb->input) ); + if (eb->pos < 0) return; + ssize_t count = completions_generate(env, env->completions, sbuf_string(eb->input), eb->pos, IC_MAX_COMPLETIONS_TO_TRY); + bool more_available = (count >= IC_MAX_COMPLETIONS_TO_TRY); + if (count <= 0) { + // no completions + if (!autotab) { term_beep(env->term); } + } + else if (count == 1) { + // complete if only one match + if (edit_complete(env,eb,0 /*idx*/) && env->complete_autotab) { + tty_code_pushback(env->tty,KEY_EVENT_AUTOTAB); + } + } + else { + //term_beep(env->term); + if (!more_available) { + edit_complete_longest_prefix(env,eb); + } + completions_sort(env->completions); + edit_completion_menu( env, eb, more_available); + } +} diff --git a/extern/isocline/src/editline_help.c b/extern/isocline/src/editline_help.c new file mode 100644 index 000000000..fa07d1db6 --- /dev/null +++ b/extern/isocline/src/editline_help.c @@ -0,0 +1,140 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ + +//------------------------------------------------------------- +// Help: this is included into editline.c +//------------------------------------------------------------- + +static const char* help[] = { + "","Navigation:", + "left," + "^b", "go one character to the left", + "right," + "^f", "go one character to the right", + "up", "go one row up, or back in the history", + "down", "go one row down, or forward in the history", + #ifdef __APPLE__ + "shift-left", + #else + "^left", + #endif + "go to the start of the previous word", + #ifdef __APPLE__ + "shift-right", + #else + "^right", + #endif + "go to the end the current word", + "home," + "^a", "go to the start of the current line", + "end," + "^e", "go to the end of the current line", + "pgup," + "^home", "go to the start of the current input", + "pgdn," + "^end", "go to the end of the current input", + "alt-m", "jump to matching brace", + "^p", "go back in the history", + "^n", "go forward in the history", + "^r,^s", "search the history starting with the current word", + "","", + + "", "Deletion:", + "del,^d", "delete the current character", + "backsp,^h", "delete the previous character", + "^w", "delete to preceding white space", + "alt-backsp", "delete to the start of the current word", + "alt-d", "delete to the end of the current word", + "^u", "delete to the start of the current line", + "^k", "delete to the end of the current line", + "esc", "delete the current input, or done with empty input", + "","", + + "", "Editing:", + "enter", "accept current input", + #ifndef __APPLE__ + "^enter, ^j", "", + "shift-tab", + #else + "shift-tab,^j", + #endif + "create a new line for multi-line input", + //" ", "(or type '\\' followed by enter)", + "^l", "clear screen", + "^t", "swap with previous character (move character backward)", + "^z,^_", "undo", + "^y", "redo", + //"^C", "done with empty input", + //"F1", "show this help", + "tab", "try to complete the current input", + "","", + "","In the completion menu:", + "enter,left", "use the currently selected completion", + "1 - 9", "use completion N from the menu", + "tab,down", "select the next completion", + "shift-tab,up","select the previous completion", + "esc", "exit menu without completing", + "pgdn,^j", "show all further possible completions", + "","", + "","In incremental history search:", + "enter", "use the currently found history entry", + "backsp," + "^z", "go back to the previous match (undo)", + "tab," + "^r", "find the next match", + "shift-tab," + "^s", "find an earlier match", + "esc", "exit search", + " ","", + NULL, NULL +}; + +static const char* help_initial = + "[ic-info]" + "Isocline v1.0, copyright (c) 2021 Daan Leijen.\n" + "This is free software; you can redistribute it and/or\n" + "modify it under the terms of the MIT License.\n" + "See <[url]https://github.com/daanx/isocline[/url]> for further information.\n" + "We use ^ as a shorthand for ctrl-.\n" + "\n" + "Overview:\n" + "\n[ansi-lightgray]" + " home,ctrl-a cursor end,ctrl-e\n" + " ┌────────────────┼───────────────┐ (navigate)\n" + //" │ │ │\n" + #ifndef __APPLE__ + " │ ctrl-left │ ctrl-right │\n" + #else + " │ alt-left │ alt-right │\n" + #endif + " │ ┌───────┼──────┐ │ ctrl-r : search history\n" + " ▼ ▼ ▼ ▼ ▼ tab : complete word\n" + " prompt> [ansi-darkgray]it's the quintessential language[/] shift-tab: insert new line\n" + " ▲ ▲ ▲ ▲ esc : delete input, done\n" + " │ └──────────────┘ │ ctrl-z : undo\n" + " │ alt-backsp alt-d │\n" + //" │ │ │\n" + " └────────────────────────────────┘ (delete)\n" + " ctrl-u ctrl-k\n" + "[/ansi-lightgray][/ic-info]\n"; + +static void edit_show_help(ic_env_t* env, editor_t* eb) { + edit_clear(env, eb); + bbcode_println(env->bbcode, help_initial); + for (ssize_t i = 0; help[i] != NULL && help[i+1] != NULL; i += 2) { + if (help[i][0] == 0) { + bbcode_printf(env->bbcode, "[ic-info]%s[/]\n", help[i+1]); + } + else { + bbcode_printf(env->bbcode, " [ic-emphasis]%-13s[/][ansi-lightgray]%s%s[/]\n", help[i], (help[i+1][0] == 0 ? "" : ": "), help[i+1]); + } + } + + eb->cur_rows = 0; + eb->cur_row = 0; + edit_refresh(env, eb); +} diff --git a/extern/isocline/src/editline_history.c b/extern/isocline/src/editline_history.c new file mode 100644 index 000000000..2a0afa1c7 --- /dev/null +++ b/extern/isocline/src/editline_history.c @@ -0,0 +1,260 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ + +//------------------------------------------------------------- +// History search: this file is included in editline.c +//------------------------------------------------------------- + +static void edit_history_at(ic_env_t* env, editor_t* eb, int ofs ) +{ + if (eb->modified) { + history_update(env->history, sbuf_string(eb->input)); // update first entry if modified + eb->history_idx = 0; // and start again + eb->modified = false; + } + const char* entry = history_get(env->history,eb->history_idx + ofs); + // debug_msg( "edit: history: at: %d + %d, found: %s\n", eb->history_idx, ofs, entry); + if (entry == NULL) { + term_beep(env->term); + } + else { + eb->history_idx += ofs; + sbuf_replace(eb->input, entry); + if (ofs > 0) { + // at end of first line when scrolling up + ssize_t end = sbuf_find_line_end(eb->input,0); + eb->pos = (end < 0 ? 0 : end); + } + else { + eb->pos = sbuf_len(eb->input); // at end of last line when scrolling down + } + edit_refresh(env, eb); + } +} + +static void edit_history_prev(ic_env_t* env, editor_t* eb) { + edit_history_at(env,eb, 1 ); +} + +static void edit_history_next(ic_env_t* env, editor_t* eb) { + edit_history_at(env,eb, -1 ); +} + +typedef struct hsearch_s { + struct hsearch_s* next; + ssize_t hidx; + ssize_t match_pos; + ssize_t match_len; + bool cinsert; +} hsearch_t; + +static void hsearch_push( alloc_t* mem, hsearch_t** hs, ssize_t hidx, ssize_t mpos, ssize_t mlen, bool cinsert ) { + hsearch_t* h = mem_zalloc_tp( mem, hsearch_t ); + if (h == NULL) return; + h->hidx = hidx; + h->match_pos = mpos; + h->match_len = mlen; + h->cinsert = cinsert; + h->next = *hs; + *hs = h; +} + +static bool hsearch_pop( alloc_t* mem, hsearch_t** hs, ssize_t* hidx, ssize_t* match_pos, ssize_t* match_len, bool* cinsert ) { + hsearch_t* h = *hs; + if (h == NULL) return false; + *hs = h->next; + if (hidx != NULL) *hidx = h->hidx; + if (match_pos != NULL) *match_pos = h->match_pos; + if (match_len != NULL) *match_len = h->match_len; + if (cinsert != NULL) *cinsert = h->cinsert; + mem_free(mem, h); + return true; +} + +static void hsearch_done( alloc_t* mem, hsearch_t* hs ) { + while (hs != NULL) { + hsearch_t* next = hs->next; + mem_free(mem, hs); + hs = next; + } +} + +static void edit_history_search(ic_env_t* env, editor_t* eb, char* initial ) { + if (history_count( env->history ) <= 0) { + term_beep(env->term); + return; + } + + // update history + if (eb->modified) { + history_update(env->history, sbuf_string(eb->input)); // update first entry if modified + eb->history_idx = 0; // and start again + eb->modified = false; + } + + // set a search prompt and remember the previous state + editor_undo_capture(eb); + eb->disable_undo = true; + bool old_hint = ic_enable_hint(false); + const char* prompt_text = eb->prompt_text; + eb->prompt_text = "history search"; + + // search state + hsearch_t* hs = NULL; // search undo + ssize_t hidx = 1; // current history entry + ssize_t match_pos = 0; // current matched position + ssize_t match_len = 0; // length of the match + const char* hentry = NULL; // current history entry + + // Simulate per character searches for each letter in `initial` (so backspace works) + if (initial != NULL) { + const ssize_t initial_len = ic_strlen(initial); + ssize_t ipos = 0; + while( ipos < initial_len ) { + ssize_t next = str_next_ofs( initial, initial_len, ipos, NULL ); + if (next < 0) break; + hsearch_push( eb->mem, &hs, hidx, match_pos, match_len, true); + char c = initial[ipos + next]; // terminate temporarily + initial[ipos + next] = 0; + if (history_search( env->history, hidx, initial, true, &hidx, &match_pos )) { + match_len = ipos + next; + } + else if (ipos + next >= initial_len) { + term_beep(env->term); + } + initial[ipos + next] = c; // restore + ipos += next; + } + sbuf_replace( eb->input, initial); + eb->pos = ipos; + } + else { + sbuf_clear( eb->input ); + eb->pos = 0; + } + + // Incremental search +again: + hentry = history_get(env->history,hidx); + if (hentry != NULL) { + sbuf_appendf(eb->extra, "[ic-info]%zd. [/][ic-diminish][!pre]", hidx); + sbuf_append_n( eb->extra, hentry, match_pos ); + sbuf_append(eb->extra, "[/pre][u ic-emphasis][!pre]" ); + sbuf_append_n( eb->extra, hentry + match_pos, match_len ); + sbuf_append(eb->extra, "[/pre][/u][!pre]" ); + sbuf_append(eb->extra, hentry + match_pos + match_len ); + sbuf_append(eb->extra, "[/pre][/ic-diminish]"); + if (!env->no_help) { + sbuf_append(eb->extra, "\n[ic-info](use tab for the next match)[/]"); + } + sbuf_append(eb->extra, "\n" ); + } + edit_refresh(env, eb); + + // Wait for input + code_t c = (hentry == NULL ? KEY_ESC : tty_read(env->tty)); + if (tty_term_resize_event(env->tty)) { + edit_resize(env, eb); + } + sbuf_clear(eb->extra); + + // Process commands + if (c == KEY_ESC || c == KEY_BELL /* ^G */ || c == KEY_CTRL_C) { + c = 0; + eb->disable_undo = false; + editor_undo_restore(eb, false); + } + else if (c == KEY_ENTER) { + c = 0; + editor_undo_forget(eb); + sbuf_replace( eb->input, hentry ); + eb->pos = sbuf_len(eb->input); + eb->modified = false; + eb->history_idx = hidx; + } + else if (c == KEY_BACKSP || c == KEY_CTRL_Z) { + // undo last search action + bool cinsert; + if (hsearch_pop(env->mem,&hs, &hidx, &match_pos, &match_len, &cinsert)) { + if (cinsert) edit_backspace(env,eb); + } + goto again; + } + else if (c == KEY_CTRL_R || c == KEY_TAB || c == KEY_UP) { + // search backward + hsearch_push(env->mem, &hs, hidx, match_pos, match_len, false); + if (!history_search( env->history, hidx+1, sbuf_string(eb->input), true, &hidx, &match_pos )) { + hsearch_pop(env->mem,&hs,NULL,NULL,NULL,NULL); + term_beep(env->term); + }; + goto again; + } + else if (c == KEY_CTRL_S || c == KEY_SHIFT_TAB || c == KEY_DOWN) { + // search forward + hsearch_push(env->mem, &hs, hidx, match_pos, match_len, false); + if (!history_search( env->history, hidx-1, sbuf_string(eb->input), false, &hidx, &match_pos )) { + hsearch_pop(env->mem, &hs,NULL,NULL,NULL,NULL); + term_beep(env->term); + }; + goto again; + } + else if (c == KEY_F1) { + edit_show_help(env, eb); + goto again; + } + else { + // insert character and search further backward + char chr; + unicode_t uchr; + if (code_is_ascii_char(c,&chr)) { + hsearch_push(env->mem, &hs, hidx, match_pos, match_len, true); + edit_insert_char(env,eb,chr); + } + else if (code_is_unicode(c,&uchr)) { + hsearch_push(env->mem, &hs, hidx, match_pos, match_len, true); + edit_insert_unicode(env,eb,uchr); + } + else { + // ignore command + term_beep(env->term); + goto again; + } + // search for the new input + if (history_search( env->history, hidx, sbuf_string(eb->input), true, &hidx, &match_pos )) { + match_len = sbuf_len(eb->input); + } + else { + term_beep(env->term); + }; + goto again; + } + + // done + eb->disable_undo = false; + hsearch_done(env->mem,hs); + eb->prompt_text = prompt_text; + ic_enable_hint(old_hint); + edit_refresh(env,eb); + if (c != 0) tty_code_pushback(env->tty, c); +} + +// Start an incremental search with the current word +static void edit_history_search_with_current_word(ic_env_t* env, editor_t* eb) { + char* initial = NULL; + ssize_t start = sbuf_find_word_start( eb->input, eb->pos ); + if (start >= 0) { + const ssize_t next = sbuf_next(eb->input, start, NULL); + if (!ic_char_is_idletter(sbuf_string(eb->input) + start, (long)(next - start))) { + start = next; + } + if (start >= 0 && start < eb->pos) { + initial = mem_strndup(eb->mem, sbuf_string(eb->input) + start, eb->pos - start); + } + } + edit_history_search( env, eb, initial); + mem_free(env->mem, initial); +} diff --git a/extern/isocline/src/env.h b/extern/isocline/src/env.h new file mode 100644 index 000000000..edfc10033 --- /dev/null +++ b/extern/isocline/src/env.h @@ -0,0 +1,60 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_ENV_H +#define IC_ENV_H + +#include "../include/isocline.h" +#include "common.h" +#include "term.h" +#include "tty.h" +#include "stringbuf.h" +#include "history.h" +#include "completions.h" +#include "bbcode.h" + +//------------------------------------------------------------- +// Environment +//------------------------------------------------------------- + +struct ic_env_s { + alloc_t* mem; // potential custom allocator + ic_env_t* next; // next environment (used for proper deallocation) + term_t* term; // terminal + tty_t* tty; // keyboard (NULL if stdin is a pipe, file, etc) + completions_t* completions; // current completions + history_t* history; // edit history + bbcode_t* bbcode; // print with bbcodes + const char* prompt_marker; // the prompt marker (defaults to "> ") + const char* cprompt_marker; // prompt marker for continuation lines (defaults to `prompt_marker`) + ic_highlight_fun_t* highlighter; // highlight callback + void* highlighter_arg; // user state for the highlighter. + const char* match_braces; // matching braces, e.g "()[]{}" + const char* auto_braces; // auto insertion braces, e.g "()[]{}\"\"''" + char multiline_eol; // character used for multiline input ("\") (set to 0 to disable) + bool initialized; // are we initialized? + bool noedit; // is rich editing possible (tty != NULL) + bool singleline_only; // allow only single line editing? + bool complete_nopreview; // do not show completion preview for each selection in the completion menu? + bool complete_autotab; // try to keep completing after a completion? + bool no_multiline_indent; // indent continuation lines to line up under the initial prompt + bool no_help; // show short help line for history search etc. + bool no_hint; // allow hinting? + bool no_highlight; // enable highlighting? + bool no_bracematch; // enable brace matching? + bool no_autobrace; // enable automatic brace insertion? + bool no_lscolors; // use LSCOLORS/LS_COLORS to colorize file name completions? + long hint_delay; // delay before displaying a hint in milliseconds +}; + +ic_private char* ic_editline(ic_env_t* env, const char* prompt_text); + +ic_private ic_env_t* ic_get_env(void); +ic_private const char* ic_env_get_auto_braces(ic_env_t* env); +ic_private const char* ic_env_get_match_braces(ic_env_t* env); + +#endif // IC_ENV_H diff --git a/extern/isocline/src/highlight.c b/extern/isocline/src/highlight.c new file mode 100644 index 000000000..59c7255c6 --- /dev/null +++ b/extern/isocline/src/highlight.c @@ -0,0 +1,259 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ + +#include +#include "common.h" +#include "term.h" +#include "stringbuf.h" +#include "attr.h" +#include "bbcode.h" + +//------------------------------------------------------------- +// Syntax highlighting +//------------------------------------------------------------- + +struct ic_highlight_env_s { + attrbuf_t* attrs; + const char* input; + ssize_t input_len; + bbcode_t* bbcode; + alloc_t* mem; + ssize_t cached_upos; // cached unicode position + ssize_t cached_cpos; // corresponding utf-8 byte position +}; + + +ic_private void highlight( alloc_t* mem, bbcode_t* bb, const char* s, attrbuf_t* attrs, ic_highlight_fun_t* highlighter, void* arg ) { + const ssize_t len = ic_strlen(s); + if (len <= 0) return; + attrbuf_set_at(attrs,0,len,attr_none()); // fill to length of s + if (highlighter != NULL) { + ic_highlight_env_t henv; + henv.attrs = attrs; + henv.input = s; + henv.input_len = len; + henv.bbcode = bb; + henv.mem = mem; + henv.cached_cpos = 0; + henv.cached_upos = 0; + (*highlighter)( &henv, s, arg ); + } +} + + +//------------------------------------------------------------- +// Client interface +//------------------------------------------------------------- + +static void pos_adjust( ic_highlight_env_t* henv, ssize_t* ppos, ssize_t* plen ) { + ssize_t pos = *ppos; + ssize_t len = *plen; + if (pos >= henv->input_len) return; + if (pos >= 0 && len >= 0) return; // already character positions + if (henv->input == NULL) return; + + if (pos < 0) { + // negative `pos` is used as the unicode character position (for easy interfacing with Haskell) + ssize_t upos = -pos; + ssize_t cpos = 0; + ssize_t ucount = 0; + if (henv->cached_upos <= upos) { // if we have a cached position, start from there + ucount = henv->cached_upos; + cpos = henv->cached_cpos; + } + while ( ucount < upos ) { + ssize_t next = str_next_ofs(henv->input, henv->input_len, cpos, NULL); + if (next <= 0) return; + ucount++; + cpos += next; + } + *ppos = pos = cpos; + // and cache it to avoid quadratic behavior + henv->cached_upos = upos; + henv->cached_cpos = cpos; + } + if (len < 0) { + // negative `len` is used as a unicode character length + len = -len; + ssize_t ucount = 0; + ssize_t clen = 0; + while (ucount < len) { + ssize_t next = str_next_ofs(henv->input, henv->input_len, pos + clen, NULL); + if (next <= 0) return; + ucount++; + clen += next; + } + *plen = len = clen; + // and update cache if possible + if (henv->cached_cpos == pos) { + henv->cached_upos += ucount; + henv->cached_cpos += clen; + } + } +} + +static void highlight_attr(ic_highlight_env_t* henv, ssize_t pos, ssize_t count, attr_t attr ) { + if (henv==NULL) return; + pos_adjust(henv,&pos,&count); + if (pos < 0 || count <= 0) return; + attrbuf_update_at(henv->attrs, pos, count, attr); +} + +ic_public void ic_highlight(ic_highlight_env_t* henv, long pos, long count, const char* style ) { + if (henv == NULL || style==NULL || style[0]==0 || pos < 0) return; + highlight_attr(henv,pos,count,bbcode_style( henv->bbcode, style )); +} + +ic_public void ic_highlight_formatted(ic_highlight_env_t* henv, const char* s, const char* fmt) { + if (s==NULL || s[0] == 0 || fmt==NULL) return; + attrbuf_t* attrs = attrbuf_new(henv->mem); + stringbuf_t* out = sbuf_new(henv->mem); // todo: avoid allocating out? + if (attrs!=NULL && out != NULL) { + bbcode_append( henv->bbcode, fmt, out, attrs); + const ssize_t len = ic_strlen(s); + if (sbuf_len(out) != len) { + debug_msg("highlight: formatted string content differs from the original input:\n original: %s\n formatted: %s\n", s, fmt); + } + for( ssize_t i = 0; i < len; i++) { + attrbuf_update_at(henv->attrs, i, 1, attrbuf_attr_at(attrs,i)); + } + } + sbuf_free(out); + attrbuf_free(attrs); +} + +//------------------------------------------------------------- +// Brace matching +//------------------------------------------------------------- +#define MAX_NESTING (64) + +typedef struct brace_s { + char close; + bool at_cursor; + ssize_t pos; +} brace_t; + +ic_private void highlight_match_braces(const char* s, attrbuf_t* attrs, ssize_t cursor_pos, const char* braces, attr_t match_attr, attr_t error_attr) +{ + brace_t open[MAX_NESTING+1]; + ssize_t nesting = 0; + const ssize_t brace_len = ic_strlen(braces); + for (long i = 0; i < ic_strlen(s); i++) { + const char c = s[i]; + // push open brace + bool found_open = false; + for (ssize_t b = 0; b < brace_len; b += 2) { + if (c == braces[b]) { + // open brace + if (nesting >= MAX_NESTING) return; // give up + open[nesting].close = braces[b+1]; + open[nesting].pos = i; + open[nesting].at_cursor = (i == cursor_pos - 1); + nesting++; + found_open = true; + break; + } + } + if (found_open) continue; + + // pop to closing brace and potentially highlight + for (ssize_t b = 1; b < brace_len; b += 2) { + if (c == braces[b]) { + // close brace + if (nesting <= 0) { + // unmatched close brace + attrbuf_update_at( attrs, i, 1, error_attr); + } + else { + // can we fix an unmatched brace where we can match by popping just one? + if (open[nesting-1].close != c && nesting > 1 && open[nesting-2].close == c) { + // assume previous open brace was wrong + attrbuf_update_at(attrs, open[nesting-1].pos, 1, error_attr); + nesting--; + } + if (open[nesting-1].close != c) { + // unmatched open brace + attrbuf_update_at( attrs, i, 1, error_attr); + } + else { + // matching brace + nesting--; + if (i == cursor_pos - 1 || (open[nesting].at_cursor && open[nesting].pos != i - 1)) { + // highlight matching brace + attrbuf_update_at(attrs, open[nesting].pos, 1, match_attr); + attrbuf_update_at(attrs, i, 1, match_attr); + } + } + } + break; + } + } + } + // note: don't mark further unmatched open braces as in error +} + + +ic_private ssize_t find_matching_brace(const char* s, ssize_t cursor_pos, const char* braces, bool* is_balanced) +{ + if (is_balanced != NULL) { *is_balanced = false; } + bool balanced = true; + ssize_t match = -1; + brace_t open[MAX_NESTING+1]; + ssize_t nesting = 0; + const ssize_t brace_len = ic_strlen(braces); + for (long i = 0; i < ic_strlen(s); i++) { + const char c = s[i]; + // push open brace + bool found_open = false; + for (ssize_t b = 0; b < brace_len; b += 2) { + if (c == braces[b]) { + // open brace + if (nesting >= MAX_NESTING) return -1; // give up + open[nesting].close = braces[b+1]; + open[nesting].pos = i; + open[nesting].at_cursor = (i == cursor_pos - 1); + nesting++; + found_open = true; + break; + } + } + if (found_open) continue; + + // pop to closing brace + for (ssize_t b = 1; b < brace_len; b += 2) { + if (c == braces[b]) { + // close brace + if (nesting <= 0) { + // unmatched close brace + balanced = false; + } + else { + if (open[nesting-1].close != c) { + // unmatched open brace + balanced = false; + } + else { + // matching brace + nesting--; + if (i == cursor_pos - 1) { + // found matching open brace + match = open[nesting].pos + 1; + } + else if (open[nesting].at_cursor) { + // found matching close brace + match = i + 1; + } + } + } + break; + } + } + } + if (nesting != 0) { balanced = false; } + if (is_balanced != NULL) { *is_balanced = balanced; } + return match; +} diff --git a/extern/isocline/src/highlight.h b/extern/isocline/src/highlight.h new file mode 100644 index 000000000..67da02ffd --- /dev/null +++ b/extern/isocline/src/highlight.h @@ -0,0 +1,24 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_HIGHLIGHT_H +#define IC_HIGHLIGHT_H + +#include "common.h" +#include "attr.h" +#include "term.h" +#include "bbcode.h" + +//------------------------------------------------------------- +// Syntax highlighting +//------------------------------------------------------------- + +ic_private void highlight( alloc_t* mem, bbcode_t* bb, const char* s, attrbuf_t* attrs, ic_highlight_fun_t* highlighter, void* arg ); +ic_private void highlight_match_braces(const char* s, attrbuf_t* attrs, ssize_t cursor_pos, const char* braces, attr_t match_attr, attr_t error_attr); +ic_private ssize_t find_matching_brace(const char* s, ssize_t cursor_pos, const char* braces, bool* is_balanced); + +#endif // IC_HIGHLIGHT_H diff --git a/extern/isocline/src/history.c b/extern/isocline/src/history.c new file mode 100644 index 000000000..440976aa5 --- /dev/null +++ b/extern/isocline/src/history.c @@ -0,0 +1,269 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#include +#include +#include + +#include "../include/isocline.h" +#include "common.h" +#include "history.h" +#include "stringbuf.h" + +#define IC_MAX_HISTORY (200) + +struct history_s { + ssize_t count; // current number of entries in use + ssize_t len; // size of elems + const char** elems; // history items (up to count) + const char* fname; // history file + alloc_t* mem; + bool allow_duplicates; // allow duplicate entries? +}; + +ic_private history_t* history_new(alloc_t* mem) { + history_t* h = mem_zalloc_tp(mem,history_t); + h->mem = mem; + return h; +} + +ic_private void history_free(history_t* h) { + if (h == NULL) return; + history_clear(h); + if (h->len > 0) { + mem_free( h->mem, h->elems ); + h->elems = NULL; + h->len = 0; + } + mem_free(h->mem, h->fname); + h->fname = NULL; + mem_free(h->mem, h); // free ourselves +} + +ic_private bool history_enable_duplicates( history_t* h, bool enable ) { + bool prev = h->allow_duplicates; + h->allow_duplicates = enable; + return prev; +} + +ic_private ssize_t history_count(const history_t* h) { + return h->count; +} + +//------------------------------------------------------------- +// push/clear +//------------------------------------------------------------- + +ic_private bool history_update( history_t* h, const char* entry ) { + if (entry==NULL) return false; + history_remove_last(h); + history_push(h,entry); + //debug_msg("history: update: with %s; now at %s\n", entry, history_get(h,0)); + return true; +} + +static void history_delete_at( history_t* h, ssize_t idx ) { + if (idx < 0 || idx >= h->count) return; + mem_free(h->mem, h->elems[idx]); + for(ssize_t i = idx+1; i < h->count; i++) { + h->elems[i-1] = h->elems[i]; + } + h->count--; +} + +ic_private bool history_push( history_t* h, const char* entry ) { + if (h->len <= 0 || entry==NULL) return false; + // remove any older duplicate + if (!h->allow_duplicates) { + for( int i = 0; i < h->count; i++) { + if (strcmp(h->elems[i],entry) == 0) { + history_delete_at(h,i); + } + } + } + // insert at front + if (h->count == h->len) { + // delete oldest entry + history_delete_at(h,0); + } + assert(h->count < h->len); + h->elems[h->count] = mem_strdup(h->mem,entry); + h->count++; + return true; +} + + +static void history_remove_last_n( history_t* h, ssize_t n ) { + if (n <= 0) return; + if (n > h->count) n = h->count; + for( ssize_t i = h->count - n; i < h->count; i++) { + mem_free( h->mem, h->elems[i] ); + } + h->count -= n; + assert(h->count >= 0); +} + +ic_private void history_remove_last(history_t* h) { + history_remove_last_n(h,1); +} + +ic_private void history_clear(history_t* h) { + history_remove_last_n( h, h->count ); +} + +ic_private const char* history_get( const history_t* h, ssize_t n ) { + if (n < 0 || n >= h->count) return NULL; + return h->elems[h->count - n - 1]; +} + +ic_private bool history_search( const history_t* h, ssize_t from /*including*/, const char* search, bool backward, ssize_t* hidx, ssize_t* hpos ) { + const char* p = NULL; + ssize_t i; + if (backward) { + for( i = from; i < h->count; i++ ) { + p = strstr( history_get(h,i), search); + if (p != NULL) break; + } + } + else { + for( i = from; i >= 0; i-- ) { + p = strstr( history_get(h,i), search); + if (p != NULL) break; + } + } + if (p == NULL) return false; + if (hidx != NULL) *hidx = i; + if (hpos != NULL) *hpos = (p - history_get(h,i)); + return true; +} + +//------------------------------------------------------------- +// +//------------------------------------------------------------- + +ic_private void history_load_from(history_t* h, const char* fname, long max_entries ) { + history_clear(h); + h->fname = mem_strdup(h->mem,fname); + if (max_entries == 0) { + assert(h->elems == NULL); + return; + } + if (max_entries < 0 || max_entries > IC_MAX_HISTORY) max_entries = IC_MAX_HISTORY; + h->elems = (const char**)mem_zalloc_tp_n(h->mem, char*, max_entries ); + if (h->elems == NULL) return; + h->len = max_entries; + history_load(h); +} + + + + +//------------------------------------------------------------- +// save/load history to file +//------------------------------------------------------------- + +static char from_xdigit( int c ) { + if (c >= '0' && c <= '9') return (char)(c - '0'); + if (c >= 'A' && c <= 'F') return (char)(10 + (c - 'A')); + if (c >= 'a' && c <= 'f') return (char)(10 + (c - 'a')); + return 0; +} + +static char to_xdigit( uint8_t c ) { + if (c <= 9) return ((char)c + '0'); + if (c >= 10 && c <= 15) return ((char)c - 10 + 'A'); + return '0'; +} + +static bool ic_isxdigit( int c ) { + return ((c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F') || (c >= '0' && c <= '9')); +} + +static bool history_read_entry( history_t* h, FILE* f, stringbuf_t* sbuf ) { + sbuf_clear(sbuf); + while( !feof(f)) { + int c = fgetc(f); + if (c == EOF || c == '\n') break; + if (c == '\\') { + c = fgetc(f); + if (c == 'n') { sbuf_append(sbuf,"\n"); } + else if (c == 'r') { /* ignore */ } // sbuf_append(sbuf,"\r"); + else if (c == 't') { sbuf_append(sbuf,"\t"); } + else if (c == '\\') { sbuf_append(sbuf,"\\"); } + else if (c == 'x') { + int c1 = fgetc(f); + int c2 = fgetc(f); + if (ic_isxdigit(c1) && ic_isxdigit(c2)) { + char chr = from_xdigit(c1)*16 + from_xdigit(c2); + sbuf_append_char(sbuf,chr); + } + else return false; + } + else return false; + } + else sbuf_append_char(sbuf,(char)c); + } + if (sbuf_len(sbuf)==0 || sbuf_string(sbuf)[0] == '#') return true; + return history_push(h, sbuf_string(sbuf)); +} + +static bool history_write_entry( const char* entry, FILE* f, stringbuf_t* sbuf ) { + sbuf_clear(sbuf); + //debug_msg("history: write: %s\n", entry); + while( entry != NULL && *entry != 0 ) { + char c = *entry++; + if (c == '\\') { sbuf_append(sbuf,"\\\\"); } + else if (c == '\n') { sbuf_append(sbuf,"\\n"); } + else if (c == '\r') { /* ignore */ } // sbuf_append(sbuf,"\\r"); } + else if (c == '\t') { sbuf_append(sbuf,"\\t"); } + else if (c < ' ' || c > '~' || c == '#') { + char c1 = to_xdigit( (uint8_t)c / 16 ); + char c2 = to_xdigit( (uint8_t)c % 16 ); + sbuf_append(sbuf,"\\x"); + sbuf_append_char(sbuf,c1); + sbuf_append_char(sbuf,c2); + } + else sbuf_append_char(sbuf,c); + } + //debug_msg("history: write buf: %s\n", sbuf_string(sbuf)); + + if (sbuf_len(sbuf) > 0) { + sbuf_append(sbuf,"\n"); + fputs(sbuf_string(sbuf),f); + } + return true; +} + +ic_private void history_load( history_t* h ) { + if (h->fname == NULL) return; + FILE* f = fopen(h->fname, "r"); + if (f == NULL) return; + stringbuf_t* sbuf = sbuf_new(h->mem); + if (sbuf != NULL) { + while (!feof(f)) { + if (!history_read_entry(h,f,sbuf)) break; // error + } + sbuf_free(sbuf); + } + fclose(f); +} + +ic_private void history_save( const history_t* h ) { + if (h->fname == NULL) return; + FILE* f = fopen(h->fname, "w"); + if (f == NULL) return; + #ifndef _WIN32 + chmod(h->fname,S_IRUSR|S_IWUSR); + #endif + stringbuf_t* sbuf = sbuf_new(h->mem); + if (sbuf != NULL) { + for( int i = 0; i < h->count; i++ ) { + if (!history_write_entry(h->elems[i],f,sbuf)) break; // error + } + sbuf_free(sbuf); + } + fclose(f); +} diff --git a/extern/isocline/src/history.h b/extern/isocline/src/history.h new file mode 100644 index 000000000..76a37160f --- /dev/null +++ b/extern/isocline/src/history.h @@ -0,0 +1,38 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_HISTORY_H +#define IC_HISTORY_H + +#include "common.h" + +//------------------------------------------------------------- +// History +//------------------------------------------------------------- + +struct history_s; +typedef struct history_s history_t; + +ic_private history_t* history_new(alloc_t* mem); +ic_private void history_free(history_t* h); +ic_private void history_clear(history_t* h); +ic_private bool history_enable_duplicates( history_t* h, bool enable ); +ic_private ssize_t history_count(const history_t* h); + +ic_private void history_load_from(history_t* h, const char* fname, long max_entries); +ic_private void history_load( history_t* h ); +ic_private void history_save( const history_t* h ); + +ic_private bool history_push( history_t* h, const char* entry ); +ic_private bool history_update( history_t* h, const char* entry ); +ic_private const char* history_get( const history_t* h, ssize_t n ); +ic_private void history_remove_last(history_t* h); + +ic_private bool history_search( const history_t* h, ssize_t from, const char* search, bool backward, ssize_t* hidx, ssize_t* hpos); + + +#endif // IC_HISTORY_H diff --git a/extern/isocline/src/isocline.c b/extern/isocline/src/isocline.c new file mode 100644 index 000000000..132780628 --- /dev/null +++ b/extern/isocline/src/isocline.c @@ -0,0 +1,589 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ + +//------------------------------------------------------------- +// Usually we include all sources one file so no internal +// symbols are public in the libray. +// +// You can compile the entire library just as: +// $ gcc -c src/isocline.c +//------------------------------------------------------------- +#if !defined(IC_SEPARATE_OBJS) +# define _CRT_SECURE_NO_WARNINGS // for msvc +# define _XOPEN_SOURCE 700 // for wcwidth +# define _DEFAULT_SOURCE // ensure usleep stays visible with _XOPEN_SOURCE >= 700 +# include "attr.c" +# include "bbcode.c" +# include "editline.c" +# include "highlight.c" +# include "undo.c" +# include "history.c" +# include "completers.c" +# include "completions.c" +# include "term.c" +# include "tty_esc.c" +# include "tty.c" +# include "stringbuf.c" +# include "common.c" +#endif + +//------------------------------------------------------------- +// includes +//------------------------------------------------------------- +#include +#include +#include +#include + +#include "../include/isocline.h" +#include "common.h" +#include "env.h" + + +//------------------------------------------------------------- +// Readline +//------------------------------------------------------------- + +static char* ic_getline( alloc_t* mem ); + +ic_public char* ic_readline(const char* prompt_text) +{ + ic_env_t* env = ic_get_env(); + if (env == NULL) return NULL; + if (!env->noedit) { + // terminal editing enabled + return ic_editline(env, prompt_text); // in editline.c + } + else { + // no editing capability (pipe, dumb terminal, etc) + if (env->tty != NULL && env->term != NULL) { + // if the terminal is not interactive, but we are reading from the tty (keyboard), we display a prompt + term_start_raw(env->term); // set utf8 mode on windows + if (prompt_text != NULL) { + term_write(env->term, prompt_text); + } + term_write(env->term, env->prompt_marker); + term_end_raw(env->term, false); + } + // read directly from stdin + return ic_getline(env->mem); + } +} + + +//------------------------------------------------------------- +// Read a line from the stdin stream if there is no editing +// support (like from a pipe, file, or dumb terminal). +//------------------------------------------------------------- + +static char* ic_getline(alloc_t* mem) +{ + // read until eof or newline + stringbuf_t* sb = sbuf_new(mem); + int c; + while (true) { + c = fgetc(stdin); + if (c==EOF || c=='\n') { + break; + } + else { + sbuf_append_char(sb, (char)c); + } + } + return sbuf_free_dup(sb); +} + + +//------------------------------------------------------------- +// Formatted output +//------------------------------------------------------------- + + +ic_public void ic_printf(const char* fmt, ...) { + va_list ap; + va_start(ap, fmt); + ic_vprintf(fmt, ap); + va_end(ap); +} + +ic_public void ic_vprintf(const char* fmt, va_list args) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->bbcode == NULL) return; + bbcode_vprintf(env->bbcode, fmt, args); +} + +ic_public void ic_print(const char* s) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->bbcode==NULL) return; + bbcode_print(env->bbcode, s); +} + +ic_public void ic_println(const char* s) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->bbcode==NULL) return; + bbcode_println(env->bbcode, s); +} + +void ic_style_def(const char* name, const char* fmt) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->bbcode==NULL) return; + bbcode_style_def(env->bbcode, name, fmt); +} + +void ic_style_open(const char* fmt) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->bbcode==NULL) return; + bbcode_style_open(env->bbcode, fmt); +} + +void ic_style_close(void) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->bbcode==NULL) return; + bbcode_style_close(env->bbcode, NULL); +} + + +//------------------------------------------------------------- +// Interface +//------------------------------------------------------------- + +ic_public bool ic_async_stop(void) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + if (env->tty==NULL) return false; + return tty_async_stop(env->tty); +} + +static void set_prompt_marker(ic_env_t* env, const char* prompt_marker, const char* cprompt_marker) { + if (prompt_marker == NULL) prompt_marker = "> "; + if (cprompt_marker == NULL) cprompt_marker = prompt_marker; + mem_free(env->mem, env->prompt_marker); + mem_free(env->mem, env->cprompt_marker); + env->prompt_marker = mem_strdup(env->mem, prompt_marker); + env->cprompt_marker = mem_strdup(env->mem, cprompt_marker); +} + +ic_public const char* ic_get_prompt_marker(void) { + ic_env_t* env = ic_get_env(); if (env==NULL) return NULL; + return env->prompt_marker; +} + +ic_public const char* ic_get_continuation_prompt_marker(void) { + ic_env_t* env = ic_get_env(); if (env==NULL) return NULL; + return env->cprompt_marker; +} + +ic_public void ic_set_prompt_marker( const char* prompt_marker, const char* cprompt_marker ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + set_prompt_marker(env, prompt_marker, cprompt_marker); +} + +ic_public bool ic_enable_multiline( bool enable ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + bool prev = env->singleline_only; + env->singleline_only = !enable; + return !prev; +} + +ic_public bool ic_enable_beep( bool enable ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + return term_enable_beep(env->term, enable); +} + +ic_public bool ic_enable_color( bool enable ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + return term_enable_color( env->term, enable ); +} + +ic_public bool ic_enable_history_duplicates( bool enable ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + return history_enable_duplicates(env->history, enable); +} + +ic_public void ic_set_history(const char* fname, long max_entries ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + history_load_from(env->history, fname, max_entries ); +} + +ic_public void ic_history_remove_last(void) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + history_remove_last(env->history); +} + +ic_public void ic_history_add( const char* entry ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + history_push( env->history, entry ); +} + +ic_public void ic_history_clear(void) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + history_clear(env->history); +} + +ic_public bool ic_enable_auto_tab( bool enable ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + bool prev = env->complete_autotab; + env->complete_autotab = enable; + return prev; +} + +ic_public bool ic_enable_completion_preview( bool enable ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + bool prev = env->complete_nopreview; + env->complete_nopreview = !enable; + return !prev; +} + +ic_public bool ic_enable_multiline_indent(bool enable) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + bool prev = env->no_multiline_indent; + env->no_multiline_indent = !enable; + return !prev; +} + +ic_public bool ic_enable_hint(bool enable) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + bool prev = env->no_hint; + env->no_hint = !enable; + return !prev; +} + +ic_public long ic_set_hint_delay(long delay_ms) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + long prev = env->hint_delay; + env->hint_delay = (delay_ms < 0 ? 0 : (delay_ms > 5000 ? 5000 : delay_ms)); + return prev; +} + +ic_public void ic_set_tty_esc_delay(long initial_delay_ms, long followup_delay_ms ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + if (env->tty == NULL) return; + tty_set_esc_delay(env->tty, initial_delay_ms, followup_delay_ms); +} + + +ic_public bool ic_enable_highlight(bool enable) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + bool prev = env->no_highlight; + env->no_highlight = !enable; + return !prev; +} + +ic_public bool ic_enable_inline_help(bool enable) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + bool prev = env->no_help; + env->no_help = !enable; + return !prev; +} + +ic_public bool ic_enable_brace_matching(bool enable) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + bool prev = env->no_bracematch; + env->no_bracematch = !enable; + return !prev; +} + +ic_public void ic_set_matching_braces(const char* brace_pairs) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + mem_free(env->mem, env->match_braces); + env->match_braces = NULL; + if (brace_pairs != NULL) { + ssize_t len = ic_strlen(brace_pairs); + if (len > 0 && (len % 2) == 0) { + env->match_braces = mem_strdup(env->mem, brace_pairs); + } + } +} + +ic_public bool ic_enable_brace_insertion(bool enable) { + ic_env_t* env = ic_get_env(); if (env==NULL) return false; + bool prev = env->no_autobrace; + env->no_autobrace = !enable; + return !prev; +} + +ic_public void ic_set_insertion_braces(const char* brace_pairs) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + mem_free(env->mem, env->auto_braces); + env->auto_braces = NULL; + if (brace_pairs != NULL) { + ssize_t len = ic_strlen(brace_pairs); + if (len > 0 && (len % 2) == 0) { + env->auto_braces = mem_strdup(env->mem, brace_pairs); + } + } +} + +ic_private const char* ic_env_get_match_braces(ic_env_t* env) { + return (env->match_braces == NULL ? "()[]{}" : env->match_braces); +} + +ic_private const char* ic_env_get_auto_braces(ic_env_t* env) { + return (env->auto_braces == NULL ? "()[]{}\"\"''" : env->auto_braces); +} + +ic_public void ic_set_default_highlighter(ic_highlight_fun_t* highlighter, void* arg) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + env->highlighter = highlighter; + env->highlighter_arg = arg; +} + + +ic_public void ic_free( void* p ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + mem_free(env->mem, p); +} + +ic_public void* ic_malloc(size_t sz) { + ic_env_t* env = ic_get_env(); if (env==NULL) return NULL; + return mem_malloc(env->mem, to_ssize_t(sz)); +} + +ic_public const char* ic_strdup( const char* s ) { + if (s==NULL) return NULL; + ic_env_t* env = ic_get_env(); if (env==NULL) return NULL; + ssize_t len = ic_strlen(s); + char* p = mem_malloc_tp_n( env->mem, char, len + 1 ); + if (p == NULL) return NULL; + ic_memcpy( p, s, len ); + p[len] = 0; + return p; +} + +//------------------------------------------------------------- +// Terminal +//------------------------------------------------------------- + +ic_public void ic_term_init(void) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + if (env->term==NULL) return; + term_start_raw(env->term); +} + +ic_public void ic_term_done(void) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + if (env->term==NULL) return; + term_end_raw(env->term,false); +} + +ic_public void ic_term_flush(void) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + if (env->term==NULL) return; + term_flush(env->term); +} + +ic_public void ic_term_write(const char* s) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + if (env->term == NULL) return; + term_write(env->term, s); +} + +ic_public void ic_term_writeln(const char* s) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + if (env->term == NULL) return; + term_writeln(env->term, s); +} + +ic_public void ic_term_writef(const char* fmt, ...) { + va_list ap; + va_start(ap, fmt); + ic_term_vwritef(fmt, ap); + va_end(ap); +} + +ic_public void ic_term_vwritef(const char* fmt, va_list args) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + if (env->term == NULL) return; + term_vwritef(env->term, fmt, args); +} + +ic_public void ic_term_reset( void ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + if (env->term == NULL) return; + term_attr_reset(env->term); +} + +ic_public void ic_term_style( const char* style ) { + ic_env_t* env = ic_get_env(); if (env==NULL) return; + if (env->term == NULL || env->bbcode == NULL) return; + term_set_attr( env->term, bbcode_style(env->bbcode, style)); +} + +ic_public int ic_term_get_color_bits(void) { + ic_env_t* env = ic_get_env(); + if (env==NULL || env->term==NULL) return 4; + return term_get_color_bits(env->term); +} + +ic_public void ic_term_bold(bool enable) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->term==NULL) return; + term_bold(env->term, enable); +} + +ic_public void ic_term_underline(bool enable) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->term==NULL) return; + term_underline(env->term, enable); +} + +ic_public void ic_term_italic(bool enable) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->term==NULL) return; + term_italic(env->term, enable); +} + +ic_public void ic_term_reverse(bool enable) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->term==NULL) return; + term_reverse(env->term, enable); +} + +ic_public void ic_term_color_ansi(bool foreground, int ansi_color) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->term==NULL) return; + ic_color_t color = color_from_ansi256(ansi_color); + if (foreground) { term_color(env->term, color); } + else { term_bgcolor(env->term, color); } +} + +ic_public void ic_term_color_rgb(bool foreground, uint32_t hcolor) { + ic_env_t* env = ic_get_env(); if (env==NULL || env->term==NULL) return; + ic_color_t color = ic_rgb(hcolor); + if (foreground) { term_color(env->term, color); } + else { term_bgcolor(env->term, color); } +} + + +//------------------------------------------------------------- +// Readline with temporary completer and highlighter +//------------------------------------------------------------- + +ic_public char* ic_readline_ex(const char* prompt_text, + ic_completer_fun_t* completer, void* completer_arg, + ic_highlight_fun_t* highlighter, void* highlighter_arg ) +{ + ic_env_t* env = ic_get_env(); if (env == NULL) return NULL; + // save previous + ic_completer_fun_t* prev_completer; + void* prev_completer_arg; + completions_get_completer(env->completions, &prev_completer, &prev_completer_arg); + ic_highlight_fun_t* prev_highlighter = env->highlighter; + void* prev_highlighter_arg = env->highlighter_arg; + // call with current + if (completer != NULL) { ic_set_default_completer(completer, completer_arg); } + if (highlighter != NULL) { ic_set_default_highlighter(highlighter, highlighter_arg); } + char* res = ic_readline(prompt_text); + // restore previous + ic_set_default_completer(prev_completer, prev_completer_arg); + ic_set_default_highlighter(prev_highlighter, prev_highlighter_arg); + return res; +} + + +//------------------------------------------------------------- +// Initialize +//------------------------------------------------------------- + +static void ic_atexit(void); + +static void ic_env_free(ic_env_t* env) { + if (env == NULL) return; + history_save(env->history); + history_free(env->history); + completions_free(env->completions); + bbcode_free(env->bbcode); + term_free(env->term); + tty_free(env->tty); + mem_free(env->mem, env->cprompt_marker); + mem_free(env->mem,env->prompt_marker); + mem_free(env->mem, env->match_braces); + mem_free(env->mem, env->auto_braces); + env->prompt_marker = NULL; + + // and deallocate ourselves + alloc_t* mem = env->mem; + mem_free(mem, env); + + // and finally the custom memory allocation structure + mem_free(mem, mem); +} + + +static ic_env_t* ic_env_create( ic_malloc_fun_t* _malloc, ic_realloc_fun_t* _realloc, ic_free_fun_t* _free ) +{ + if (_malloc == NULL) _malloc = &malloc; + if (_realloc == NULL) _realloc = &realloc; + if (_free == NULL) _free = &free; + // allocate + alloc_t* mem = (alloc_t*)_malloc(sizeof(alloc_t)); + if (mem == NULL) return NULL; + mem->malloc = _malloc; + mem->realloc = _realloc; + mem->free = _free; + ic_env_t* env = mem_zalloc_tp(mem, ic_env_t); + if (env==NULL) { + mem->free(mem); + return NULL; + } + env->mem = mem; + + // Initialize + env->tty = tty_new(env->mem, -1); // can return NULL + env->term = term_new(env->mem, env->tty, false, false, -1 ); + env->history = history_new(env->mem); + env->completions = completions_new(env->mem); + env->bbcode = bbcode_new(env->mem, env->term); + env->hint_delay = 400; + + if (env->tty == NULL || env->term==NULL || + env->completions == NULL || env->history == NULL || env->bbcode == NULL || + !term_is_interactive(env->term)) + { + env->noedit = true; + } + env->multiline_eol = '\\'; + + bbcode_style_def(env->bbcode, "ic-prompt", "ansi-green" ); + bbcode_style_def(env->bbcode, "ic-info", "ansi-darkgray" ); + bbcode_style_def(env->bbcode, "ic-diminish", "ansi-lightgray" ); + bbcode_style_def(env->bbcode, "ic-emphasis", "#ffffd7" ); + bbcode_style_def(env->bbcode, "ic-hint", "ansi-darkgray" ); + bbcode_style_def(env->bbcode, "ic-error", "#d70000" ); + bbcode_style_def(env->bbcode, "ic-bracematch","ansi-white"); // color = #F7DC6F" ); + + bbcode_style_def(env->bbcode, "keyword", "#569cd6" ); + bbcode_style_def(env->bbcode, "control", "#c586c0" ); + bbcode_style_def(env->bbcode, "number", "#b5cea8" ); + bbcode_style_def(env->bbcode, "string", "#ce9178" ); + bbcode_style_def(env->bbcode, "comment", "#6A9955" ); + bbcode_style_def(env->bbcode, "type", "darkcyan" ); + bbcode_style_def(env->bbcode, "constant", "#569cd6" ); + + set_prompt_marker(env, NULL, NULL); + return env; +} + +static ic_env_t* rpenv; + +static void ic_atexit(void) { + if (rpenv != NULL) { + ic_env_free(rpenv); + rpenv = NULL; + } +} + +ic_private ic_env_t* ic_get_env(void) { + if (rpenv==NULL) { + rpenv = ic_env_create( NULL, NULL, NULL ); + if (rpenv != NULL) { atexit( &ic_atexit ); } + } + return rpenv; +} + +ic_public void ic_init_custom_malloc( ic_malloc_fun_t* _malloc, ic_realloc_fun_t* _realloc, ic_free_fun_t* _free ) { + assert(rpenv == NULL); + if (rpenv != NULL) { + ic_env_free(rpenv); + rpenv = ic_env_create( _malloc, _realloc, _free ); + } + else { + rpenv = ic_env_create( _malloc, _realloc, _free ); + if (rpenv != NULL) { + atexit( &ic_atexit ); + } + } +} + diff --git a/extern/isocline/src/stringbuf.c b/extern/isocline/src/stringbuf.c new file mode 100644 index 000000000..7bbfad049 --- /dev/null +++ b/extern/isocline/src/stringbuf.c @@ -0,0 +1,1038 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ + +// get `wcwidth` for the column width of unicode characters +// note: for now the OS provided one is unused as we see quite a bit of variation +// among platforms and including our own seems more reliable. +/* +#if defined(__linux__) || defined(__freebsd__) +// use the system supplied one +#if !defined(_XOPEN_SOURCE) +#define _XOPEN_SOURCE 700 // so wcwidth is visible +#endif +#include +#else +*/ +// use our own (also on APPLE as that fails within vscode) +#define wcwidth(c) mk_wcwidth(c) +#include "wcwidth.c" +// #endif + +#include +#include +#include + +#include "common.h" +#include "stringbuf.h" + +//------------------------------------------------------------- +// In place growable utf-8 strings +//------------------------------------------------------------- + +struct stringbuf_s { + char* buf; + ssize_t buflen; + ssize_t count; + alloc_t* mem; +}; + + +//------------------------------------------------------------- +// String column width +//------------------------------------------------------------- + +// column width of a utf8 single character sequence. +static ssize_t utf8_char_width( const char* s, ssize_t n ) { + if (n <= 0) return 0; + + uint8_t b = (uint8_t)s[0]; + int32_t c; + if (b < ' ') { + return 0; + } + else if (b <= 0x7F) { + return 1; + } + else if (b <= 0xC1) { // invalid continuation byte or invalid 0xC0, 0xC1 (check is strictly not necessary as we don't validate..) + return 1; + } + else if (b <= 0xDF && n >= 2) { // b >= 0xC2 // 2 bytes + c = (((b & 0x1F) << 6) | (s[1] & 0x3F)); + assert(c < 0xD800 || c > 0xDFFF); + int w = wcwidth(c); + return w; + } + else if (b <= 0xEF && n >= 3) { // b >= 0xE0 // 3 bytes + c = (((b & 0x0F) << 12) | ((s[1] & 0x3F) << 6) | (s[2] & 0x3F)); + return wcwidth(c); + } + else if (b <= 0xF4 && n >= 4) { // b >= 0xF0 // 4 bytes + c = (((b & 0x07) << 18) | ((s[1] & 0x3F) << 12) | ((s[2] & 0x3F) << 6) | (s[3] & 0x3F)); + return wcwidth(c); + } + else { + // failed + return 1; + } +} + + +// The column width of a codepoint (0, 1, or 2) +static ssize_t char_column_width( const char* s, ssize_t n ) { + if (s == NULL || n <= 0) return 0; + else if ((uint8_t)(*s) < ' ') return 0; // also for CSI escape sequences + else { + ssize_t w = utf8_char_width(s, n); + #ifdef _WIN32 + return (w <= 0 ? 1 : w); // windows console seems to use at least one column + #else + return w; + #endif + } +} + +static ssize_t str_column_width_n( const char* s, ssize_t len ) { + if (s == NULL || len <= 0) return 0; + ssize_t pos = 0; + ssize_t cwidth = 0; + ssize_t cw; + ssize_t ofs; + while (s[pos] != 0 && (ofs = str_next_ofs(s, len, pos, &cw)) > 0) { + cwidth += cw; + pos += ofs; + } + return cwidth; +} + +ic_private ssize_t str_column_width( const char* s ) { + return str_column_width_n( s, ic_strlen(s) ); +} + +ic_private ssize_t str_skip_until_fit( const char* s, ssize_t max_width ) { + if (s == NULL) return 0; + ssize_t cwidth = str_column_width(s); + ssize_t len = ic_strlen(s); + ssize_t pos = 0; + ssize_t next; + ssize_t cw; + while (cwidth > max_width && (next = str_next_ofs(s, len, pos, &cw)) > 0) { + cwidth -= cw; + pos += next; + } + return pos; +} + +ic_private ssize_t str_take_while_fit( const char* s, ssize_t max_width) { + if (s == NULL) return 0; + const ssize_t len = ic_strlen(s); + ssize_t pos = 0; + ssize_t next; + ssize_t cw; + ssize_t cwidth = 0; + while ((next = str_next_ofs(s, len, pos, &cw)) > 0) { + if (cwidth + cw > max_width) break; + cwidth += cw; + pos += next; + } + return pos; +} + + +//------------------------------------------------------------- +// String navigation +//------------------------------------------------------------- + +// get offset of the previous codepoint. does not skip back over CSI sequences. +ic_private ssize_t str_prev_ofs( const char* s, ssize_t pos, ssize_t* width ) { + ssize_t ofs = 0; + if (s != NULL && pos > 0) { + ofs = 1; + while (pos > ofs) { + uint8_t u = (uint8_t)s[pos - ofs]; + if (u < 0x80 || u > 0xBF) break; // continue while follower + ofs++; + } + } + if (width != NULL) *width = char_column_width( s+(pos-ofs), ofs ); + return ofs; +} + +// skip an escape sequence +// +ic_private bool skip_esc( const char* s, ssize_t len, ssize_t* esclen ) { + if (s == NULL || len <= 1 || s[0] != '\x1B') return false; + if (esclen != NULL) *esclen = 0; + if (strchr("[PX^_]",s[1]) != NULL) { + // CSI (ESC [), DCS (ESC P), SOS (ESC X), PM (ESC ^), APC (ESC _), and OSC (ESC ]): terminated with a special sequence + bool finalCSI = (s[1] == '['); // CSI terminates with 0x40-0x7F; otherwise ST (bell or ESC \) + ssize_t n = 2; + while (len > n) { + char c = s[n++]; + if ((finalCSI && (uint8_t)c >= 0x40 && (uint8_t)c <= 0x7F) || // terminating byte: @A–Z[\]^_`a–z{|}~ + (!finalCSI && c == '\x07') || // bell + (c == '\x02')) // STX terminates as well + { + if (esclen != NULL) *esclen = n; + return true; + } + else if (!finalCSI && c == '\x1B' && len > n && s[n] == '\\') { // ST (ESC \) + n++; + if (esclen != NULL) *esclen = n; + return true; + } + } + } + if (strchr(" #%()*+",s[1]) != NULL) { + // assume escape sequence of length 3 (like ESC % G) + if (esclen != NULL) *esclen = 2; + return true; + } + else { + // assume single character escape code (like ESC 7) + if (esclen != NULL) *esclen = 2; + return true; + } + return false; +} + +// Offset to the next codepoint, treats CSI escape sequences as a single code point. +ic_private ssize_t str_next_ofs( const char* s, ssize_t len, ssize_t pos, ssize_t* cwidth ) { + ssize_t ofs = 0; + if (s != NULL && len > pos) { + if (skip_esc(s+pos,len-pos,&ofs)) { + // skip escape sequence + } + else { + ofs = 1; + // utf8 extended character? + while(len > pos + ofs) { + uint8_t u = (uint8_t)s[pos + ofs]; + if (u < 0x80 || u > 0xBF) break; // break if not a follower + ofs++; + } + } + } + if (cwidth != NULL) *cwidth = char_column_width( s+pos, ofs ); + return ofs; +} + +static ssize_t str_limit_to_length( const char* s, ssize_t n ) { + ssize_t i; + for(i = 0; i < n && s[i] != 0; i++) { /* nothing */ } + return i; +} + + +//------------------------------------------------------------- +// String searching prev/next word, line, ws_word +//------------------------------------------------------------- + + +static ssize_t str_find_backward( const char* s, ssize_t len, ssize_t pos, ic_is_char_class_fun_t* match, bool skip_immediate_matches ) { + if (pos > len) pos = len; + if (pos < 0) pos = 0; + ssize_t i = pos; + // skip matching first (say, whitespace in case of the previous start-of-word) + if (skip_immediate_matches) { + do { + ssize_t prev = str_prev_ofs(s, i, NULL); + if (prev <= 0) break; + assert(i - prev >= 0); + if (!match(s + i - prev, (long)prev)) break; + i -= prev; + } while (i > 0); + } + // find match + do { + ssize_t prev = str_prev_ofs(s, i, NULL); + if (prev <= 0) break; + assert(i - prev >= 0); + if (match(s + i - prev, (long)prev)) { + return i; // found; + } + i -= prev; + } while (i > 0); + return -1; // not found +} + +static ssize_t str_find_forward( const char* s, ssize_t len, ssize_t pos, ic_is_char_class_fun_t* match, bool skip_immediate_matches ) { + if (s == NULL || len < 0) return -1; + if (pos > len) pos = len; + if (pos < 0) pos = 0; + ssize_t i = pos; + ssize_t next; + // skip matching first (say, whitespace in case of the next end-of-word) + if (skip_immediate_matches) { + do { + next = str_next_ofs(s, len, i, NULL); + if (next <= 0) break; + assert( i + next <= len); + if (!match(s + i, (long)next)) break; + i += next; + } while (i < len); + } + // and then look + do { + next = str_next_ofs(s, len, i, NULL); + if (next <= 0) break; + assert( i + next <= len); + if (match(s + i, (long)next)) { + return i; // found + } + i += next; + } while (i < len); + return -1; +} + +static bool char_is_linefeed( const char* s, long n ) { + return (n == 1 && (*s == '\n' || *s == 0)); +} + +static ssize_t str_find_line_start( const char* s, ssize_t len, ssize_t pos) { + ssize_t start = str_find_backward(s,len,pos,&char_is_linefeed,false /* don't skip immediate matches */); + return (start < 0 ? 0 : start); +} + +static ssize_t str_find_line_end( const char* s, ssize_t len, ssize_t pos) { + ssize_t end = str_find_forward(s,len,pos, &char_is_linefeed, false); + return (end < 0 ? len : end); +} + +static ssize_t str_find_word_start( const char* s, ssize_t len, ssize_t pos) { + ssize_t start = str_find_backward(s,len,pos, &ic_char_is_idletter,true /* skip immediate matches */); + return (start < 0 ? 0 : start); +} + +static ssize_t str_find_word_end( const char* s, ssize_t len, ssize_t pos) { + ssize_t end = str_find_forward(s,len,pos,&ic_char_is_idletter,true /* skip immediate matches */); + return (end < 0 ? len : end); +} + +static ssize_t str_find_ws_word_start( const char* s, ssize_t len, ssize_t pos) { + ssize_t start = str_find_backward(s,len,pos,&ic_char_is_white,true /* skip immediate matches */); + return (start < 0 ? 0 : start); +} + +static ssize_t str_find_ws_word_end( const char* s, ssize_t len, ssize_t pos) { + ssize_t end = str_find_forward(s,len,pos,&ic_char_is_white,true /* skip immediate matches */); + return (end < 0 ? len : end); +} + + +//------------------------------------------------------------- +// String row/column iteration +//------------------------------------------------------------- + +// invoke a function for each terminal row; returns total row count. +static ssize_t str_for_each_row( const char* s, ssize_t len, ssize_t termw, ssize_t promptw, ssize_t cpromptw, + row_fun_t* fun, const void* arg, void* res ) +{ + if (s == NULL) s = ""; + ssize_t i; + ssize_t rcount = 0; + ssize_t rcol = 0; + ssize_t rstart = 0; + ssize_t startw = promptw; + for(i = 0; i < len; ) { + ssize_t w; + ssize_t next = str_next_ofs(s, len, i, &w); + if (next <= 0) { + debug_msg("str: foreach row: next<=0: len %zd, i %zd, w %zd, buf %s\n", len, i, w, s ); + assert(false); + break; + } + startw = (rcount == 0 ? promptw : cpromptw); + ssize_t termcol = rcol + w + startw + 1 /* for the cursor */; + if (termw != 0 && i != 0 && termcol >= termw) { + // wrap + if (fun != NULL) { + if (fun(s,rcount,rstart,i - rstart,startw,true,arg,res)) return rcount; + } + rcount++; + rstart = i; + rcol = 0; + } + if (s[i] == '\n') { + // newline + if (fun != NULL) { + if (fun(s,rcount,rstart,i - rstart,startw,false,arg,res)) return rcount; + } + rcount++; + rstart = i+1; + rcol = 0; + } + assert (s[i] != 0); + i += next; + rcol += w; + } + if (fun != NULL) { + if (fun(s,rcount,rstart,i - rstart,startw,false,arg,res)) return rcount; + } + return rcount+1; +} + +//------------------------------------------------------------- +// String: get row/column position +//------------------------------------------------------------- + + +static bool str_get_current_pos_iter( + const char* s, + ssize_t row, ssize_t row_start, ssize_t row_len, + ssize_t startw, bool is_wrap, const void* arg, void* res) +{ + ic_unused(is_wrap); ic_unused(startw); + rowcol_t* rc = (rowcol_t*)res; + ssize_t pos = *((ssize_t*)arg); + + if (pos >= row_start && pos <= (row_start + row_len)) { + // found the cursor row + rc->row_start = row_start; + rc->row_len = row_len; + rc->row = row; + rc->col = str_column_width_n( s + row_start, pos - row_start ); + rc->first_on_row = (pos == row_start); + if (is_wrap) { + // if wrapped, we check if the next character is at row_len + ssize_t next = str_next_ofs(s, row_start + row_len, pos, NULL); + rc->last_on_row = (pos + next >= row_start + row_len); + } + else { + // normal last position is right after the last character + rc->last_on_row = (pos >= row_start + row_len); + } + // debug_msg("edit; pos iter: pos: %zd (%c), row_start: %zd, rowlen: %zd\n", pos, s[pos], row_start, row_len); + } + return false; // always continue to count all rows +} + +static ssize_t str_get_rc_at_pos(const char* s, ssize_t len, ssize_t termw, ssize_t promptw, ssize_t cpromptw, ssize_t pos, rowcol_t* rc) { + memset(rc, 0, sizeof(*rc)); + ssize_t rows = str_for_each_row(s, len, termw, promptw, cpromptw, &str_get_current_pos_iter, &pos, rc); + // debug_msg("edit: current pos: (%d, %d) %s %s\n", rc->row, rc->col, rc->first_on_row ? "first" : "", rc->last_on_row ? "last" : ""); + return rows; +} + + + +//------------------------------------------------------------- +// String: get row/column position for a resized terminal +// with potentially "hard-wrapped" rows +//------------------------------------------------------------- +typedef struct wrapped_arg_s { + ssize_t pos; + ssize_t newtermw; +} wrapped_arg_t; + +typedef struct wrowcol_s { + rowcol_t rc; + ssize_t hrows; // count of hard-wrapped extra rows +} wrowcol_t; + +static bool str_get_current_wrapped_pos_iter( + const char* s, + ssize_t row, ssize_t row_start, ssize_t row_len, + ssize_t startw, bool is_wrap, const void* arg, void* res) +{ + ic_unused(is_wrap); + wrowcol_t* wrc = (wrowcol_t*)res; + const wrapped_arg_t* warg = (const wrapped_arg_t*)arg; + + // iterate through the row and record the postion and hard-wraps + ssize_t hwidth = startw; + ssize_t i = 0; + while( i <= row_len ) { // include rowlen as the cursor position can be just after the last character + // get next position and column width + ssize_t cw; + ssize_t next; + bool is_cursor = (warg->pos == row_start+i); + if (i < row_len) { + next = str_next_ofs(s + row_start, row_len, i, &cw); + } + else { + // end of row: take wrap or cursor into account + // (wrap has width 2 as it displays a back-arrow but also has an invisible newline that wraps) + cw = (is_wrap ? 2 : (is_cursor ? 1 : 0)); + next = 1; + } + + if (next > 0) { + if (hwidth + cw > warg->newtermw) { + // hardwrap + hwidth = 0; + wrc->hrows++; + debug_msg("str: found hardwrap: row: %zd, hrows: %zd\n", row, wrc->hrows); + } + } + else { + next++; // ensure we terminate (as we go up to rowlen) + } + + // did we find our position? + if (is_cursor) { + debug_msg("str: found position: row: %zd, hrows: %zd\n", row, wrc->hrows); + wrc->rc.row_start = row_start; + wrc->rc.row_len = row_len; + wrc->rc.row = wrc->hrows + row; + wrc->rc.col = hwidth; + wrc->rc.first_on_row = (i==0); + wrc->rc.last_on_row = (i+next >= row_len - (is_wrap ? 1 : 0)); + } + + // advance + hwidth += cw; + i += next; + } + return false; // always continue to count all rows +} + + +static ssize_t str_get_wrapped_rc_at_pos(const char* s, ssize_t len, ssize_t termw, ssize_t newtermw, ssize_t promptw, ssize_t cpromptw, ssize_t pos, rowcol_t* rc) { + wrapped_arg_t warg; + warg.pos = pos; + warg.newtermw = newtermw; + wrowcol_t wrc; + memset(&wrc,0,sizeof(wrc)); + ssize_t rows = str_for_each_row(s, len, termw, promptw, cpromptw, &str_get_current_wrapped_pos_iter, &warg, &wrc); + debug_msg("edit: wrapped pos: (%zd,%zd) rows %zd %s %s, hrows: %zd\n", wrc.rc.row, wrc.rc.col, rows, wrc.rc.first_on_row ? "first" : "", wrc.rc.last_on_row ? "last" : "", wrc.hrows); + *rc = wrc.rc; + return (rows + wrc.hrows); +} + + +//------------------------------------------------------------- +// Set position +//------------------------------------------------------------- + +static bool str_set_pos_iter( + const char* s, + ssize_t row, ssize_t row_start, ssize_t row_len, + ssize_t startw, bool is_wrap, const void* arg, void* res) +{ + ic_unused(arg); ic_unused(is_wrap); ic_unused(startw); + rowcol_t* rc = (rowcol_t*)arg; + if (rc->row != row) return false; // keep searching + // we found our row + ssize_t col = 0; + ssize_t i = row_start; + ssize_t end = row_start + row_len; + while (col < rc->col && i < end) { + ssize_t cw; + ssize_t next = str_next_ofs(s, row_start + row_len, i, &cw); + if (next <= 0) break; + i += next; + col += cw; + } + *((ssize_t*)res) = i; + return true; // stop iteration +} + +static ssize_t str_get_pos_at_rc(const char* s, ssize_t len, ssize_t termw, ssize_t promptw, ssize_t cpromptw, ssize_t row, ssize_t col /* without prompt */) { + rowcol_t rc; + memset(&rc,0,ssizeof(rc)); + rc.row = row; + rc.col = col; + ssize_t pos = -1; + str_for_each_row(s,len,termw,promptw,cpromptw,&str_set_pos_iter,&rc,&pos); + return pos; +} + + +//------------------------------------------------------------- +// String buffer +//------------------------------------------------------------- +static bool sbuf_ensure_extra(stringbuf_t* s, ssize_t extra) +{ + if (s->buflen >= s->count + extra) return true; + // reallocate; pick good initial size and multiples to increase reuse on allocation + ssize_t newlen = (s->buflen <= 0 ? 120 : (s->buflen > 1000 ? s->buflen + 1000 : 2*s->buflen)); + if (newlen < s->count + extra) newlen = s->count + extra; + if (s->buflen > 0) { + debug_msg("stringbuf: reallocate: old %zd, new %zd\n", s->buflen, newlen); + } + char* newbuf = mem_realloc_tp(s->mem, char, s->buf, newlen+1); // one more for terminating zero + if (newbuf == NULL) { + assert(false); + return false; + } + s->buf = newbuf; + s->buflen = newlen; + s->buf[s->count] = s->buf[s->buflen] = 0; + assert(s->buflen >= s->count + extra); + return true; +} + +static void sbuf_init( stringbuf_t* sbuf, alloc_t* mem ) { + sbuf->mem = mem; + sbuf->buf = NULL; + sbuf->buflen = 0; + sbuf->count = 0; +} + +static void sbuf_done( stringbuf_t* sbuf ) { + mem_free( sbuf->mem, sbuf->buf ); + sbuf->buf = NULL; + sbuf->buflen = 0; + sbuf->count = 0; +} + + +ic_private void sbuf_free( stringbuf_t* sbuf ) { + if (sbuf==NULL) return; + sbuf_done(sbuf); + mem_free(sbuf->mem, sbuf); +} + +ic_private stringbuf_t* sbuf_new( alloc_t* mem ) { + stringbuf_t* sbuf = mem_zalloc_tp(mem,stringbuf_t); + if (sbuf == NULL) return NULL; + sbuf_init(sbuf,mem); + return sbuf; +} + +// free the sbuf and return the current string buffer as the result +ic_private char* sbuf_free_dup(stringbuf_t* sbuf) { + if (sbuf == NULL) return NULL; + char* s = NULL; + if (sbuf->buf != NULL) { + s = mem_realloc_tp(sbuf->mem, char, sbuf->buf, sbuf_len(sbuf)+1); + if (s == NULL) { s = sbuf->buf; } + sbuf->buf = 0; + sbuf->buflen = 0; + sbuf->count = 0; + } + sbuf_free(sbuf); + return s; +} + +ic_private const char* sbuf_string_at( stringbuf_t* sbuf, ssize_t pos ) { + if (pos < 0 || sbuf->count < pos) return NULL; + if (sbuf->buf == NULL) return ""; + assert(sbuf->buf[sbuf->count] == 0); + return sbuf->buf + pos; +} + +ic_private const char* sbuf_string( stringbuf_t* sbuf ) { + return sbuf_string_at( sbuf, 0 ); +} + +ic_private char sbuf_char_at(stringbuf_t* sbuf, ssize_t pos) { + if (sbuf->buf == NULL || pos < 0 || sbuf->count < pos) return 0; + return sbuf->buf[pos]; +} + +ic_private char* sbuf_strdup_at( stringbuf_t* sbuf, ssize_t pos ) { + return mem_strdup(sbuf->mem, sbuf_string_at(sbuf,pos)); +} + +ic_private char* sbuf_strdup( stringbuf_t* sbuf ) { + return mem_strdup(sbuf->mem, sbuf_string(sbuf)); +} + +ic_private ssize_t sbuf_len(const stringbuf_t* s) { + if (s == NULL) return 0; + return s->count; +} + +ic_private ssize_t sbuf_append_vprintf(stringbuf_t* sb, const char* fmt, va_list args) { + const ssize_t min_needed = ic_strlen(fmt); + if (!sbuf_ensure_extra(sb,min_needed + 16)) return sb->count; + ssize_t avail = sb->buflen - sb->count; + va_list args0; + va_copy(args0, args); + ssize_t needed = vsnprintf(sb->buf + sb->count, to_size_t(avail), fmt, args0); + if (needed > avail) { + sb->buf[sb->count] = 0; + if (!sbuf_ensure_extra(sb, needed)) return sb->count; + avail = sb->buflen - sb->count; + needed = vsnprintf(sb->buf + sb->count, to_size_t(avail), fmt, args); + } + assert(needed <= avail); + sb->count += (needed > avail ? avail : (needed >= 0 ? needed : 0)); + assert(sb->count <= sb->buflen); + sb->buf[sb->count] = 0; + return sb->count; +} + +ic_private ssize_t sbuf_appendf(stringbuf_t* sb, const char* fmt, ...) { + va_list args; + va_start( args, fmt); + ssize_t res = sbuf_append_vprintf( sb, fmt, args ); + va_end(args); + return res; +} + + +ic_private ssize_t sbuf_insert_at_n(stringbuf_t* sbuf, const char* s, ssize_t n, ssize_t pos ) { + if (pos < 0 || pos > sbuf->count || s == NULL) return pos; + n = str_limit_to_length(s,n); + if (n <= 0 || !sbuf_ensure_extra(sbuf,n)) return pos; + ic_memmove(sbuf->buf + pos + n, sbuf->buf + pos, sbuf->count - pos); + ic_memcpy(sbuf->buf + pos, s, n); + sbuf->count += n; + sbuf->buf[sbuf->count] = 0; + return (pos + n); +} + +ic_private stringbuf_t* sbuf_split_at( stringbuf_t* sb, ssize_t pos ) { + stringbuf_t* res = sbuf_new(sb->mem); + if (res==NULL || pos < 0) return NULL; + if (pos < sb->count) { + sbuf_append_n(res, sb->buf + pos, sb->count - pos); + sb->count = pos; + } + return res; +} + +ic_private ssize_t sbuf_insert_at(stringbuf_t* sbuf, const char* s, ssize_t pos ) { + return sbuf_insert_at_n( sbuf, s, ic_strlen(s), pos ); +} + +ic_private ssize_t sbuf_insert_char_at(stringbuf_t* sbuf, char c, ssize_t pos ) { + char s[2]; + s[0] = c; + s[1] = 0; + return sbuf_insert_at_n( sbuf, s, 1, pos); +} + +ic_private ssize_t sbuf_insert_unicode_at(stringbuf_t* sbuf, unicode_t u, ssize_t pos) { + uint8_t s[5]; + unicode_to_qutf8(u, s); + return sbuf_insert_at(sbuf, (const char*)s, pos); +} + + + +ic_private void sbuf_delete_at( stringbuf_t* sbuf, ssize_t pos, ssize_t count ) { + if (pos < 0 || pos >= sbuf->count) return; + if (pos + count > sbuf->count) count = sbuf->count - pos; + ic_memmove(sbuf->buf + pos, sbuf->buf + pos + count, sbuf->count - pos - count); + sbuf->count -= count; + sbuf->buf[sbuf->count] = 0; +} + +ic_private void sbuf_delete_from_to( stringbuf_t* sbuf, ssize_t pos, ssize_t end ) { + if (end <= pos) return; + sbuf_delete_at( sbuf, pos, end - pos); +} + +ic_private void sbuf_delete_from(stringbuf_t* sbuf, ssize_t pos ) { + sbuf_delete_at(sbuf, pos, sbuf_len(sbuf) - pos ); +} + + +ic_private void sbuf_clear( stringbuf_t* sbuf ) { + sbuf_delete_at(sbuf, 0, sbuf_len(sbuf)); +} + +ic_private ssize_t sbuf_append_n( stringbuf_t* sbuf, const char* s, ssize_t n ) { + return sbuf_insert_at_n( sbuf, s, n, sbuf_len(sbuf)); +} + +ic_private ssize_t sbuf_append( stringbuf_t* sbuf, const char* s ) { + return sbuf_insert_at( sbuf, s, sbuf_len(sbuf)); +} + +ic_private ssize_t sbuf_append_char( stringbuf_t* sbuf, char c ) { + char buf[2]; + buf[0] = c; + buf[1] = 0; + return sbuf_append( sbuf, buf ); +} + +ic_private void sbuf_replace(stringbuf_t* sbuf, const char* s) { + sbuf_clear(sbuf); + sbuf_append(sbuf,s); +} + +ic_private ssize_t sbuf_next_ofs( stringbuf_t* sbuf, ssize_t pos, ssize_t* cwidth ) { + return str_next_ofs( sbuf->buf, sbuf->count, pos, cwidth); +} + +ic_private ssize_t sbuf_prev_ofs( stringbuf_t* sbuf, ssize_t pos, ssize_t* cwidth ) { + return str_prev_ofs( sbuf->buf, pos, cwidth); +} + +ic_private ssize_t sbuf_next( stringbuf_t* sbuf, ssize_t pos, ssize_t* cwidth) { + ssize_t ofs = sbuf_next_ofs(sbuf,pos,cwidth); + if (ofs <= 0) return -1; + assert(pos + ofs <= sbuf->count); + return pos + ofs; +} + +ic_private ssize_t sbuf_prev( stringbuf_t* sbuf, ssize_t pos, ssize_t* cwidth) { + ssize_t ofs = sbuf_prev_ofs(sbuf,pos,cwidth); + if (ofs <= 0) return -1; + assert(pos - ofs >= 0); + return pos - ofs; +} + +ic_private ssize_t sbuf_delete_char_before( stringbuf_t* sbuf, ssize_t pos ) { + ssize_t n = sbuf_prev_ofs(sbuf, pos, NULL); + if (n <= 0) return 0; + assert( pos - n >= 0 ); + sbuf_delete_at(sbuf, pos - n, n); + return pos - n; +} + +ic_private void sbuf_delete_char_at( stringbuf_t* sbuf, ssize_t pos ) { + ssize_t n = sbuf_next_ofs(sbuf, pos, NULL); + if (n <= 0) return; + assert( pos + n <= sbuf->count ); + sbuf_delete_at(sbuf, pos, n); + return; +} + +ic_private ssize_t sbuf_swap_char( stringbuf_t* sbuf, ssize_t pos ) { + ssize_t next = sbuf_next_ofs(sbuf, pos, NULL); + if (next <= 0) return 0; + ssize_t prev = sbuf_prev_ofs(sbuf, pos, NULL); + if (prev <= 0) return 0; + char buf[64]; + if (prev >= 63) return 0; + ic_memcpy(buf, sbuf->buf + pos - prev, prev ); + ic_memmove(sbuf->buf + pos - prev, sbuf->buf + pos, next); + ic_memmove(sbuf->buf + pos - prev + next, buf, prev); + return pos - prev; +} + +ic_private ssize_t sbuf_find_line_start( stringbuf_t* sbuf, ssize_t pos ) { + return str_find_line_start( sbuf->buf, sbuf->count, pos); +} + +ic_private ssize_t sbuf_find_line_end( stringbuf_t* sbuf, ssize_t pos ) { + return str_find_line_end( sbuf->buf, sbuf->count, pos); +} + +ic_private ssize_t sbuf_find_word_start( stringbuf_t* sbuf, ssize_t pos ) { + return str_find_word_start( sbuf->buf, sbuf->count, pos); +} + +ic_private ssize_t sbuf_find_word_end( stringbuf_t* sbuf, ssize_t pos ) { + return str_find_word_end( sbuf->buf, sbuf->count, pos); +} + +ic_private ssize_t sbuf_find_ws_word_start( stringbuf_t* sbuf, ssize_t pos ) { + return str_find_ws_word_start( sbuf->buf, sbuf->count, pos); +} + +ic_private ssize_t sbuf_find_ws_word_end( stringbuf_t* sbuf, ssize_t pos ) { + return str_find_ws_word_end( sbuf->buf, sbuf->count, pos); +} + +// find row/col position +ic_private ssize_t sbuf_get_pos_at_rc( stringbuf_t* sbuf, ssize_t termw, ssize_t promptw, ssize_t cpromptw, ssize_t row, ssize_t col ) { + return str_get_pos_at_rc( sbuf->buf, sbuf->count, termw, promptw, cpromptw, row, col); +} + +// get row/col for a given position +ic_private ssize_t sbuf_get_rc_at_pos( stringbuf_t* sbuf, ssize_t termw, ssize_t promptw, ssize_t cpromptw, ssize_t pos, rowcol_t* rc ) { + return str_get_rc_at_pos( sbuf->buf, sbuf->count, termw, promptw, cpromptw, pos, rc); +} + +ic_private ssize_t sbuf_get_wrapped_rc_at_pos( stringbuf_t* sbuf, ssize_t termw, ssize_t newtermw, ssize_t promptw, ssize_t cpromptw, ssize_t pos, rowcol_t* rc ) { + return str_get_wrapped_rc_at_pos( sbuf->buf, sbuf->count, termw, newtermw, promptw, cpromptw, pos, rc); +} + +ic_private ssize_t sbuf_for_each_row( stringbuf_t* sbuf, ssize_t termw, ssize_t promptw, ssize_t cpromptw, row_fun_t* fun, void* arg, void* res ) { + if (sbuf == NULL) return 0; + return str_for_each_row( sbuf->buf, sbuf->count, termw, promptw, cpromptw, fun, arg, res); +} + + +// Duplicate and decode from utf-8 (for non-utf8 terminals) +ic_private char* sbuf_strdup_from_utf8(stringbuf_t* sbuf) { + ssize_t len = sbuf_len(sbuf); + if (sbuf == NULL || len <= 0) return NULL; + char* s = mem_zalloc_tp_n(sbuf->mem, char, len); + if (s == NULL) return NULL; + ssize_t dest = 0; + for (ssize_t i = 0; i < len; ) { + ssize_t ofs = sbuf_next_ofs(sbuf, i, NULL); + if (ofs <= 0) { + // invalid input + break; + } + else if (ofs == 1) { + // regular character + s[dest++] = sbuf->buf[i]; + } + else if (sbuf->buf[i] == '\x1B') { + // skip escape sequences + } + else { + // decode unicode + ssize_t nread; + unicode_t uchr = unicode_from_qutf8( (const uint8_t*)(sbuf->buf + i), ofs, &nread); + uint8_t c; + if (unicode_is_raw(uchr, &c)) { + // raw byte, output as is (this will take care of locale specific input) + s[dest++] = (char)c; + } + else if (uchr <= 0x7F) { + // allow ascii + s[dest++] = (char)uchr; + } + else { + // skip unknown unicode characters.. + // todo: convert according to locale? + } + } + i += ofs; + } + assert(dest <= len); + s[dest] = 0; + return s; +} + +//------------------------------------------------------------- +// String helpers +//------------------------------------------------------------- + +ic_public long ic_prev_char( const char* s, long pos ) { + ssize_t len = ic_strlen(s); + if (pos < 0 || pos > len) return -1; + ssize_t ofs = str_prev_ofs( s, pos, NULL ); + if (ofs <= 0) return -1; + return (long)(pos - ofs); +} + +ic_public long ic_next_char( const char* s, long pos ) { + ssize_t len = ic_strlen(s); + if (pos < 0 || pos > len) return -1; + ssize_t ofs = str_next_ofs( s, len, pos, NULL ); + if (ofs <= 0) return -1; + return (long)(pos + ofs); +} + + +// parse a decimal (leave pi unchanged on error) +ic_private bool ic_atoz(const char* s, ssize_t* pi) { + return (sscanf(s, "%zd", pi) == 1); +} + +// parse two decimals separated by a semicolon +ic_private bool ic_atoz2(const char* s, ssize_t* pi, ssize_t* pj) { + return (sscanf(s, "%zd;%zd", pi, pj) == 2); +} + +// parse unsigned 32-bit (leave pu unchanged on error) +ic_private bool ic_atou32(const char* s, uint32_t* pu) { + return (sscanf(s, "%" SCNu32, pu) == 1); +} + + +// Convenience: character class for whitespace `[ \t\r\n]`. +ic_public bool ic_char_is_white(const char* s, long len) { + if (s == NULL || len != 1) return false; + const char c = *s; + return (c==' ' || c == '\t' || c == '\n' || c == '\r'); +} + +// Convenience: character class for non-whitespace `[^ \t\r\n]`. +ic_public bool ic_char_is_nonwhite(const char* s, long len) { + return !ic_char_is_white(s, len); +} + +// Convenience: character class for separators `[ \t\r\n,.;:/\\\(\)\{\}\[\]]`. +ic_public bool ic_char_is_separator(const char* s, long len) { + if (s == NULL || len != 1) return false; + const char c = *s; + return (strchr(" \t\r\n,.;:/\\(){}[]", c) != NULL); +} + +// Convenience: character class for non-separators. +ic_public bool ic_char_is_nonseparator(const char* s, long len) { + return !ic_char_is_separator(s, len); +} + + +// Convenience: character class for digits (`[0-9]`). +ic_public bool ic_char_is_digit(const char* s, long len) { + if (s == NULL || len != 1) return false; + const char c = *s; + return (c >= '0' && c <= '9'); +} + +// Convenience: character class for hexadecimal digits (`[A-Fa-f0-9]`). +ic_public bool ic_char_is_hexdigit(const char* s, long len) { + if (s == NULL || len != 1) return false; + const char c = *s; + return ((c >= '0' && c <= '9') || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F')); +} + +// Convenience: character class for letters (`[A-Za-z]` and any unicode > 0x80). +ic_public bool ic_char_is_letter(const char* s, long len) { + if (s == NULL || len <= 0) return false; + const char c = *s; + return ((uint8_t)c >= 0x80 || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z')); +} + +// Convenience: character class for identifier letters (`[A-Za-z0-9_-]` and any unicode > 0x80). +ic_public bool ic_char_is_idletter(const char* s, long len) { + if (s == NULL || len <= 0) return false; + const char c = *s; + return ((uint8_t)c >= 0x80 || (c >= 'A' && c <= 'Z') || (c >= 'a' && c <= 'z') || (c >= '0' && c <= '9') || (c == '_') || (c == '-')); +} + +// Convenience: character class for filename letters (`[^ \t\r\n`@$><=;|&{(]`). +ic_public bool ic_char_is_filename_letter(const char* s, long len) { + if (s == NULL || len <= 0) return false; + const char c = *s; + return ((uint8_t)c >= 0x80 || (strchr(" \t\r\n`@$><=;|&{}()[]", c) == NULL)); +} + +// Convenience: If this is a token start, returns the length (or <= 0 if not found). +ic_public long ic_is_token(const char* s, long pos, ic_is_char_class_fun_t* is_token_char) { + if (s == NULL || pos < 0 || is_token_char == NULL) return -1; + ssize_t len = ic_strlen(s); + if (pos >= len) return -1; + if (pos > 0 && is_token_char(s + pos -1, 1)) return -1; // token start? + ssize_t i = pos; + while ( i < len ) { + ssize_t next = str_next_ofs(s, len, i, NULL); + if (next <= 0) return -1; + if (!is_token_char(s + i, (long)next)) break; + i += next; + } + return (long)(i - pos); +} + + +static int ic_strncmp(const char* s1, const char* s2, ssize_t n) { + return strncmp(s1, s2, to_size_t(n)); +} + +// Convenience: Does this match the specified token? +// Ensures not to match prefixes or suffixes, and returns the length of the match (in bytes). +// E.g. `ic_match_token("function",0,&ic_char_is_letter,"fun")` returns 0. +ic_public long ic_match_token(const char* s, long pos, ic_is_char_class_fun_t* is_token_char, const char* token) { + long n = ic_is_token(s, pos, is_token_char); + if (n > 0 && token != NULL && n == ic_strlen(token) && ic_strncmp(s + pos, token, n) == 0) { + return n; + } + else { + return 0; + } +} + + +// Convenience: Do any of the specified tokens match? +// Ensures not to match prefixes or suffixes, and returns the length of the match (in bytes). +// Ensures not to match prefixes or suffixes. +// E.g. `ic_match_any_token("function",0,&ic_char_is_letter,{"fun","func",NULL})` returns 0. +ic_public long ic_match_any_token(const char* s, long pos, ic_is_char_class_fun_t* is_token_char, const char** tokens) { + long n = ic_is_token(s, pos, is_token_char); + if (n <= 0 || tokens == NULL) return 0; + for (const char** token = tokens; *token != NULL; token++) { + if (n == ic_strlen(*token) && ic_strncmp(s + pos, *token, n) == 0) { + return n; + } + } + return 0; +} + diff --git a/extern/isocline/src/stringbuf.h b/extern/isocline/src/stringbuf.h new file mode 100644 index 000000000..39b21ea42 --- /dev/null +++ b/extern/isocline/src/stringbuf.h @@ -0,0 +1,121 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_STRINGBUF_H +#define IC_STRINGBUF_H + +#include +#include "common.h" + +//------------------------------------------------------------- +// string buffer +// in-place modified buffer with edit operations +// that grows on demand. +//------------------------------------------------------------- + +// abstract string buffer +struct stringbuf_s; +typedef struct stringbuf_s stringbuf_t; + +ic_private stringbuf_t* sbuf_new( alloc_t* mem ); +ic_private void sbuf_free( stringbuf_t* sbuf ); +ic_private char* sbuf_free_dup(stringbuf_t* sbuf); +ic_private ssize_t sbuf_len(const stringbuf_t* s); + +ic_private const char* sbuf_string_at( stringbuf_t* sbuf, ssize_t pos ); +ic_private const char* sbuf_string( stringbuf_t* sbuf ); +ic_private char sbuf_char_at(stringbuf_t* sbuf, ssize_t pos); +ic_private char* sbuf_strdup_at( stringbuf_t* sbuf, ssize_t pos ); +ic_private char* sbuf_strdup( stringbuf_t* sbuf ); +ic_private char* sbuf_strdup_from_utf8(stringbuf_t* sbuf); // decode to locale + + +ic_private ssize_t sbuf_appendf(stringbuf_t* sb, const char* fmt, ...); +ic_private ssize_t sbuf_append_vprintf(stringbuf_t* sb, const char* fmt, va_list args); + +ic_private stringbuf_t* sbuf_split_at( stringbuf_t* sb, ssize_t pos ); + +// primitive edit operations (inserts return the new position) +ic_private void sbuf_clear(stringbuf_t* sbuf); +ic_private void sbuf_replace(stringbuf_t* sbuf, const char* s); +ic_private void sbuf_delete_at(stringbuf_t* sbuf, ssize_t pos, ssize_t count); +ic_private void sbuf_delete_from_to(stringbuf_t* sbuf, ssize_t pos, ssize_t end); +ic_private void sbuf_delete_from(stringbuf_t* sbuf, ssize_t pos ); +ic_private ssize_t sbuf_insert_at_n(stringbuf_t* sbuf, const char* s, ssize_t n, ssize_t pos ); +ic_private ssize_t sbuf_insert_at(stringbuf_t* sbuf, const char* s, ssize_t pos ); +ic_private ssize_t sbuf_insert_char_at(stringbuf_t* sbuf, char c, ssize_t pos ); +ic_private ssize_t sbuf_insert_unicode_at(stringbuf_t* sbuf, unicode_t u, ssize_t pos); +ic_private ssize_t sbuf_append_n(stringbuf_t* sbuf, const char* s, ssize_t n); +ic_private ssize_t sbuf_append(stringbuf_t* sbuf, const char* s); +ic_private ssize_t sbuf_append_char(stringbuf_t* sbuf, char c); + +// high level edit operations (return the new position) +ic_private ssize_t sbuf_next( stringbuf_t* sbuf, ssize_t pos, ssize_t* cwidth ); +ic_private ssize_t sbuf_prev( stringbuf_t* sbuf, ssize_t pos, ssize_t* cwidth ); +ic_private ssize_t sbuf_next_ofs(stringbuf_t* sbuf, ssize_t pos, ssize_t* cwidth); + +ic_private ssize_t sbuf_delete_char_before( stringbuf_t* sbuf, ssize_t pos ); +ic_private void sbuf_delete_char_at( stringbuf_t* sbuf, ssize_t pos ); +ic_private ssize_t sbuf_swap_char( stringbuf_t* sbuf, ssize_t pos ); + +ic_private ssize_t sbuf_find_line_start( stringbuf_t* sbuf, ssize_t pos ); +ic_private ssize_t sbuf_find_line_end( stringbuf_t* sbuf, ssize_t pos ); +ic_private ssize_t sbuf_find_word_start( stringbuf_t* sbuf, ssize_t pos ); +ic_private ssize_t sbuf_find_word_end( stringbuf_t* sbuf, ssize_t pos ); +ic_private ssize_t sbuf_find_ws_word_start( stringbuf_t* sbuf, ssize_t pos ); +ic_private ssize_t sbuf_find_ws_word_end( stringbuf_t* sbuf, ssize_t pos ); + +// parse a decimal +ic_private bool ic_atoz(const char* s, ssize_t* i); +// parse two decimals separated by a semicolon +ic_private bool ic_atoz2(const char* s, ssize_t* i, ssize_t* j); +ic_private bool ic_atou32(const char* s, uint32_t* pu); + +// row/column info +typedef struct rowcol_s { + ssize_t row; + ssize_t col; + ssize_t row_start; + ssize_t row_len; + bool first_on_row; + bool last_on_row; +} rowcol_t; + +// find row/col position +ic_private ssize_t sbuf_get_pos_at_rc( stringbuf_t* sbuf, ssize_t termw, ssize_t promptw, ssize_t cpromptw, + ssize_t row, ssize_t col ); +// get row/col for a given position +ic_private ssize_t sbuf_get_rc_at_pos( stringbuf_t* sbuf, ssize_t termw, ssize_t promptw, ssize_t cpromptw, + ssize_t pos, rowcol_t* rc ); + +ic_private ssize_t sbuf_get_wrapped_rc_at_pos( stringbuf_t* sbuf, ssize_t termw, ssize_t newtermw, ssize_t promptw, ssize_t cpromptw, + ssize_t pos, rowcol_t* rc ); + +// row iteration +typedef bool (row_fun_t)(const char* s, + ssize_t row, ssize_t row_start, ssize_t row_len, + ssize_t startw, // prompt width + bool is_wrap, const void* arg, void* res); + +ic_private ssize_t sbuf_for_each_row( stringbuf_t* sbuf, ssize_t termw, ssize_t promptw, ssize_t cpromptw, + row_fun_t* fun, void* arg, void* res ); + + +//------------------------------------------------------------- +// Strings +//------------------------------------------------------------- + +// skip a single CSI sequence (ESC [ ...) +ic_private bool skip_csi_esc( const char* s, ssize_t len, ssize_t* esclen ); // used in term.c + +ic_private ssize_t str_column_width( const char* s ); +ic_private ssize_t str_prev_ofs( const char* s, ssize_t pos, ssize_t* cwidth ); +ic_private ssize_t str_next_ofs( const char* s, ssize_t len, ssize_t pos, ssize_t* cwidth ); +ic_private ssize_t str_skip_until_fit( const char* s, ssize_t max_width); // tail that fits +ic_private ssize_t str_take_while_fit( const char* s, ssize_t max_width); // prefix that fits + +#endif // IC_STRINGBUF_H diff --git a/extern/isocline/src/term.c b/extern/isocline/src/term.c new file mode 100644 index 000000000..c55d9ae5d --- /dev/null +++ b/extern/isocline/src/term.c @@ -0,0 +1,1124 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#include +#include +#include +#include // getenv +#include + +#include "common.h" +#include "tty.h" +#include "term.h" +#include "stringbuf.h" // str_next_ofs + +#if defined(_WIN32) +#include +#define STDOUT_FILENO 1 +#else +#include +#include +#include +#if defined(__linux__) +#include +#endif +#endif + +#define IC_CSI "\x1B[" + +// color support; colors are auto mapped smaller palettes if needed. (see `term_color.c`) +typedef enum palette_e { + MONOCHROME, // no color + ANSI8, // only basic 8 ANSI color (ESC[m, idx: 30-37, +10 for background) + ANSI16, // basic + bright ANSI colors (ESC[m, idx: 30-37, 90-97, +10 for background) + ANSI256, // ANSI 256 color palette (ESC[38;5;m, idx: 0-15 standard color, 16-231 6x6x6 rbg colors, 232-255 gray shades) + ANSIRGB // direct rgb colors supported (ESC[38;2;;;m) +} palette_t; + +// The terminal screen +struct term_s { + int fd_out; // output handle + ssize_t width; // screen column width + ssize_t height; // screen row height + ssize_t raw_enabled; // is raw mode active? counted by start/end pairs + bool nocolor; // show colors? + bool silent; // enable beep? + bool is_utf8; // utf-8 output? determined by the tty + attr_t attr; // current text attributes + palette_t palette; // color support + buffer_mode_t bufmode; // buffer mode + stringbuf_t* buf; // buffer for buffered output + tty_t* tty; // used on posix to get the cursor position + alloc_t* mem; // allocator + #ifdef _WIN32 + HANDLE hcon; // output console handler + WORD hcon_default_attr; // default text attributes + WORD hcon_orig_attr; // original text attributes + DWORD hcon_orig_mode; // original console mode + DWORD hcon_mode; // used console mode + UINT hcon_orig_cp; // original console code-page (locale) + COORD hcon_save_cursor; // saved cursor position (for escape sequence emulation) + #endif +}; + +static bool term_write_direct(term_t* term, const char* s, ssize_t n ); +static void term_append_buf(term_t* term, const char* s, ssize_t n); + +//------------------------------------------------------------- +// Colors +//------------------------------------------------------------- + +#include "term_color.c" + +//------------------------------------------------------------- +// Helpers +//------------------------------------------------------------- + +ic_private void term_left(term_t* term, ssize_t n) { + if (n <= 0) return; + term_writef( term, IC_CSI "%zdD", n ); +} + +ic_private void term_right(term_t* term, ssize_t n) { + if (n <= 0) return; + term_writef( term, IC_CSI "%zdC", n ); +} + +ic_private void term_up(term_t* term, ssize_t n) { + if (n <= 0) return; + term_writef( term, IC_CSI "%zdA", n ); +} + +ic_private void term_down(term_t* term, ssize_t n) { + if (n <= 0) return; + term_writef( term, IC_CSI "%zdB", n ); +} + +ic_private void term_clear_line(term_t* term) { + term_write( term, "\r" IC_CSI "K"); +} + +ic_private void term_clear_to_end_of_line(term_t* term) { + term_write(term, IC_CSI "K"); +} + +ic_private void term_start_of_line(term_t* term) { + term_write( term, "\r" ); +} + +ic_private ssize_t term_get_width(term_t* term) { + return term->width; +} + +ic_private ssize_t term_get_height(term_t* term) { + return term->height; +} + +ic_private void term_attr_reset(term_t* term) { + term_write(term, IC_CSI "m" ); +} + +ic_private void term_underline(term_t* term, bool on) { + term_write(term, on ? IC_CSI "4m" : IC_CSI "24m" ); +} + +ic_private void term_reverse(term_t* term, bool on) { + term_write(term, on ? IC_CSI "7m" : IC_CSI "27m"); +} + +ic_private void term_bold(term_t* term, bool on) { + term_write(term, on ? IC_CSI "1m" : IC_CSI "22m" ); +} + +ic_private void term_italic(term_t* term, bool on) { + term_write(term, on ? IC_CSI "3m" : IC_CSI "23m" ); +} + +ic_private void term_writeln(term_t* term, const char* s) { + term_write(term,s); + term_write(term,"\n"); +} + +ic_private void term_write_char(term_t* term, char c) { + char buf[2]; + buf[0] = c; + buf[1] = 0; + term_write_n(term, buf, 1 ); +} + +ic_private attr_t term_get_attr( const term_t* term ) { + return term->attr; +} + +ic_private void term_set_attr( term_t* term, attr_t attr ) { + if (term->nocolor) return; + if (attr.x.color != term->attr.x.color && attr.x.color != IC_COLOR_NONE) { + term_color(term,attr.x.color); + if (term->palette < ANSIRGB && color_is_rgb(attr.x.color)) { + term->attr.x.color = attr.x.color; // actual color may have been approximated but we keep the actual color to avoid updating every time + } + } + if (attr.x.bgcolor != term->attr.x.bgcolor && attr.x.bgcolor != IC_COLOR_NONE) { + term_bgcolor(term,attr.x.bgcolor); + if (term->palette < ANSIRGB && color_is_rgb(attr.x.bgcolor)) { + term->attr.x.bgcolor = attr.x.bgcolor; + } + } + if (attr.x.bold != term->attr.x.bold && attr.x.bold != IC_NONE) { + term_bold(term,attr.x.bold == IC_ON); + } + if (attr.x.underline != term->attr.x.underline && attr.x.underline != IC_NONE) { + term_underline(term,attr.x.underline == IC_ON); + } + if (attr.x.reverse != term->attr.x.reverse && attr.x.reverse != IC_NONE) { + term_reverse(term,attr.x.reverse == IC_ON); + } + if (attr.x.italic != term->attr.x.italic && attr.x.italic != IC_NONE) { + term_italic(term,attr.x.italic == IC_ON); + } + assert(attr.x.color == term->attr.x.color || attr.x.color == IC_COLOR_NONE); + assert(attr.x.bgcolor == term->attr.x.bgcolor || attr.x.bgcolor == IC_COLOR_NONE); + assert(attr.x.bold == term->attr.x.bold || attr.x.bold == IC_NONE); + assert(attr.x.reverse == term->attr.x.reverse || attr.x.reverse == IC_NONE); + assert(attr.x.underline == term->attr.x.underline || attr.x.underline == IC_NONE); + assert(attr.x.italic == term->attr.x.italic || attr.x.italic == IC_NONE); +} + + +/* +ic_private void term_clear_lines_to_end(term_t* term) { + term_write(term, "\r" IC_CSI "J"); +} + +ic_private void term_show_cursor(term_t* term, bool on) { + term_write(term, on ? IC_CSI "?25h" : IC_CSI "?25l"); +} +*/ + +//------------------------------------------------------------- +// Formatted output +//------------------------------------------------------------- + +ic_private void term_writef(term_t* term, const char* fmt, ...) { + va_list ap; + va_start(ap, fmt); + term_vwritef(term,fmt,ap); + va_end(ap); +} + +ic_private void term_vwritef(term_t* term, const char* fmt, va_list args ) { + sbuf_append_vprintf(term->buf, fmt, args); +} + +ic_private void term_write_formatted( term_t* term, const char* s, const attr_t* attrs ) { + term_write_formatted_n( term, s, attrs, ic_strlen(s)); +} + +ic_private void term_write_formatted_n( term_t* term, const char* s, const attr_t* attrs, ssize_t len ) { + if (attrs == NULL) { + // write directly + term_write(term,s); + } + else { + // ensure raw mode from now on + if (term->raw_enabled <= 0) { + term_start_raw(term); + } + // and output with text attributes + const attr_t default_attr = term_get_attr(term); + attr_t attr = attr_none(); + ssize_t i = 0; + ssize_t n = 0; + while( i+n < len && s[i+n] != 0 ) { + if (!attr_is_eq(attr,attrs[i+n])) { + if (n > 0) { + term_write_n( term, s+i, n ); + i += n; + n = 0; + } + attr = attrs[i]; + term_set_attr( term, attr_update_with(default_attr,attr) ); + } + n++; + } + if (n > 0) { + term_write_n( term, s+i, n ); + i += n; + n = 0; + } + assert(s[i] != 0 || i == len); + term_set_attr(term, default_attr); + } +} + +//------------------------------------------------------------- +// Write to the terminal +// The buffered functions are used to reduce cursor flicker +// during refresh +//------------------------------------------------------------- + +ic_private void term_beep(term_t* term) { + if (term->silent) return; + fprintf(stderr,"\x7"); + fflush(stderr); +} + +ic_private void term_write_repeat(term_t* term, const char* s, ssize_t count) { + for (; count > 0; count--) { + term_write(term, s); + } +} + +ic_private void term_write(term_t* term, const char* s) { + if (s == NULL || s[0] == 0) return; + ssize_t n = ic_strlen(s); + term_write_n(term,s,n); +} + +// Primitive terminal write; all writes go through here +ic_private void term_write_n(term_t* term, const char* s, ssize_t n) { + if (s == NULL || n <= 0) return; + // write to buffer to reduce flicker and to process escape sequences (this may flush too) + term_append_buf(term, s, n); +} + + +//------------------------------------------------------------- +// Buffering +//------------------------------------------------------------- + + +ic_private void term_flush(term_t* term) { + if (sbuf_len(term->buf) > 0) { + //term_show_cursor(term,false); + term_write_direct(term, sbuf_string(term->buf), sbuf_len(term->buf)); + //term_show_cursor(term,true); + sbuf_clear(term->buf); + } +} + +ic_private buffer_mode_t term_set_buffer_mode(term_t* term, buffer_mode_t mode) { + buffer_mode_t oldmode = term->bufmode; + if (oldmode != mode) { + if (mode == UNBUFFERED) { + term_flush(term); + } + term->bufmode = mode; + } + return oldmode; +} + +static void term_check_flush(term_t* term, bool contains_nl) { + if (term->bufmode == UNBUFFERED || + sbuf_len(term->buf) > 4000 || + (term->bufmode == LINEBUFFERED && contains_nl)) + { + term_flush(term); + } +} + +//------------------------------------------------------------- +// Init +//------------------------------------------------------------- + +static void term_init_raw(term_t* term); + +ic_private term_t* term_new(alloc_t* mem, tty_t* tty, bool nocolor, bool silent, int fd_out ) +{ + term_t* term = mem_zalloc_tp(mem, term_t); + if (term == NULL) return NULL; + + term->fd_out = (fd_out < 0 ? STDOUT_FILENO : fd_out); + term->nocolor = nocolor || (isatty(term->fd_out) == 0); + term->silent = silent; + term->mem = mem; + term->tty = tty; // can be NULL + term->width = 80; + term->height = 25; + term->is_utf8 = tty_is_utf8(tty); + term->palette = ANSI16; // almost universally supported + term->buf = sbuf_new(mem); + term->bufmode = LINEBUFFERED; + term->attr = attr_default(); + + // respect NO_COLOR + if (getenv("NO_COLOR") != NULL) { + term->nocolor = true; + } + if (!term->nocolor) { + // detect color palette + // COLORTERM takes precedence + const char* colorterm = getenv("COLORTERM"); + const char* eterm = getenv("TERM"); + if (ic_contains(colorterm,"24bit") || ic_contains(colorterm,"truecolor") || ic_contains(colorterm,"direct")) { + term->palette = ANSIRGB; + } + else if (ic_contains(colorterm,"8bit") || ic_contains(colorterm,"256color")) { term->palette = ANSI256; } + else if (ic_contains(colorterm,"4bit") || ic_contains(colorterm,"16color")) { term->palette = ANSI16; } + else if (ic_contains(colorterm,"3bit") || ic_contains(colorterm,"8color")) { term->palette = ANSI8; } + else if (ic_contains(colorterm,"1bit") || ic_contains(colorterm,"nocolor") || ic_contains(colorterm,"monochrome")) { + term->palette = MONOCHROME; + } + // otherwise check for some specific terminals + else if (getenv("WT_SESSION") != NULL) { term->palette = ANSIRGB; } // Windows terminal + else if (getenv("ITERM_SESSION_ID") != NULL) { term->palette = ANSIRGB; } // iTerm2 terminal + else if (getenv("VSCODE_PID") != NULL) { term->palette = ANSIRGB; } // vscode terminal + else { + // and otherwise fall back to checking TERM + if (ic_contains(eterm,"truecolor") || ic_contains(eterm,"direct") || ic_contains(colorterm,"24bit")) { + term->palette = ANSIRGB; + } + else if (ic_contains(eterm,"alacritty") || ic_contains(eterm,"kitty")) { + term->palette = ANSIRGB; + } + else if (ic_contains(eterm,"256color") || ic_contains(eterm,"gnome")) { + term->palette = ANSI256; + } + else if (ic_contains(eterm,"16color")){ term->palette = ANSI16; } + else if (ic_contains(eterm,"8color")) { term->palette = ANSI8; } + else if (ic_contains(eterm,"monochrome") || ic_contains(eterm,"nocolor") || ic_contains(eterm,"dumb")) { + term->palette = MONOCHROME; + } + } + debug_msg("term: color-bits: %d (COLORTERM=%s, TERM=%s)\n", term_get_color_bits(term), colorterm, eterm); + } + + // read COLUMS/LINES from the environment for a better initial guess. + const char* env_columns = getenv("COLUMNS"); + if (env_columns != NULL) { ic_atoz(env_columns, &term->width); } + const char* env_lines = getenv("LINES"); + if (env_lines != NULL) { ic_atoz(env_lines, &term->height); } + + // initialize raw terminal output and terminal dimensions + term_init_raw(term); + term_update_dim(term); + term_attr_reset(term); // ensure we are at default settings + + return term; +} + +ic_private bool term_is_interactive(const term_t* term) { + ic_unused(term); + // check dimensions (0 is used for debuggers) + // if (term->width <= 0) return false; + + // check editing support + const char* eterm = getenv("TERM"); + debug_msg("term: TERM=%s\n", eterm); + if (eterm != NULL && + (strstr("dumb|DUMB|cons25|CONS25|emacs|EMACS",eterm) != NULL)) { + return false; + } + + return true; +} + +ic_private bool term_enable_beep(term_t* term, bool enable) { + bool prev = term->silent; + term->silent = !enable; + return prev; +} + +ic_private bool term_enable_color(term_t* term, bool enable) { + bool prev = !term->nocolor; + term->nocolor = !enable; + return prev; +} + +ic_private void term_free(term_t* term) { + if (term == NULL) return; + term_flush(term); + term_end_raw(term, true); + sbuf_free(term->buf); term->buf = NULL; + mem_free(term->mem, term); +} + +//------------------------------------------------------------- +// For best portability and applications inserting CSI SGR (ESC[ .. m) +// codes themselves in strings, we interpret these at the +// lowest level so we can have a `term_get_attr` function which +// is needed for bracketed styles etc. +//------------------------------------------------------------- + +static void term_append_esc(term_t* term, const char* const s, ssize_t len) { + if (s[1]=='[' && s[len-1] == 'm') { + // it is a CSI SGR sequence: ESC[ ... m + if (term->nocolor) return; // ignore escape sequences if nocolor is set + term->attr = attr_update_with(term->attr, attr_from_esc_sgr(s,len)); + } + // and write out the escape sequence as-is + sbuf_append_n(term->buf, s, len); +} + + +static void term_append_utf8(term_t* term, const char* s, ssize_t len) { + ssize_t nread; + unicode_t uchr = unicode_from_qutf8((const uint8_t*)s, len, &nread); + uint8_t c; + if (unicode_is_raw(uchr, &c)) { + // write bytes as is; this also ensure that on non-utf8 terminals characters between 0x80-0xFF + // go through _as is_ due to the qutf8 encoding. + sbuf_append_char(term->buf,(char)c); + } + else if (!term->is_utf8) { + // on non-utf8 terminals still send utf-8 and hope for the best + // todo: we could try to convert to the locale first? + sbuf_append_n(term->buf, s, len); + // sbuf_appendf(term->buf, "\x1B[%" PRIu32 "u", uchr); // unicode escape code + } + else { + // write utf-8 as is + sbuf_append_n(term->buf, s, len); + } +} + +static void term_append_buf( term_t* term, const char* s, ssize_t len ) { + ssize_t pos = 0; + bool newline = false; + while (pos < len) { + // handle ascii sequences in bulk + ssize_t ascii = 0; + ssize_t next; + while ((next = str_next_ofs(s, len, pos+ascii, NULL)) > 0 && + (uint8_t)s[pos + ascii] > '\x1B' && (uint8_t)s[pos + ascii] <= 0x7F ) + { + ascii += next; + } + if (ascii > 0) { + sbuf_append_n(term->buf, s+pos, ascii); + pos += ascii; + } + if (next <= 0) break; + + const uint8_t c = (uint8_t)s[pos]; + // handle utf8 sequences (for non-utf8 terminals) + if (c >= 0x80) { + term_append_utf8(term, s+pos, next); + } + // handle escape sequence (note: str_next_ofs considers whole CSI escape sequences at a time) + else if (next > 1 && c == '\x1B') { + term_append_esc(term, s+pos, next); + } + else if (c < ' ' && c != 0 && (c < '\x07' || c > '\x0D')) { + // ignore control characters except \a, \b, \t, \n, \r, and form-feed and vertical tab. + } + else { + if (c == '\n') { newline = true; } + sbuf_append_n(term->buf, s+pos, next); + } + pos += next; + } + // possibly flush + term_check_flush(term, newline); +} + +//------------------------------------------------------------- +// Platform dependent: Write directly to the terminal +//------------------------------------------------------------- + +#if !defined(_WIN32) + +// write to the console without further processing +static bool term_write_direct(term_t* term, const char* s, ssize_t n) { + ssize_t count = 0; + while( count < n ) { + ssize_t nwritten = write(term->fd_out, s + count, to_size_t(n - count)); + if (nwritten > 0) { + count += nwritten; + } + else if (errno != EINTR && errno != EAGAIN) { + debug_msg("term: write failed: length %i, errno %i: \"%s\"\n", n, errno, s); + return false; + } + } + return true; +} + +#else + +//---------------------------------------------------------------------------------- +// On windows we use the new virtual terminal processing if it is available (Windows Terminal) +// but fall back to ansi escape emulation on older systems but also for example +// the PS terminal +// +// note: we use row/col as 1-based ANSI escape while windows X/Y coords are 0-based. +//----------------------------------------------------------------------------------- + +#if !defined(ENABLE_VIRTUAL_TERMINAL_PROCESSING) +#define ENABLE_VIRTUAL_TERMINAL_PROCESSING (0) +#endif +#if !defined(ENABLE_LVB_GRID_WORLDWIDE) +#define ENABLE_LVB_GRID_WORLDWIDE (0) +#endif + +// direct write to the console without further processing +static bool term_write_console(term_t* term, const char* s, ssize_t n ) { + DWORD written; + // WriteConsoleA(term->hcon, s, (DWORD)(to_size_t(n)), &written, NULL); + WriteFile(term->hcon, s, (DWORD)(to_size_t(n)), &written, NULL); // so it can be redirected + return (written == (DWORD)(to_size_t(n))); +} + +static bool term_get_cursor_pos( term_t* term, ssize_t* row, ssize_t* col) { + *row = 0; + *col = 0; + CONSOLE_SCREEN_BUFFER_INFO info; + if (!GetConsoleScreenBufferInfo(term->hcon, &info)) return false; + *row = (ssize_t)info.dwCursorPosition.Y + 1; + *col = (ssize_t)info.dwCursorPosition.X + 1; + return true; +} + +static void term_move_cursor_to( term_t* term, ssize_t row, ssize_t col ) { + CONSOLE_SCREEN_BUFFER_INFO info; + if (!GetConsoleScreenBufferInfo( term->hcon, &info )) return; + if (col > info.dwSize.X) col = info.dwSize.X; + if (row > info.dwSize.Y) row = info.dwSize.Y; + if (col <= 0) col = 1; + if (row <= 0) row = 1; + COORD coord; + coord.X = (SHORT)col - 1; + coord.Y = (SHORT)row - 1; + SetConsoleCursorPosition( term->hcon, coord); +} + +static void term_cursor_save(term_t* term) { + memset(&term->hcon_save_cursor, 0, sizeof(term->hcon_save_cursor)); + CONSOLE_SCREEN_BUFFER_INFO info; + if (!GetConsoleScreenBufferInfo(term->hcon, &info)) return; + term->hcon_save_cursor = info.dwCursorPosition; +} + +static void term_cursor_restore(term_t* term) { + if (term->hcon_save_cursor.X == 0) return; + SetConsoleCursorPosition(term->hcon, term->hcon_save_cursor); +} + +static void term_move_cursor( term_t* term, ssize_t drow, ssize_t dcol, ssize_t n ) { + CONSOLE_SCREEN_BUFFER_INFO info; + if (!GetConsoleScreenBufferInfo( term->hcon, &info )) return; + COORD cur = info.dwCursorPosition; + ssize_t col = (ssize_t)cur.X + 1 + n*dcol; + ssize_t row = (ssize_t)cur.Y + 1 + n*drow; + term_move_cursor_to( term, row, col ); +} + +static void term_cursor_visible( term_t* term, bool visible ) { + CONSOLE_CURSOR_INFO info; + if (!GetConsoleCursorInfo(term->hcon,&info)) return; + info.bVisible = visible; + SetConsoleCursorInfo(term->hcon,&info); +} + +static void term_erase_line( term_t* term, ssize_t mode ) { + CONSOLE_SCREEN_BUFFER_INFO info; + if (!GetConsoleScreenBufferInfo( term->hcon, &info )) return; + DWORD written; + COORD start; + ssize_t length; + if (mode == 2) { + // entire line + start.X = 0; + start.Y = info.dwCursorPosition.Y; + length = (ssize_t)info.srWindow.Right + 1; + } + else if (mode == 1) { + // to start of line + start.X = 0; + start.Y = info.dwCursorPosition.Y; + length = info.dwCursorPosition.X; + } + else { + // to end of line + length = (ssize_t)info.srWindow.Right - info.dwCursorPosition.X + 1; + start = info.dwCursorPosition; + } + FillConsoleOutputAttribute( term->hcon, term->hcon_default_attr, (DWORD)length, start, &written ); + FillConsoleOutputCharacterA( term->hcon, ' ', (DWORD)length, start, &written ); +} + +static void term_clear_screen(term_t* term, ssize_t mode) { + CONSOLE_SCREEN_BUFFER_INFO info; + if (!GetConsoleScreenBufferInfo(term->hcon, &info)) return; + COORD start; + start.X = 0; + start.Y = 0; + ssize_t length; + ssize_t width = (ssize_t)info.dwSize.X; + if (mode == 2) { + // entire screen + length = width * info.dwSize.Y; + } + else if (mode == 1) { + // to cursor + length = (width * ((ssize_t)info.dwCursorPosition.Y - 1)) + info.dwCursorPosition.X; + } + else { + // from cursor + start = info.dwCursorPosition; + length = (width * ((ssize_t)info.dwSize.Y - info.dwCursorPosition.Y)) + (width - info.dwCursorPosition.X + 1); + } + DWORD written; + FillConsoleOutputAttribute(term->hcon, term->hcon_default_attr, (DWORD)length, start, &written); + FillConsoleOutputCharacterA(term->hcon, ' ', (DWORD)length, start, &written); +} + +static WORD attr_color[8] = { + 0, // black + FOREGROUND_RED, // maroon + FOREGROUND_GREEN, // green + FOREGROUND_RED | FOREGROUND_GREEN, // orange + FOREGROUND_BLUE, // navy + FOREGROUND_RED | FOREGROUND_BLUE, // purple + FOREGROUND_GREEN | FOREGROUND_BLUE, // teal + FOREGROUND_RED | FOREGROUND_GREEN | FOREGROUND_BLUE, // light gray +}; + +static void term_set_win_attr( term_t* term, attr_t ta ) { + WORD def_attr = term->hcon_default_attr; + CONSOLE_SCREEN_BUFFER_INFO info; + if (!GetConsoleScreenBufferInfo( term->hcon, &info )) return; + WORD cur_attr = info.wAttributes; + WORD attr = cur_attr; + if (ta.x.color != IC_COLOR_NONE) { + if (ta.x.color >= IC_ANSI_BLACK && ta.x.color <= IC_ANSI_SILVER) { + attr = (attr & 0xFFF0) | attr_color[ta.x.color - IC_ANSI_BLACK]; + } + else if (ta.x.color >= IC_ANSI_GRAY && ta.x.color <= IC_ANSI_WHITE) { + attr = (attr & 0xFFF0) | attr_color[ta.x.color - IC_ANSI_GRAY] | FOREGROUND_INTENSITY; + } + else if (ta.x.color == IC_ANSI_DEFAULT) { + attr = (attr & 0xFFF0) | (def_attr & 0x000F); + } + } + if (ta.x.bgcolor != IC_COLOR_NONE) { + if (ta.x.bgcolor >= IC_ANSI_BLACK && ta.x.bgcolor <= IC_ANSI_SILVER) { + attr = (attr & 0xFF0F) | (WORD)(attr_color[ta.x.bgcolor - IC_ANSI_BLACK] << 4); + } + else if (ta.x.bgcolor >= IC_ANSI_GRAY && ta.x.bgcolor <= IC_ANSI_WHITE) { + attr = (attr & 0xFF0F) | (WORD)(attr_color[ta.x.bgcolor - IC_ANSI_GRAY] << 4) | BACKGROUND_INTENSITY; + } + else if (ta.x.bgcolor == IC_ANSI_DEFAULT) { + attr = (attr & 0xFF0F) | (def_attr & 0x00F0); + } + } + if (ta.x.underline != IC_NONE) { + attr = (attr & ~COMMON_LVB_UNDERSCORE) | (ta.x.underline == IC_ON ? COMMON_LVB_UNDERSCORE : 0); + } + if (ta.x.reverse != IC_NONE) { + attr = (attr & ~COMMON_LVB_REVERSE_VIDEO) | (ta.x.reverse == IC_ON ? COMMON_LVB_REVERSE_VIDEO : 0); + } + if (attr != cur_attr) { + SetConsoleTextAttribute(term->hcon, attr); + } +} + +static ssize_t esc_param( const char* s, ssize_t def ) { + if (*s == '?') s++; + ssize_t n = def; + ic_atoz(s, &n); + return n; +} + +static void esc_param2( const char* s, ssize_t* p1, ssize_t* p2, ssize_t def ) { + if (*s == '?') s++; + *p1 = def; + *p2 = def; + ic_atoz2(s, p1, p2); +} + +// Emulate escape sequences on older windows. +static void term_write_esc( term_t* term, const char* s, ssize_t len ) { + ssize_t row; + ssize_t col; + + if (s[1] == '[') { + switch (s[len-1]) { + case 'A': + term_move_cursor(term, -1, 0, esc_param(s+2, 1)); + break; + case 'B': + term_move_cursor(term, 1, 0, esc_param(s+2, 1)); + break; + case 'C': + term_move_cursor(term, 0, 1, esc_param(s+2, 1)); + break; + case 'D': + term_move_cursor(term, 0, -1, esc_param(s+2, 1)); + break; + case 'H': + esc_param2(s+2, &row, &col, 1); + term_move_cursor_to(term, row, col); + break; + case 'K': + term_erase_line(term, esc_param(s+2, 0)); + break; + case 'm': + term_set_win_attr( term, attr_from_esc_sgr(s,len) ); + break; + + // support some less standard escape codes (currently not used by isocline) + case 'E': // line down + term_get_cursor_pos(term, &row, &col); + row += esc_param(s+2, 1); + term_move_cursor_to(term, row, 1); + break; + case 'F': // line up + term_get_cursor_pos(term, &row, &col); + row -= esc_param(s+2, 1); + term_move_cursor_to(term, row, 1); + break; + case 'G': // absolute column + term_get_cursor_pos(term, &row, &col); + col = esc_param(s+2, 1); + term_move_cursor_to(term, row, col); + break; + case 'J': + term_clear_screen(term, esc_param(s+2, 0)); + break; + case 'h': + if (strncmp(s+2, "?25h", 4) == 0) { + term_cursor_visible(term, true); + } + break; + case 'l': + if (strncmp(s+2, "?25l", 4) == 0) { + term_cursor_visible(term, false); + } + break; + case 's': + term_cursor_save(term); + break; + case 'u': + term_cursor_restore(term); + break; + // otherwise ignore + } + } + else if (s[1] == '7') { + term_cursor_save(term); + } + else if (s[1] == '8') { + term_cursor_restore(term); + } + else { + // otherwise ignore + } +} + +static bool term_write_direct(term_t* term, const char* s, ssize_t len ) { + term_cursor_visible(term,false); // reduce flicker + ssize_t pos = 0; + if ((term->hcon_mode & ENABLE_VIRTUAL_TERMINAL_PROCESSING) != 0) { + // use the builtin virtual terminal processing. (enables truecolor for example) + term_write_console(term, s, len); + pos = len; + } + else { + // emulate escape sequences + while( pos < len ) { + // handle non-control in bulk (including utf-8 sequences) + // (We don't need to handle utf-8 separately as we set the codepage to always be in utf-8 mode) + ssize_t nonctrl = 0; + ssize_t next; + while( (next = str_next_ofs( s, len, pos+nonctrl, NULL )) > 0 && + (uint8_t)s[pos + nonctrl] >= ' ' && (uint8_t)s[pos + nonctrl] <= 0x7F) { + nonctrl += next; + } + if (nonctrl > 0) { + term_write_console(term, s+pos, nonctrl); + pos += nonctrl; + } + if (next <= 0) break; + + if ((uint8_t)s[pos] >= 0x80) { + // utf8 is already processed + term_write_console(term, s+pos, next); + } + else if (next > 1 && s[pos] == '\x1B') { + // handle control (note: str_next_ofs considers whole CSI escape sequences at a time) + term_write_esc(term, s+pos, next); + } + else if (next == 1 && (s[pos] == '\r' || s[pos] == '\n' || s[pos] == '\t' || s[pos] == '\b')) { + term_write_console( term, s+pos, next); + } + else { + // ignore + } + pos += next; + } + } + term_cursor_visible(term,true); + assert(pos == len); + return (pos == len); + +} +#endif + + + +//------------------------------------------------------------- +// Update terminal dimensions +//------------------------------------------------------------- + +#if !defined(_WIN32) + +// send escape query that may return a response on the tty +static bool term_esc_query_raw( term_t* term, const char* query, char* buf, ssize_t buflen ) +{ + if (buf==NULL || buflen <= 0 || query[0] == 0) return false; + bool osc = (query[1] == ']'); + if (!term_write_direct(term, query, ic_strlen(query))) return false; + debug_msg("term: read tty query response to: ESC %s\n", query + 1); + return tty_read_esc_response( term->tty, query[1], osc, buf, buflen ); +} + +static bool term_esc_query( term_t* term, const char* query, char* buf, ssize_t buflen ) +{ + if (!tty_start_raw(term->tty)) return false; + bool ok = term_esc_query_raw(term,query,buf,buflen); + tty_end_raw(term->tty); + return ok; +} + +// get the cursor position via an ESC[6n +static bool term_get_cursor_pos( term_t* term, ssize_t* row, ssize_t* col) +{ + // send escape query + char buf[128]; + if (!term_esc_query(term,"\x1B[6n",buf,128)) return false; + if (!ic_atoz2(buf,row,col)) return false; + return true; +} + +static void term_set_cursor_pos( term_t* term, ssize_t row, ssize_t col ) { + term_writef( term, IC_CSI "%zd;%zdH", row, col ); +} + +ic_private bool term_update_dim(term_t* term) { + ssize_t cols = 0; + ssize_t rows = 0; + struct winsize ws; + if (ioctl(term->fd_out, TIOCGWINSZ, &ws) >= 0) { + // ioctl succeeded + cols = ws.ws_col; // debuggers return 0 for the column + rows = ws.ws_row; + } + else { + // determine width by querying the cursor position + debug_msg("term: ioctl term-size failed: %d,%d\n", ws.ws_row, ws.ws_col); + ssize_t col0 = 0; + ssize_t row0 = 0; + if (term_get_cursor_pos(term,&row0,&col0)) { + term_set_cursor_pos(term,999,999); + ssize_t col1 = 0; + ssize_t row1 = 0; + if (term_get_cursor_pos(term,&row1,&col1)) { + cols = col1; + rows = row1; + } + term_set_cursor_pos(term,row0,col0); + } + else { + // cannot query position + // return 0 column + } + } + + // update width and return whether it changed. + bool changed = (term->width != cols || term->height != rows); + debug_msg("terminal dim: %zd,%zd: %s\n", rows, cols, changed ? "changed" : "unchanged"); + if (cols > 0) { + term->width = cols; + term->height = rows; + } + return changed; +} + +#else + +ic_private bool term_update_dim(term_t* term) { + if (term->hcon == 0) { + term->hcon = GetConsoleWindow(); + } + ssize_t rows = 0; + ssize_t cols = 0; + CONSOLE_SCREEN_BUFFER_INFO sbinfo; + if (GetConsoleScreenBufferInfo(term->hcon, &sbinfo)) { + cols = (ssize_t)sbinfo.srWindow.Right - (ssize_t)sbinfo.srWindow.Left + 1; + rows = (ssize_t)sbinfo.srWindow.Bottom - (ssize_t)sbinfo.srWindow.Top + 1; + } + bool changed = (term->width != cols || term->height != rows); + term->width = cols; + term->height = rows; + debug_msg("term: update dim: %zd, %zd\n", term->height, term->width ); + return changed; +} + +#endif + + + +//------------------------------------------------------------- +// Enable/disable terminal raw mode +//------------------------------------------------------------- + +#if !defined(_WIN32) + +// On non-windows, the terminal is set in raw mode by the tty. + +ic_private void term_start_raw(term_t* term) { + term->raw_enabled++; +} + +ic_private void term_end_raw(term_t* term, bool force) { + if (term->raw_enabled <= 0) return; + if (!force) { + term->raw_enabled--; + } + else { + term->raw_enabled = 0; + } +} + +static bool term_esc_query_color_raw(term_t* term, int color_idx, uint32_t* color ) { + char buf[128+1]; + snprintf(buf,128,"\x1B]4;%d;?\x1B\\", color_idx); + if (!term_esc_query_raw( term, buf, buf, 128 )) { + debug_msg("esc query response not received\n"); + return false; + } + if (buf[0] != '4') return false; + const char* rgb = strchr(buf,':'); + if (rgb==NULL) return false; + rgb++; // skip ':' + unsigned int r,g,b; + if (sscanf(rgb,"%x/%x/%x",&r,&g,&b) != 3) return false; + if (rgb[2]!='/') { // 48-bit rgb, hexadecimal round to 24-bit + r = (r+0x7F)/0x100; // note: can "overflow", e.g. 0xFFFF -> 0x100. (and we need `ic_cap8` to convert.) + g = (g+0x7F)/0x100; + b = (b+0x7F)/0x100; + } + *color = (ic_cap8(r)<<16) | (ic_cap8(g)<<8) | ic_cap8(b); + debug_msg("color query: %02x,%02x,%02x: %06x\n", r, g, b, *color); + return true; +} + +// update ansi 16 color palette for better color approximation +static void term_update_ansi16(term_t* term) { + debug_msg("update ansi colors\n"); + #if defined(GIO_CMAP) + // try ioctl first (on Linux) + uint8_t cmap[48]; + memset(cmap,0,48); + if (ioctl(term->fd_out,GIO_CMAP,&cmap) >= 0) { + // success + for(ssize_t i = 0; i < 48; i+=3) { + uint32_t color = ((uint32_t)(cmap[i]) << 16) | ((uint32_t)(cmap[i+1]) << 8) | cmap[i+2]; + debug_msg("term (ioctl) ansi color %d: 0x%06x\n", i, color); + ansi256[i] = color; + } + return; + } + else { + debug_msg("ioctl GIO_CMAP failed: entry 1: 0x%02x%02x%02x\n", cmap[3], cmap[4], cmap[5]); + } + #endif + // this seems to be unreliable on some systems (Ubuntu+Gnome terminal) so only enable when known ok. + #if __APPLE__ + // otherwise use OSC 4 escape sequence query + if (tty_start_raw(term->tty)) { + for(ssize_t i = 0; i < 16; i++) { + uint32_t color; + if (!term_esc_query_color_raw(term, i, &color)) break; + debug_msg("term ansi color %d: 0x%06x\n", i, color); + ansi256[i] = color; + } + tty_end_raw(term->tty); + } + #endif +} + +static void term_init_raw(term_t* term) { + if (term->palette < ANSIRGB) { + term_update_ansi16(term); + } +} + +#else + +ic_private void term_start_raw(term_t* term) { + if (term->raw_enabled++ > 0) return; + CONSOLE_SCREEN_BUFFER_INFO info; + if (GetConsoleScreenBufferInfo(term->hcon, &info)) { + term->hcon_orig_attr = info.wAttributes; + } + term->hcon_orig_cp = GetConsoleOutputCP(); + SetConsoleOutputCP(CP_UTF8); + if (term->hcon_mode == 0) { + // first time initialization + DWORD mode = ENABLE_PROCESSED_OUTPUT | ENABLE_WRAP_AT_EOL_OUTPUT | ENABLE_LVB_GRID_WORLDWIDE; // for \r \n and \b + // use escape sequence handling if available and the terminal supports it (so we can use rgb colors in Windows terminal) + // Unfortunately, in plain powershell, we can successfully enable terminal processing + // but it still fails to render correctly; so we require the palette be large enough (like in Windows Terminal) + if (term->palette >= ANSI256 && SetConsoleMode(term->hcon, mode | ENABLE_VIRTUAL_TERMINAL_PROCESSING)) { + term->hcon_mode = mode | ENABLE_VIRTUAL_TERMINAL_PROCESSING; + debug_msg("term: console mode: virtual terminal processing enabled\n"); + } + // no virtual terminal processing, emulate instead + else if (SetConsoleMode(term->hcon, mode)) { + term->hcon_mode = mode; + term->palette = ANSI16; + } + GetConsoleMode(term->hcon, &mode); + debug_msg("term: console mode: orig: 0x%x, new: 0x%x, current 0x%x\n", term->hcon_orig_mode, term->hcon_mode, mode); + } + else { + SetConsoleMode(term->hcon, term->hcon_mode); + } +} + +ic_private void term_end_raw(term_t* term, bool force) { + if (term->raw_enabled <= 0) return; + if (!force && term->raw_enabled > 1) { + term->raw_enabled--; + } + else { + term->raw_enabled = 0; + SetConsoleMode(term->hcon, term->hcon_orig_mode); + SetConsoleOutputCP(term->hcon_orig_cp); + SetConsoleTextAttribute(term->hcon, term->hcon_orig_attr); + } +} + +static void term_init_raw(term_t* term) { + term->hcon = GetStdHandle(STD_OUTPUT_HANDLE); + GetConsoleMode(term->hcon, &term->hcon_orig_mode); + CONSOLE_SCREEN_BUFFER_INFOEX info; + memset(&info, 0, sizeof(info)); + info.cbSize = sizeof(info); + if (GetConsoleScreenBufferInfoEx(term->hcon, &info)) { + // store default attributes + term->hcon_default_attr = info.wAttributes; + // update our color table with the actual colors used. + for (unsigned i = 0; i < 16; i++) { + COLORREF cr = info.ColorTable[i]; + uint32_t color = (ic_cap8(GetRValue(cr))<<16) | (ic_cap8(GetGValue(cr))<<8) | ic_cap8(GetBValue(cr)); // COLORREF = BGR + // index is also in reverse in the bits 0 and 2 + unsigned j = (i&0x08) | ((i&0x04)>>2) | (i&0x02) | (i&0x01)<<2; + debug_msg("term: ansi color %d is 0x%06x\n", j, color); + ansi256[j] = color; + } + } + else { + DWORD err = GetLastError(); + debug_msg("term: cannot get console screen buffer: %d %x", err, err); + } + term_start_raw(term); // initialize the hcon_mode + term_end_raw(term,false); +} + +#endif diff --git a/extern/isocline/src/term.h b/extern/isocline/src/term.h new file mode 100644 index 000000000..50bfd9682 --- /dev/null +++ b/extern/isocline/src/term.h @@ -0,0 +1,85 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_TERM_H +#define IC_TERM_H + +#include "common.h" +#include "tty.h" +#include "stringbuf.h" +#include "attr.h" + +struct term_s; +typedef struct term_s term_t; + +typedef enum buffer_mode_e { + UNBUFFERED, + LINEBUFFERED, + BUFFERED, +} buffer_mode_t; + +// Primitives +ic_private term_t* term_new(alloc_t* mem, tty_t* tty, bool nocolor, bool silent, int fd_out); +ic_private void term_free(term_t* term); + +ic_private bool term_is_interactive(const term_t* term); +ic_private void term_start_raw(term_t* term); +ic_private void term_end_raw(term_t* term, bool force); + +ic_private bool term_enable_beep(term_t* term, bool enable); +ic_private bool term_enable_color(term_t* term, bool enable); + +ic_private void term_flush(term_t* term); +ic_private buffer_mode_t term_set_buffer_mode(term_t* term, buffer_mode_t mode); + +ic_private void term_write_n(term_t* term, const char* s, ssize_t n); +ic_private void term_write(term_t* term, const char* s); +ic_private void term_writeln(term_t* term, const char* s); +ic_private void term_write_char(term_t* term, char c); + +ic_private void term_write_repeat(term_t* term, const char* s, ssize_t count ); +ic_private void term_beep(term_t* term); + +ic_private bool term_update_dim(term_t* term); + +ic_private ssize_t term_get_width(term_t* term); +ic_private ssize_t term_get_height(term_t* term); +ic_private int term_get_color_bits(term_t* term); + +// Helpers +ic_private void term_writef(term_t* term, const char* fmt, ...); +ic_private void term_vwritef(term_t* term, const char* fmt, va_list args); + +ic_private void term_left(term_t* term, ssize_t n); +ic_private void term_right(term_t* term, ssize_t n); +ic_private void term_up(term_t* term, ssize_t n); +ic_private void term_down(term_t* term, ssize_t n); +ic_private void term_start_of_line(term_t* term ); +ic_private void term_clear_line(term_t* term); +ic_private void term_clear_to_end_of_line(term_t* term); +// ic_private void term_clear_lines_to_end(term_t* term); + + +ic_private void term_attr_reset(term_t* term); +ic_private void term_underline(term_t* term, bool on); +ic_private void term_reverse(term_t* term, bool on); +ic_private void term_bold(term_t* term, bool on); +ic_private void term_italic(term_t* term, bool on); + +ic_private void term_color(term_t* term, ic_color_t color); +ic_private void term_bgcolor(term_t* term, ic_color_t color); + +// Formatted output + +ic_private attr_t term_get_attr( const term_t* term ); +ic_private void term_set_attr( term_t* term, attr_t attr ); +ic_private void term_write_formatted( term_t* term, const char* s, const attr_t* attrs ); +ic_private void term_write_formatted_n( term_t* term, const char* s, const attr_t* attrs, ssize_t n ); + +ic_private ic_color_t color_from_ansi256(ssize_t i); + +#endif // IC_TERM_H diff --git a/extern/isocline/src/term_color.c b/extern/isocline/src/term_color.c new file mode 100644 index 000000000..98af3cf4f --- /dev/null +++ b/extern/isocline/src/term_color.c @@ -0,0 +1,371 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ + +// This file is included in "term.c" + +//------------------------------------------------------------- +// Standard ANSI palette for 256 colors +//------------------------------------------------------------- + +static uint32_t ansi256[256] = { + // not const as on some platforms (e.g. Windows, xterm) we update the first 16 entries with the actual used colors. + // 0, standard ANSI + 0x000000, 0x800000, 0x008000, 0x808000, 0x000080, 0x800080, + 0x008080, 0xc0c0c0, + // 8, bright ANSI + 0x808080, 0xff0000, 0x00ff00, 0xffff00, 0x0000ff, 0xff00ff, + 0x00ffff, 0xffffff, + // 6x6x6 RGB colors + // 16 + 0x000000, 0x00005f, 0x000087, 0x0000af, 0x0000d7, 0x0000ff, + 0x005f00, 0x005f5f, 0x005f87, 0x005faf, 0x005fd7, 0x005fff, + 0x008700, 0x00875f, 0x008787, 0x0087af, 0x0087d7, 0x0087ff, + 0x00af00, 0x00af5f, 0x00af87, 0x00afaf, 0x00afd7, 0x00afff, + 0x00d700, 0x00d75f, 0x00d787, 0x00d7af, 0x00d7d7, 0x00d7ff, + 0x00ff00, 0x00ff5f, 0x00ff87, 0x00ffaf, 0x00ffd7, 0x00ffff, + // 52 + 0x5f0000, 0x5f005f, 0x5f0087, 0x5f00af, 0x5f00d7, 0x5f00ff, + 0x5f5f00, 0x5f5f5f, 0x5f5f87, 0x5f5faf, 0x5f5fd7, 0x5f5fff, + 0x5f8700, 0x5f875f, 0x5f8787, 0x5f87af, 0x5f87d7, 0x5f87ff, + 0x5faf00, 0x5faf5f, 0x5faf87, 0x5fafaf, 0x5fafd7, 0x5fafff, + 0x5fd700, 0x5fd75f, 0x5fd787, 0x5fd7af, 0x5fd7d7, 0x5fd7ff, + 0x5fff00, 0x5fff5f, 0x5fff87, 0x5fffaf, 0x5fffd7, 0x5fffff, + // 88 + 0x870000, 0x87005f, 0x870087, 0x8700af, 0x8700d7, 0x8700ff, + 0x875f00, 0x875f5f, 0x875f87, 0x875faf, 0x875fd7, 0x875fff, + 0x878700, 0x87875f, 0x878787, 0x8787af, 0x8787d7, 0x8787ff, + 0x87af00, 0x87af5f, 0x87af87, 0x87afaf, 0x87afd7, 0x87afff, + 0x87d700, 0x87d75f, 0x87d787, 0x87d7af, 0x87d7d7, 0x87d7ff, + 0x87ff00, 0x87ff5f, 0x87ff87, 0x87ffaf, 0x87ffd7, 0x87ffff, + // 124 + 0xaf0000, 0xaf005f, 0xaf0087, 0xaf00af, 0xaf00d7, 0xaf00ff, + 0xaf5f00, 0xaf5f5f, 0xaf5f87, 0xaf5faf, 0xaf5fd7, 0xaf5fff, + 0xaf8700, 0xaf875f, 0xaf8787, 0xaf87af, 0xaf87d7, 0xaf87ff, + 0xafaf00, 0xafaf5f, 0xafaf87, 0xafafaf, 0xafafd7, 0xafafff, + 0xafd700, 0xafd75f, 0xafd787, 0xafd7af, 0xafd7d7, 0xafd7ff, + 0xafff00, 0xafff5f, 0xafff87, 0xafffaf, 0xafffd7, 0xafffff, + // 160 + 0xd70000, 0xd7005f, 0xd70087, 0xd700af, 0xd700d7, 0xd700ff, + 0xd75f00, 0xd75f5f, 0xd75f87, 0xd75faf, 0xd75fd7, 0xd75fff, + 0xd78700, 0xd7875f, 0xd78787, 0xd787af, 0xd787d7, 0xd787ff, + 0xd7af00, 0xd7af5f, 0xd7af87, 0xd7afaf, 0xd7afd7, 0xd7afff, + 0xd7d700, 0xd7d75f, 0xd7d787, 0xd7d7af, 0xd7d7d7, 0xd7d7ff, + 0xd7ff00, 0xd7ff5f, 0xd7ff87, 0xd7ffaf, 0xd7ffd7, 0xd7ffff, + // 196 + 0xff0000, 0xff005f, 0xff0087, 0xff00af, 0xff00d7, 0xff00ff, + 0xff5f00, 0xff5f5f, 0xff5f87, 0xff5faf, 0xff5fd7, 0xff5fff, + 0xff8700, 0xff875f, 0xff8787, 0xff87af, 0xff87d7, 0xff87ff, + 0xffaf00, 0xffaf5f, 0xffaf87, 0xffafaf, 0xffafd7, 0xffafff, + 0xffd700, 0xffd75f, 0xffd787, 0xffd7af, 0xffd7d7, 0xffd7ff, + 0xffff00, 0xffff5f, 0xffff87, 0xffffaf, 0xffffd7, 0xffffff, + // 232, gray scale + 0x080808, 0x121212, 0x1c1c1c, 0x262626, 0x303030, 0x3a3a3a, + 0x444444, 0x4e4e4e, 0x585858, 0x626262, 0x6c6c6c, 0x767676, + 0x808080, 0x8a8a8a, 0x949494, 0x9e9e9e, 0xa8a8a8, 0xb2b2b2, + 0xbcbcbc, 0xc6c6c6, 0xd0d0d0, 0xdadada, 0xe4e4e4, 0xeeeeee +}; + + +//------------------------------------------------------------- +// Create colors +//------------------------------------------------------------- + +// Create a color from a 24-bit color value. +ic_private ic_color_t ic_rgb(uint32_t hex) { + return (ic_color_t)(0x1000000 | (hex & 0xFFFFFF)); +} + +// Limit an int to values between 0 and 255. +static uint32_t ic_cap8(ssize_t i) { + return (i < 0 ? 0 : (i > 255 ? 255 : (uint32_t)i)); +} + +// Create a color from a 24-bit color value. +ic_private ic_color_t ic_rgbx(ssize_t r, ssize_t g, ssize_t b) { + return ic_rgb( (ic_cap8(r)<<16) | (ic_cap8(g)<<8) | ic_cap8(b) ); +} + + +//------------------------------------------------------------- +// Match an rgb color to a ansi8, ansi16, or ansi256 +//------------------------------------------------------------- + +static bool color_is_rgb( ic_color_t color ) { + return (color >= IC_RGB(0)); // bit 24 is set for rgb colors +} + +static void color_to_rgb(ic_color_t color, int* r, int* g, int* b) { + assert(color_is_rgb(color)); + *r = ((color >> 16) & 0xFF); + *g = ((color >> 8) & 0xFF); + *b = (color & 0xFF); +} + +ic_private ic_color_t color_from_ansi256(ssize_t i) { + if (i >= 0 && i < 8) { + return (IC_ANSI_BLACK + (uint32_t)i); + } + else if (i >= 8 && i < 16) { + return (IC_ANSI_DARKGRAY + (uint32_t)(i - 8)); + } + else if (i >= 16 && i <= 255) { + return ic_rgb( ansi256[i] ); + } + else if (i == 256) { + return IC_ANSI_DEFAULT; + } + else { + return IC_ANSI_DEFAULT; + } +} + +static bool is_grayish(int r, int g, int b) { + return (abs(r-g) <= 4) && (abs((r+g)/2 - b) <= 4); +} + +static bool is_grayish_color( uint32_t rgb ) { + int r, g, b; + color_to_rgb(IC_RGB(rgb),&r,&g,&b); + return is_grayish(r,g,b); +} + +static int_least32_t sqr(int_least32_t x) { + return x*x; +} + +// Approximation to delta-E CIE color distance using much +// simpler calculations. See . +// This is essentialy weighted euclidean distance but the weight distribution +// depends on how big the "red" component of the color is. +// We do not take the square root as we only need to find +// the minimal distance (and multiply by 256 to increase precision). +// Needs at least 28-bit signed integers to avoid overflow. +static int_least32_t rgb_distance_rmean( uint32_t color, int r2, int g2, int b2 ) { + int r1, g1, b1; + color_to_rgb(IC_RGB(color),&r1,&g1,&b1); + int_least32_t rmean = (r1 + r2) / 2; + int_least32_t dr2 = sqr(r1 - r2); + int_least32_t dg2 = sqr(g1 - g2); + int_least32_t db2 = sqr(b1 - b2); + int_least32_t dist = ((512+rmean)*dr2) + 1024*dg2 + ((767-rmean)*db2); + return dist; +} + +// Another approximation to delta-E CIE color distance using +// simpler calculations. Similar to `rmean` but adds an adjustment factor +// based on the "red/blue" difference. +static int_least32_t rgb_distance_rbmean( uint32_t color, int r2, int g2, int b2 ) { + int r1, g1, b1; + color_to_rgb(IC_RGB(color),&r1,&g1,&b1); + int_least32_t rmean = (r1 + r2) / 2; + int_least32_t dr2 = sqr(r1 - r2); + int_least32_t dg2 = sqr(g1 - g2); + int_least32_t db2 = sqr(b1 - b2); + int_least32_t dist = 2*dr2 + 4*dg2 + 3*db2 + ((rmean*(dr2 - db2))/256); + return dist; +} + + +// Maintain a small cache of recently used colors. Should be short enough to be effectively constant time. +// If we ever use a more expensive color distance method, we may increase the size a bit (64?) +// (Initial zero initialized cache is valid.) +#define RGB_CACHE_LEN (16) +typedef struct rgb_cache_s { + int last; + int indices[RGB_CACHE_LEN]; + ic_color_t colors[RGB_CACHE_LEN]; +} rgb_cache_t; + +// remember a color in the LRU cache +void rgb_remember( rgb_cache_t* cache, ic_color_t color, int idx ) { + if (cache == NULL) return; + cache->colors[cache->last] = color; + cache->indices[cache->last] = idx; + cache->last++; + if (cache->last >= RGB_CACHE_LEN) { cache->last = 0; } +} + +// quick lookup in cache; -1 on failure +int rgb_lookup( const rgb_cache_t* cache, ic_color_t color ) { + if (cache != NULL) { + for(int i = 0; i < RGB_CACHE_LEN; i++) { + if (cache->colors[i] == color) return cache->indices[i]; + } + } + return -1; +} + +// return the index of the closest matching color +static int rgb_match( uint32_t* palette, int start, int len, rgb_cache_t* cache, ic_color_t color ) { + assert(color_is_rgb(color)); + // in cache? + int min = rgb_lookup(cache,color); + if (min >= 0) { + return min; + } + // otherwise find closest color match in the palette + int r, g, b; + color_to_rgb(color,&r,&g,&b); + min = start; + int_least32_t mindist = (INT_LEAST32_MAX)/4; + for(int i = start; i < len; i++) { + //int_least32_t dist = rgb_distance_rbmean(palette[i],r,g,b); + int_least32_t dist = rgb_distance_rmean(palette[i],r,g,b); + if (is_grayish_color(palette[i]) != is_grayish(r, g, b)) { + // with few colors, make it less eager to substitute a gray for a non-gray (or the other way around) + if (len <= 16) { + dist *= 4; + } + else { + dist = (dist/4)*5; + } + } + if (dist < mindist) { + min = i; + mindist = dist; + } + } + rgb_remember(cache,color,min); + return min; +} + + +// Match RGB to an index in the ANSI 256 color table +static int rgb_to_ansi256(ic_color_t color) { + static rgb_cache_t ansi256_cache; + int c = rgb_match(ansi256, 16, 256, &ansi256_cache, color); // not the first 16 ANSI colors as those may be different + //debug_msg("term: rgb %x -> ansi 256: %d\n", color, c ); + return c; +} + +// Match RGB to an ANSI 16 color code (30-37, 90-97) +static int color_to_ansi16(ic_color_t color) { + if (!color_is_rgb(color)) { + return (int)color; + } + else { + static rgb_cache_t ansi16_cache; + int c = rgb_match(ansi256, 0, 16, &ansi16_cache, color); + //debug_msg("term: rgb %x -> ansi 16: %d\n", color, c ); + return (c < 8 ? 30 + c : 90 + c - 8); + } +} + +// Match RGB to an ANSI 16 color code (30-37, 90-97) +// but assuming the bright colors are simulated using 'bold'. +static int color_to_ansi8(ic_color_t color) { + if (!color_is_rgb(color)) { + return (int)color; + } + else { + // match to basic 8 colors first + static rgb_cache_t ansi8_cache; + int c = 30 + rgb_match(ansi256, 0, 8, &ansi8_cache, color); + // and then adjust for brightness + int r, g, b; + color_to_rgb(color,&r,&g,&b); + if (r>=196 || g>=196 || b>=196) c += 60; + //debug_msg("term: rgb %x -> ansi 8: %d\n", color, c ); + return c; + } +} + + +//------------------------------------------------------------- +// Emit color escape codes based on the terminal capability +//------------------------------------------------------------- + +static void fmt_color_ansi8( char* buf, ssize_t len, ic_color_t color, bool bg ) { + int c = color_to_ansi8(color) + (bg ? 10 : 0); + if (c >= 90) { + snprintf(buf, to_size_t(len), IC_CSI "1;%dm", c - 60); + } + else { + snprintf(buf, to_size_t(len), IC_CSI "22;%dm", c ); + } +} + +static void fmt_color_ansi16( char* buf, ssize_t len, ic_color_t color, bool bg ) { + snprintf( buf, to_size_t(len), IC_CSI "%dm", color_to_ansi16(color) + (bg ? 10 : 0) ); +} + +static void fmt_color_ansi256( char* buf, ssize_t len, ic_color_t color, bool bg ) { + if (!color_is_rgb(color)) { + fmt_color_ansi16(buf,len,color,bg); + } + else { + snprintf( buf, to_size_t(len), IC_CSI "%d;5;%dm", (bg ? 48 : 38), rgb_to_ansi256(color) ); + } +} + +static void fmt_color_rgb( char* buf, ssize_t len, ic_color_t color, bool bg ) { + if (!color_is_rgb(color)) { + fmt_color_ansi16(buf,len,color,bg); + } + else { + int r,g,b; + color_to_rgb(color, &r,&g,&b); + snprintf( buf, to_size_t(len), IC_CSI "%d;2;%d;%d;%dm", (bg ? 48 : 38), r, g, b ); + } +} + +static void fmt_color_ex(char* buf, ssize_t len, palette_t palette, ic_color_t color, bool bg) { + if (color == IC_COLOR_NONE || palette == MONOCHROME) return; + if (palette == ANSI8) { + fmt_color_ansi8(buf,len,color,bg); + } + else if (!color_is_rgb(color) || palette == ANSI16) { + fmt_color_ansi16(buf,len,color,bg); + } + else if (palette == ANSI256) { + fmt_color_ansi256(buf,len,color,bg); + } + else { + fmt_color_rgb(buf,len,color,bg); + } +} + +static void term_color_ex(term_t* term, ic_color_t color, bool bg) { + char buf[128+1]; + fmt_color_ex(buf,128,term->palette,color,bg); + term_write(term,buf); +} + +//------------------------------------------------------------- +// Main API functions +//------------------------------------------------------------- + +ic_private void term_color(term_t* term, ic_color_t color) { + term_color_ex(term,color,false); +} + +ic_private void term_bgcolor(term_t* term, ic_color_t color) { + term_color_ex(term,color,true); +} + +ic_private void term_append_color(term_t* term, stringbuf_t* sbuf, ic_color_t color) { + char buf[128+1]; + fmt_color_ex(buf,128,term->palette,color,false); + sbuf_append(sbuf,buf); +} + +ic_private void term_append_bgcolor(term_t* term, stringbuf_t* sbuf, ic_color_t color) { + char buf[128+1]; + fmt_color_ex(buf, 128, term->palette, color, true); + sbuf_append(sbuf, buf); +} + +ic_private int term_get_color_bits(term_t* term) { + switch (term->palette) { + case MONOCHROME: return 1; + case ANSI8: return 3; + case ANSI16: return 4; + case ANSI256: return 8; + case ANSIRGB: return 24; + default: return 4; + } +} diff --git a/extern/isocline/src/tty.c b/extern/isocline/src/tty.c new file mode 100644 index 000000000..09f7aedd3 --- /dev/null +++ b/extern/isocline/src/tty.c @@ -0,0 +1,889 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#include +#include +#include +#include +#include + +#include "tty.h" + +#if defined(_WIN32) +#include +#include +#define isatty(fd) _isatty(fd) +#define read(fd,s,n) _read(fd,s,n) +#define STDIN_FILENO 0 +#if (_WIN32_WINNT < 0x0600) +WINBASEAPI ULONGLONG WINAPI GetTickCount64(VOID); +#endif +#else +#include +#include +#include +#include +#include +#include +#if !defined(FIONREAD) +#include +#endif +#endif + +#define TTY_PUSH_MAX (32) + +struct tty_s { + int fd_in; // input handle + bool raw_enabled; // is raw mode enabled? + bool is_utf8; // is the input stream in utf-8 mode? + bool has_term_resize_event; // are resize events generated? + bool term_resize_event; // did a term resize happen? + alloc_t* mem; // memory allocator + code_t pushbuf[TTY_PUSH_MAX]; // push back buffer for full key codes + ssize_t push_count; + uint8_t cpushbuf[TTY_PUSH_MAX]; // low level push back buffer for bytes + ssize_t cpush_count; + long esc_initial_timeout; // initial ms wait to see if ESC starts an escape sequence + long esc_timeout; // follow up delay for characters in an escape sequence + #if defined(_WIN32) + HANDLE hcon; // console input handle + DWORD hcon_orig_mode; // original console mode + #else + struct termios orig_ios; // original terminal settings + struct termios raw_ios; // raw terminal settings + #endif +}; + + +//------------------------------------------------------------- +// Forward declarations of platform dependent primitives below +//------------------------------------------------------------- + +ic_private bool tty_readc_noblock(tty_t* tty, uint8_t* c, long timeout_ms); // does not modify `c` when no input (false is returned) + +//------------------------------------------------------------- +// Key code helpers +//------------------------------------------------------------- + +ic_private bool code_is_ascii_char(code_t c, char* chr ) { + if (c >= ' ' && c <= 0x7F) { + if (chr != NULL) *chr = (char)c; + return true; + } + else { + if (chr != NULL) *chr = 0; + return false; + } +} + +ic_private bool code_is_unicode(code_t c, unicode_t* uchr) { + if (c <= KEY_UNICODE_MAX) { + if (uchr != NULL) *uchr = c; + return true; + } + else { + if (uchr != NULL) *uchr = 0; + return false; + } +} + +ic_private bool code_is_virt_key(code_t c ) { + return (KEY_NO_MODS(c) <= 0x20 || KEY_NO_MODS(c) >= KEY_VIRT); +} + + +//------------------------------------------------------------- +// Read a key code +//------------------------------------------------------------- +static code_t modify_code( code_t code ); + +static code_t tty_read_utf8( tty_t* tty, uint8_t c0 ) { + uint8_t buf[5]; + memset(buf, 0, 5); + + // try to read as many bytes as potentially needed + buf[0] = c0; + ssize_t count = 1; + if (c0 > 0x7F) { + if (tty_readc_noblock(tty, buf+count, tty->esc_timeout)) { + count++; + if (c0 > 0xDF) { + if (tty_readc_noblock(tty, buf+count, tty->esc_timeout)) { + count++; + if (c0 > 0xEF) { + if (tty_readc_noblock(tty, buf+count, tty->esc_timeout)) { + count++; + } + } + } + } + } + } + + buf[count] = 0; + debug_msg("tty: read utf8: count: %zd: %02x,%02x,%02x,%02x", count, buf[0], buf[1], buf[2], buf[3]); + + // decode the utf8 to unicode + ssize_t read = 0; + code_t code = key_unicode(unicode_from_qutf8(buf, count, &read)); + + // push back unused bytes (in the case of invalid utf8) + while (count > read) { + count--; + if (count >= 0 && count <= 4) { // to help the static analyzer + tty_cpush_char(tty, buf[count]); + } + } + return code; +} + +// pop a code from the pushback buffer. +static bool tty_code_pop(tty_t* tty, code_t* code); + + +// read a single char/key +ic_private bool tty_read_timeout(tty_t* tty, long timeout_ms, code_t* code) +{ + // is there a push_count back code? + if (tty_code_pop(tty,code)) { + return code; + } + + // read a single char/byte from a character stream + uint8_t c; + if (!tty_readc_noblock(tty, &c, timeout_ms)) return false; + + if (c == KEY_ESC) { + // escape sequence? + *code = tty_read_esc(tty, tty->esc_initial_timeout, tty->esc_timeout); + } + else if (c <= 0x7F) { + // ascii + *code = key_unicode(c); + } + else if (tty->is_utf8) { + // utf8 sequence + *code = tty_read_utf8(tty,c); + } + else { + // c >= 0x80 but tty is not utf8; use raw plane so we can translate it back in the end + *code = key_unicode( unicode_from_raw(c) ); + } + + *code = modify_code(*code); + return true; +} + +// Transform virtual keys to be more portable across platforms +static code_t modify_code( code_t code ) { + code_t key = KEY_NO_MODS(code); + code_t mods = KEY_MODS(code); + debug_msg( "tty: readc %s%s%s 0x%03x ('%c')\n", + mods&KEY_MOD_SHIFT ? "shift+" : "", mods&KEY_MOD_CTRL ? "ctrl+" : "", mods&KEY_MOD_ALT ? "alt+" : "", + key, (key >= ' ' && key <= '~' ? key : ' ')); + + // treat KEY_RUBOUT (0x7F) as KEY_BACKSP + if (key == KEY_RUBOUT) { + code = KEY_BACKSP | mods; + } + // ctrl+'_' is translated to '\x1F' on Linux, translate it back + else if (key == key_char('\x1F') && (mods & KEY_MOD_ALT) == 0) { + key = '_'; + code = WITH_CTRL(key_char('_')); + } + // treat ctrl/shift + enter always as KEY_LINEFEED for portability + else if (key == KEY_ENTER && (mods == KEY_MOD_SHIFT || mods == KEY_MOD_ALT || mods == KEY_MOD_CTRL)) { + code = KEY_LINEFEED; + } + // treat ctrl+tab always as shift+tab for portability + else if (code == WITH_CTRL(KEY_TAB)) { + code = KEY_SHIFT_TAB; + } + // treat ctrl+end/alt+>/alt-down and ctrl+home/alt+') || code == WITH_CTRL(KEY_END)) { + code = KEY_PAGEDOWN; + } + else if (code == WITH_ALT(KEY_UP) || code == WITH_ALT('<') || code == WITH_CTRL(KEY_HOME)) { + code = KEY_PAGEUP; + } + + // treat C0 codes without KEY_MOD_CTRL + if (key < ' ' && (mods&KEY_MOD_CTRL) != 0) { + code &= ~KEY_MOD_CTRL; + } + + return code; +} + + +// read a single char/key +ic_private code_t tty_read(tty_t* tty) +{ + code_t code; + if (!tty_read_timeout(tty, -1, &code)) return KEY_NONE; + return code; +} + +//------------------------------------------------------------- +// Read back an ANSI query response +//------------------------------------------------------------- + +ic_private bool tty_read_esc_response(tty_t* tty, char esc_start, bool final_st, char* buf, ssize_t buflen ) +{ + buf[0] = 0; + ssize_t len = 0; + uint8_t c = 0; + if (!tty_readc_noblock(tty, &c, 2*tty->esc_initial_timeout) || c != '\x1B') { + debug_msg("initial esc response failed: 0x%02x\n", c); + return false; + } + if (!tty_readc_noblock(tty, &c, tty->esc_timeout) || (c != esc_start)) return false; + while( len < buflen ) { + if (!tty_readc_noblock(tty, &c, tty->esc_timeout)) return false; + if (final_st) { + // OSC is terminated by BELL, or ESC \ (ST) (and STX) + if (c=='\x07' || c=='\x02') { + break; + } + else if (c=='\x1B') { + uint8_t c1; + if (!tty_readc_noblock(tty, &c1, tty->esc_timeout)) return false; + if (c1=='\\') break; + tty_cpush_char(tty,c1); + } + } + else { + if (c == '\x02') { // STX + break; + } + else if (!((c >= '0' && c <= '9') || strchr("<=>?;:",c) != NULL)) { + buf[len++] = (char)c; // for non-OSC save the terminating character + break; + } + } + buf[len++] = (char)c; + } + buf[len] = 0; + debug_msg("tty: escape query response: %s\n", buf); + return true; +} + +//------------------------------------------------------------- +// High level code pushback +//------------------------------------------------------------- + +static bool tty_code_pop( tty_t* tty, code_t* code ) { + if (tty->push_count <= 0) return false; + tty->push_count--; + *code = tty->pushbuf[tty->push_count]; + return true; +} + +ic_private void tty_code_pushback( tty_t* tty, code_t c ) { + // note: must be signal safe + if (tty->push_count >= TTY_PUSH_MAX) return; + tty->pushbuf[tty->push_count] = c; + tty->push_count++; +} + + +//------------------------------------------------------------- +// low-level character pushback (for escape sequences and windows) +//------------------------------------------------------------- + +ic_private bool tty_cpop(tty_t* tty, uint8_t* c) { + if (tty->cpush_count <= 0) { // do not modify c on failure (see `tty_decode_unicode`) + return false; + } + else { + tty->cpush_count--; + *c = tty->cpushbuf[tty->cpush_count]; + return true; + } +} + +static void tty_cpush(tty_t* tty, const char* s) { + ssize_t len = ic_strlen(s); + if (tty->push_count + len > TTY_PUSH_MAX) { + debug_msg("tty: cpush buffer full! (pushing %s)\n", s); + assert(false); + return; + } + for (ssize_t i = 0; i < len; i++) { + tty->cpushbuf[tty->cpush_count + i] = (uint8_t)( s[len - i - 1] ); + } + tty->cpush_count += len; + return; +} + +// convenience function for small sequences +static void tty_cpushf(tty_t* tty, const char* fmt, ...) { + va_list args; + va_start(args,fmt); + char buf[TTY_PUSH_MAX+1]; + vsnprintf(buf,TTY_PUSH_MAX,fmt,args); + buf[TTY_PUSH_MAX] = 0; + tty_cpush(tty,buf); + va_end(args); + return; +} + +ic_private void tty_cpush_char(tty_t* tty, uint8_t c) { + uint8_t buf[2]; + buf[0] = c; + buf[1] = 0; + tty_cpush(tty, (const char*)buf); +} + + +//------------------------------------------------------------- +// Push escape codes (used on Windows to insert keys) +//------------------------------------------------------------- + +static unsigned csi_mods(code_t mods) { + unsigned m = 1; + if (mods&KEY_MOD_SHIFT) m += 1; + if (mods&KEY_MOD_ALT) m += 2; + if (mods&KEY_MOD_CTRL) m += 4; + return m; +} + +// Push ESC [ ; ~ +static void tty_cpush_csi_vt( tty_t* tty, code_t mods, uint32_t vtcode ) { + tty_cpushf(tty,"\x1B[%u;%u~", vtcode, csi_mods(mods) ); +} + +// push ESC [ 1 ; +static void tty_cpush_csi_xterm( tty_t* tty, code_t mods, char xcode ) { + tty_cpushf(tty,"\x1B[1;%u%c", csi_mods(mods), xcode ); +} + +// push ESC [ ; u +static void tty_cpush_csi_unicode( tty_t* tty, code_t mods, uint32_t unicode ) { + if ((unicode < 0x80 && mods == 0) || + (mods == KEY_MOD_CTRL && unicode < ' ' && unicode != KEY_TAB && unicode != KEY_ENTER + && unicode != KEY_LINEFEED && unicode != KEY_BACKSP) || + (mods == KEY_MOD_SHIFT && unicode >= ' ' && unicode <= KEY_RUBOUT)) { + tty_cpush_char(tty,(uint8_t)unicode); + } + else { + tty_cpushf(tty,"\x1B[%u;%uu", unicode, csi_mods(mods) ); + } +} + +//------------------------------------------------------------- +// Init +//------------------------------------------------------------- + +static bool tty_init_raw(tty_t* tty); +static void tty_done_raw(tty_t* tty); + +static bool tty_init_utf8(tty_t* tty) { + #ifdef _WIN32 + tty->is_utf8 = true; + #else + const char* loc = setlocale(LC_ALL,""); + tty->is_utf8 = (ic_icontains(loc,"UTF-8") || ic_icontains(loc,"utf8") || ic_stricmp(loc,"C") == 0); + debug_msg("tty: utf8: %s (loc=%s)\n", tty->is_utf8 ? "true" : "false", loc); + #endif + return true; +} + +ic_private tty_t* tty_new(alloc_t* mem, int fd_in) +{ + tty_t* tty = mem_zalloc_tp(mem, tty_t); + tty->mem = mem; + tty->fd_in = (fd_in < 0 ? STDIN_FILENO : fd_in); + #if defined(__APPLE__) + tty->esc_initial_timeout = 200; // apple use ESC+ for alt- + #else + tty->esc_initial_timeout = 100; + #endif + tty->esc_timeout = 10; + if (!(isatty(tty->fd_in) && tty_init_raw(tty) && tty_init_utf8(tty))) { + tty_free(tty); + return NULL; + } + return tty; +} + +ic_private void tty_free(tty_t* tty) { + if (tty==NULL) return; + tty_end_raw(tty); + tty_done_raw(tty); + mem_free(tty->mem,tty); +} + +ic_private bool tty_is_utf8(const tty_t* tty) { + if (tty == NULL) return true; + return (tty->is_utf8); +} + +ic_private bool tty_term_resize_event(tty_t* tty) { + if (tty == NULL) return true; + if (tty->has_term_resize_event) { + if (!tty->term_resize_event) return false; + tty->term_resize_event = false; // reset. + } + return true; // always return true on systems without a resize event (more expensive but still ok) +} + +ic_private void tty_set_esc_delay(tty_t* tty, long initial_delay_ms, long followup_delay_ms) { + tty->esc_initial_timeout = (initial_delay_ms < 0 ? 0 : (initial_delay_ms > 1000 ? 1000 : initial_delay_ms)); + tty->esc_timeout = (followup_delay_ms < 0 ? 0 : (followup_delay_ms > 1000 ? 1000 : followup_delay_ms)); +} + +//------------------------------------------------------------- +// Unix +//------------------------------------------------------------- +#if !defined(_WIN32) + +static bool tty_readc_blocking(tty_t* tty, uint8_t* c) { + if (tty_cpop(tty,c)) return true; + *c = 0; + ssize_t nread = read(tty->fd_in, (char*)c, 1); + if (nread < 0 && errno == EINTR) { + // can happen on SIGWINCH signal for terminal resize + } + return (nread == 1); +} + + +// non blocking read -- with a small timeout used for reading escape sequences. +ic_private bool tty_readc_noblock(tty_t* tty, uint8_t* c, long timeout_ms) +{ + // in our pushback buffer? + if (tty_cpop(tty, c)) return true; + + // blocking read? + if (timeout_ms < 0) { + return tty_readc_blocking(tty,c); + } + + // if supported, peek first if any char is available. + #if defined(FIONREAD) + { int navail = 0; + if (ioctl(0, FIONREAD, &navail) == 0) { + if (navail >= 1) { + return tty_readc_blocking(tty, c); + } + else if (timeout_ms == 0) { + return false; // return early if there is no input available (with a zero timeout) + } + } + } + #endif + + // otherwise block for at most timeout milliseconds + #if defined(FD_SET) + // we can use select to detect when input becomes available + fd_set readset; + struct timeval time; + FD_ZERO(&readset); + FD_SET(tty->fd_in, &readset); + time.tv_sec = (timeout_ms > 0 ? timeout_ms / 1000 : 0); + time.tv_usec = (timeout_ms > 0 ? 1000*(timeout_ms % 1000) : 0); + if (select(tty->fd_in + 1, &readset, NULL, NULL, &time) == 1) { + // input available + return tty_readc_blocking(tty, c); + } + #else + // no select, we cannot timeout; use usleeps :-( + // todo: this seems very rare nowadays; should be even support this? + do { + // peek ahead if possible + #if defined(FIONREAD) + int navail = 0; + if (ioctl(0, FIONREAD, &navail) == 0 && navail >= 1) { + return tty_readc_blocking(tty, c); + } + #elif defined(O_NONBLOCK) + // use a temporary non-blocking read mode + int fstatus = fcntl(tty->fd_in, F_GETFL, 0); + if (fstatus != -1) { + if (fcntl(tty->fd_in, F_SETFL, (fstatus | O_NONBLOCK)) != -1) { + char buf[2] = { 0, 0 }; + ssize_t nread = read(tty->fd_in, buf, 1); + fcntl(tty->fd_in, F_SETFL, fstatus); + if (nread >= 1) { + *c = (uint8_t)buf[0]; + return true; + } + } + } + #else + #error "define an nonblocking read for this platform" + #endif + // and sleep a bit + if (timeout_ms > 0) { + usleep(50*1000L); // sleep at most 0.05s at a time + timeout_ms -= 100; + if (timeout_ms < 0) { timeout_ms = 0; } + } + } + while (timeout_ms > 0); + #endif + return false; +} + +#if defined(TIOCSTI) +ic_private bool tty_async_stop(const tty_t* tty) { + // insert ^C in the input stream + char c = KEY_CTRL_C; + return (ioctl(tty->fd_in, TIOCSTI, &c) >= 0); +} +#else +ic_private bool tty_async_stop(const tty_t* tty) { + return false; +} +#endif + +// We install various signal handlers to restore the terminal settings +// in case of a terminating signal. This is also used to catch terminal window resizes. +// This is not strictly needed so this can be disabled on +// (older) platforms that do not support signal handling well. +#if defined(SIGWINCH) && defined(SA_RESTART) // ensure basic signal functionality is defined + +// store the tty in a global so we access it on unexpected termination +static tty_t* sig_tty; // = NULL + +// Catch all termination signals (and SIGWINCH) +typedef struct signal_handler_s { + int signum; + union { + int _avoid_warning; + struct sigaction previous; + } action; +} signal_handler_t; + +static signal_handler_t sighandlers[] = { + { SIGWINCH, {0} }, + { SIGTERM , {0} }, + { SIGINT , {0} }, + { SIGQUIT , {0} }, + { SIGHUP , {0} }, + { SIGSEGV , {0} }, + { SIGTRAP , {0} }, + { SIGBUS , {0} }, + { SIGTSTP , {0} }, + { SIGTTIN , {0} }, + { SIGTTOU , {0} }, + { 0 , {0} } +}; + +static bool sigaction_is_valid( struct sigaction* sa ) { + return (sa->sa_sigaction != NULL && sa->sa_handler != SIG_DFL && sa->sa_handler != SIG_IGN); +} + +// Generic signal handler +static void sig_handler(int signum, siginfo_t* siginfo, void* uap ) { + if (signum == SIGWINCH) { + if (sig_tty != NULL) { + sig_tty->term_resize_event = true; + } + } + else { + // the rest are termination signals; restore the terminal mode. (`tcsetattr` is signal-safe) + if (sig_tty != NULL && sig_tty->raw_enabled) { + tcsetattr(sig_tty->fd_in, TCSAFLUSH, &sig_tty->orig_ios); + sig_tty->raw_enabled = false; + } + } + // call previous handler + signal_handler_t* sh = sighandlers; + while( sh->signum != 0 && sh->signum != signum) { sh++; } + if (sh->signum == signum) { + if (sigaction_is_valid(&sh->action.previous)) { + (sh->action.previous.sa_sigaction)(signum, siginfo, uap); + } + } +} + +static void signals_install(tty_t* tty) { + sig_tty = tty; + // generic signal handler + struct sigaction handler; + memset(&handler,0,sizeof(handler)); + sigemptyset(&handler.sa_mask); + handler.sa_sigaction = &sig_handler; + handler.sa_flags = SA_RESTART; + // install for all signals + for( signal_handler_t* sh = sighandlers; sh->signum != 0; sh++ ) { + if (sigaction( sh->signum, NULL, &sh->action.previous) == 0) { // get previous + if (sh->action.previous.sa_handler != SIG_IGN) { // if not to be ignored + if (sigaction( sh->signum, &handler, &sh->action.previous ) < 0) { // install our handler + sh->action.previous.sa_sigaction = NULL; // do not restore on error + } + else if (sh->signum == SIGWINCH) { + sig_tty->has_term_resize_event = true; + }; + } + } + } +} + +static void signals_restore(void) { + // restore all signal handlers + for( signal_handler_t* sh = sighandlers; sh->signum != 0; sh++ ) { + if (sigaction_is_valid(&sh->action.previous)) { + sigaction( sh->signum, &sh->action.previous, NULL ); + }; + } + sig_tty = NULL; +} + +#else +static void signals_install(tty_t* tty) { + ic_unused(tty); + // nothing +} +static void signals_restore(void) { + // nothing +} + +#endif + +ic_private bool tty_start_raw(tty_t* tty) { + if (tty == NULL) return false; + if (tty->raw_enabled) return true; + if (tcsetattr(tty->fd_in,TCSAFLUSH,&tty->raw_ios) < 0) return false; + tty->raw_enabled = true; + return true; +} + +ic_private void tty_end_raw(tty_t* tty) { + if (tty == NULL) return; + if (!tty->raw_enabled) return; + tty->cpush_count = 0; + if (tcsetattr(tty->fd_in,TCSAFLUSH,&tty->orig_ios) < 0) return; + tty->raw_enabled = false; +} + +static bool tty_init_raw(tty_t* tty) +{ + // Set input to raw mode. See . + if (tcgetattr(tty->fd_in,&tty->orig_ios) == -1) return false; + tty->raw_ios = tty->orig_ios; + // input: no break signal, no \r to \n, no parity check, no 8-bit to 7-bit, no flow control + tty->raw_ios.c_iflag &= ~(unsigned long)(BRKINT | ICRNL | INPCK | ISTRIP | IXON); + // control: allow 8-bit + tty->raw_ios.c_cflag |= CS8; + // local: no echo, no line-by-line (canonical), no extended input processing, no signals for ^z,^c + tty->raw_ios.c_lflag &= ~(unsigned long)(ECHO | ICANON | IEXTEN | ISIG); + // 1 byte at a time, no delay + tty->raw_ios.c_cc[VTIME] = 0; + tty->raw_ios.c_cc[VMIN] = 1; + + // store in global so our signal handlers can restore the terminal mode + signals_install(tty); + + return true; +} + +static void tty_done_raw(tty_t* tty) { + ic_unused(tty); + signals_restore(); +} + + +#else + +//------------------------------------------------------------- +// Windows +// For best portability we push CSI escape sequences directly +// to the character stream (instead of returning key codes). +//------------------------------------------------------------- + +static void tty_waitc_console(tty_t* tty, long timeout_ms); + +ic_private bool tty_readc_noblock(tty_t* tty, uint8_t* c, long timeout_ms) { // don't modify `c` if there is no input + // in our pushback buffer? + if (tty_cpop(tty, c)) return true; + // any events in the input queue? + tty_waitc_console(tty, timeout_ms); + return tty_cpop(tty, c); +} + +// Read from the console input events and push escape codes into the tty cbuffer. +static void tty_waitc_console(tty_t* tty, long timeout_ms) +{ + // wait for a key down event + INPUT_RECORD inp; + DWORD count; + uint32_t surrogate_hi = 0; + while (true) { + // check if there are events if in non-blocking timeout mode + if (timeout_ms >= 0) { + // first peek ahead + if (!GetNumberOfConsoleInputEvents(tty->hcon, &count)) return; + if (count == 0) { + if (timeout_ms == 0) { + // out of time + return; + } + else { + // wait for input events for at most timeout milli seconds + ULONGLONG start_ms = GetTickCount64(); + DWORD res = WaitForSingleObject(tty->hcon, (DWORD)timeout_ms); + switch (res) { + case WAIT_OBJECT_0: { + // input is available, decrease our timeout + ULONGLONG waited_ms = (GetTickCount64() - start_ms); + timeout_ms -= (long)waited_ms; + if (timeout_ms < 0) { + timeout_ms = 0; + } + break; + } + case WAIT_TIMEOUT: + case WAIT_ABANDONED: + case WAIT_FAILED: + default: + return; + } + } + } + } + + // (blocking) Read from the input + if (!ReadConsoleInputW(tty->hcon, &inp, 1, &count)) return; + if (count != 1) return; + + // resize event? + if (inp.EventType == WINDOW_BUFFER_SIZE_EVENT) { + tty->term_resize_event = true; + continue; + } + + // wait for key down events + if (inp.EventType != KEY_EVENT) continue; + + // the modifier state + DWORD modstate = inp.Event.KeyEvent.dwControlKeyState; + + // we need to handle shift up events separately + if (!inp.Event.KeyEvent.bKeyDown && inp.Event.KeyEvent.wVirtualKeyCode == VK_SHIFT) { + modstate &= (DWORD)~SHIFT_PRESSED; + } + + // ignore AltGr + DWORD altgr = LEFT_CTRL_PRESSED | RIGHT_ALT_PRESSED; + if ((modstate & altgr) == altgr) { modstate &= ~altgr; } + + + // get modifiers + code_t mods = 0; + if ((modstate & ( RIGHT_CTRL_PRESSED | LEFT_CTRL_PRESSED )) != 0) mods |= KEY_MOD_CTRL; + if ((modstate & ( RIGHT_ALT_PRESSED | LEFT_ALT_PRESSED )) != 0) mods |= KEY_MOD_ALT; + if ((modstate & SHIFT_PRESSED) != 0) mods |= KEY_MOD_SHIFT; + + // virtual keys + uint32_t chr = (uint32_t)inp.Event.KeyEvent.uChar.UnicodeChar; + WORD virt = inp.Event.KeyEvent.wVirtualKeyCode; + debug_msg("tty: console %s: %s%s%s virt 0x%04x, chr 0x%04x ('%c')\n", inp.Event.KeyEvent.bKeyDown ? "down" : "up", mods&KEY_MOD_CTRL ? "ctrl-" : "", mods&KEY_MOD_ALT ? "alt-" : "", mods&KEY_MOD_SHIFT ? "shift-" : "", virt, chr, chr); + + // only process keydown events (except for Alt-up which is used for unicode pasting...) + if (!inp.Event.KeyEvent.bKeyDown && virt != VK_MENU) { + continue; + } + + if (chr == 0) { + switch (virt) { + case VK_UP: tty_cpush_csi_xterm(tty, mods, 'A'); return; + case VK_DOWN: tty_cpush_csi_xterm(tty, mods, 'B'); return; + case VK_RIGHT: tty_cpush_csi_xterm(tty, mods, 'C'); return; + case VK_LEFT: tty_cpush_csi_xterm(tty, mods, 'D'); return; + case VK_END: tty_cpush_csi_xterm(tty, mods, 'F'); return; + case VK_HOME: tty_cpush_csi_xterm(tty, mods, 'H'); return; + case VK_DELETE: tty_cpush_csi_vt(tty,mods,3); return; + case VK_PRIOR: tty_cpush_csi_vt(tty,mods,5); return; //page up + case VK_NEXT: tty_cpush_csi_vt(tty,mods,6); return; //page down + case VK_TAB: tty_cpush_csi_unicode(tty,mods,9); return; + case VK_RETURN: tty_cpush_csi_unicode(tty,mods,13); return; + default: { + uint32_t vtcode = 0; + if (virt >= VK_F1 && virt <= VK_F5) { + vtcode = 10 + (virt - VK_F1); + } + else if (virt >= VK_F6 && virt <= VK_F10) { + vtcode = 17 + (virt - VK_F6); + } + else if (virt >= VK_F11 && virt <= VK_F12) { + vtcode = 13 + (virt - VK_F11); + } + if (vtcode > 0) { + tty_cpush_csi_vt(tty,mods,vtcode); + return; + } + } + } + // ignore other control keys (shift etc). + } + // high surrogate pair + else if (chr >= 0xD800 && chr <= 0xDBFF) { + surrogate_hi = (chr - 0xD800); + } + // low surrogate pair + else if (chr >= 0xDC00 && chr <= 0xDFFF) { + chr = ((surrogate_hi << 10) + (chr - 0xDC00) + 0x10000); + tty_cpush_csi_unicode(tty,mods,chr); + surrogate_hi = 0; + return; + } + // regular character + else { + tty_cpush_csi_unicode(tty,mods,chr); + return; + } + } +} + +ic_private bool tty_async_stop(const tty_t* tty) { + // send ^c + INPUT_RECORD events[2]; + memset(events, 0, 2*sizeof(INPUT_RECORD)); + events[0].EventType = KEY_EVENT; + events[0].Event.KeyEvent.bKeyDown = TRUE; + events[0].Event.KeyEvent.uChar.AsciiChar = KEY_CTRL_C; + events[1] = events[0]; + events[1].Event.KeyEvent.bKeyDown = FALSE; + DWORD nwritten = 0; + WriteConsoleInput(tty->hcon, events, 2, &nwritten); + return (nwritten == 2); +} + +ic_private bool tty_start_raw(tty_t* tty) { + if (tty->raw_enabled) return true; + GetConsoleMode(tty->hcon,&tty->hcon_orig_mode); + DWORD mode = ENABLE_QUICK_EDIT_MODE // cut&paste allowed + | ENABLE_WINDOW_INPUT // to catch resize events + // | ENABLE_VIRTUAL_TERMINAL_INPUT + // | ENABLE_PROCESSED_INPUT + ; + SetConsoleMode(tty->hcon, mode ); + tty->raw_enabled = true; + return true; +} + +ic_private void tty_end_raw(tty_t* tty) { + if (!tty->raw_enabled) return; + SetConsoleMode(tty->hcon, tty->hcon_orig_mode ); + tty->raw_enabled = false; +} + +static bool tty_init_raw(tty_t* tty) { + tty->hcon = GetStdHandle( STD_INPUT_HANDLE ); + tty->has_term_resize_event = true; + return true; +} + +static void tty_done_raw(tty_t* tty) { + ic_unused(tty); +} + +#endif + + diff --git a/extern/isocline/src/tty.h b/extern/isocline/src/tty.h new file mode 100644 index 000000000..a0062bf30 --- /dev/null +++ b/extern/isocline/src/tty.h @@ -0,0 +1,160 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_TTY_H +#define IC_TTY_H + +#include "common.h" + +//------------------------------------------------------------- +// TTY/Keyboard input +//------------------------------------------------------------- + +// Key code +typedef uint32_t code_t; + +// TTY interface +struct tty_s; +typedef struct tty_s tty_t; + + +ic_private tty_t* tty_new(alloc_t* mem, int fd_in); +ic_private void tty_free(tty_t* tty); + +ic_private bool tty_is_utf8(const tty_t* tty); +ic_private bool tty_start_raw(tty_t* tty); +ic_private void tty_end_raw(tty_t* tty); +ic_private code_t tty_read(tty_t* tty); +ic_private bool tty_read_timeout(tty_t* tty, long timeout_ms, code_t* c ); + +ic_private void tty_code_pushback( tty_t* tty, code_t c ); +ic_private bool code_is_ascii_char(code_t c, char* chr ); +ic_private bool code_is_unicode(code_t c, unicode_t* uchr); +ic_private bool code_is_virt_key(code_t c ); + +ic_private bool tty_term_resize_event(tty_t* tty); // did the terminal resize? +ic_private bool tty_async_stop(const tty_t* tty); // unblock the read asynchronously +ic_private void tty_set_esc_delay(tty_t* tty, long initial_delay_ms, long followup_delay_ms); + +// shared between tty.c and tty_esc.c: low level character push +ic_private void tty_cpush_char(tty_t* tty, uint8_t c); +ic_private bool tty_cpop(tty_t* tty, uint8_t* c); +ic_private bool tty_readc_noblock(tty_t* tty, uint8_t* c, long timeout_ms); +ic_private code_t tty_read_esc(tty_t* tty, long esc_initial_timeout, long esc_timeout); // in tty_esc.c + +// used by term.c to read back ANSI escape responses +ic_private bool tty_read_esc_response(tty_t* tty, char esc_start, bool final_st, char* buf, ssize_t buflen ); + + +//------------------------------------------------------------- +// Key codes: a code_t is 32 bits. +// we use the bottom 24 (nah, 21) bits for unicode (up to x0010FFFF) +// The codes after x01000000 are for virtual keys +// and events use x02000000. +// The top 4 bits are used for modifiers. +//------------------------------------------------------------- + +static inline code_t key_char( char c ) { + // careful about signed character conversion (negative char ~> 0x80 - 0xFF) + return ((uint8_t)c); +} + +static inline code_t key_unicode( unicode_t u ) { + return u; +} + + +#define KEY_MOD_SHIFT (0x10000000U) +#define KEY_MOD_ALT (0x20000000U) +#define KEY_MOD_CTRL (0x40000000U) + +#define KEY_NO_MODS(k) (k & 0x0FFFFFFFU) +#define KEY_MODS(k) (k & 0xF0000000U) + +#define WITH_SHIFT(x) (x | KEY_MOD_SHIFT) +#define WITH_ALT(x) (x | KEY_MOD_ALT) +#define WITH_CTRL(x) (x | KEY_MOD_CTRL) + +#define KEY_NONE (0) +#define KEY_CTRL_A (1) +#define KEY_CTRL_B (2) +#define KEY_CTRL_C (3) +#define KEY_CTRL_D (4) +#define KEY_CTRL_E (5) +#define KEY_CTRL_F (6) +#define KEY_BELL (7) +#define KEY_BACKSP (8) +#define KEY_TAB (9) +#define KEY_LINEFEED (10) // ctrl/shift + enter is considered KEY_LINEFEED +#define KEY_CTRL_K (11) +#define KEY_CTRL_L (12) +#define KEY_ENTER (13) +#define KEY_CTRL_N (14) +#define KEY_CTRL_O (15) +#define KEY_CTRL_P (16) +#define KEY_CTRL_Q (17) +#define KEY_CTRL_R (18) +#define KEY_CTRL_S (19) +#define KEY_CTRL_T (20) +#define KEY_CTRL_U (21) +#define KEY_CTRL_V (22) +#define KEY_CTRL_W (23) +#define KEY_CTRL_X (24) +#define KEY_CTRL_Y (25) +#define KEY_CTRL_Z (26) +#define KEY_ESC (27) +#define KEY_SPACE (32) +#define KEY_RUBOUT (127) // always translated to KEY_BACKSP +#define KEY_UNICODE_MAX (0x0010FFFFU) + + +#define KEY_VIRT (0x01000000U) +#define KEY_UP (KEY_VIRT+0) +#define KEY_DOWN (KEY_VIRT+1) +#define KEY_LEFT (KEY_VIRT+2) +#define KEY_RIGHT (KEY_VIRT+3) +#define KEY_HOME (KEY_VIRT+4) +#define KEY_END (KEY_VIRT+5) +#define KEY_DEL (KEY_VIRT+6) +#define KEY_PAGEUP (KEY_VIRT+7) +#define KEY_PAGEDOWN (KEY_VIRT+8) +#define KEY_INS (KEY_VIRT+9) + +#define KEY_F1 (KEY_VIRT+11) +#define KEY_F2 (KEY_VIRT+12) +#define KEY_F3 (KEY_VIRT+13) +#define KEY_F4 (KEY_VIRT+14) +#define KEY_F5 (KEY_VIRT+15) +#define KEY_F6 (KEY_VIRT+16) +#define KEY_F7 (KEY_VIRT+17) +#define KEY_F8 (KEY_VIRT+18) +#define KEY_F9 (KEY_VIRT+19) +#define KEY_F10 (KEY_VIRT+20) +#define KEY_F11 (KEY_VIRT+21) +#define KEY_F12 (KEY_VIRT+22) +#define KEY_F(n) (KEY_F1 + (n) - 1) + +#define KEY_EVENT_BASE (0x02000000U) +#define KEY_EVENT_RESIZE (KEY_EVENT_BASE+1) +#define KEY_EVENT_AUTOTAB (KEY_EVENT_BASE+2) +#define KEY_EVENT_STOP (KEY_EVENT_BASE+3) + +// Convenience +#define KEY_CTRL_UP (WITH_CTRL(KEY_UP)) +#define KEY_CTRL_DOWN (WITH_CTRL(KEY_DOWN)) +#define KEY_CTRL_LEFT (WITH_CTRL(KEY_LEFT)) +#define KEY_CTRL_RIGHT (WITH_CTRL(KEY_RIGHT)) +#define KEY_CTRL_HOME (WITH_CTRL(KEY_HOME)) +#define KEY_CTRL_END (WITH_CTRL(KEY_END)) +#define KEY_CTRL_DEL (WITH_CTRL(KEY_DEL)) +#define KEY_CTRL_PAGEUP (WITH_CTRL(KEY_PAGEUP)) +#define KEY_CTRL_PAGEDOWN (WITH_CTRL(KEY_PAGEDOWN))) +#define KEY_CTRL_INS (WITH_CTRL(KEY_INS)) + +#define KEY_SHIFT_TAB (WITH_SHIFT(KEY_TAB)) + +#endif // IC_TTY_H diff --git a/extern/isocline/src/tty_esc.c b/extern/isocline/src/tty_esc.c new file mode 100644 index 000000000..0ac8761db --- /dev/null +++ b/extern/isocline/src/tty_esc.c @@ -0,0 +1,401 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#include +#include "tty.h" + +/*------------------------------------------------------------- +Decoding escape sequences to key codes. +This is a bit tricky there are many variants to encode keys as escape sequences, see for example: +- . +- +- +- +- + +Generally, for our purposes we accept a subset of escape sequences as: + + escseq ::= ESC + | ESC char + | ESC start special? (number (';' modifiers)?)? final + +where: + char ::= [\x00-\xFF] # any character + special ::= [:<=>?] + number ::= [0-9+] + modifiers ::= [1-9] + intermediate ::= [\x20-\x2F] # !"#$%&'()*+,-./ + final ::= [\x40-\x7F] # @A–Z[\]^_`a–z{|}~ + ESC ::= '\x1B' + CSI ::= ESC '[' + SS3 ::= ESC 'O' + +In ECMA48 `special? (number (';' modifiers)?)?` is the more liberal `[\x30-\x3F]*` +but that seems never used for key codes. If the number (vtcode or unicode) or the +modifiers are not given, we assume these are '1'. +We then accept the following key sequences: + + key ::= ESC # lone ESC + | ESC char # Alt+char + | ESC '[' special? vtcode ';' modifiers '~' # vt100 codes + | ESC '[' special? '1' ';' modifiers [A-Z] # xterm codes + | ESC 'O' special? '1' ';' modifiers [A-Za-z] # SS3 codes + | ESC '[' special? unicode ';' modifiers 'u' # direct unicode code + +Moreover, we translate the following special cases that do not fit into the above grammar. +First we translate away special starter sequences: +--------------------------------------------------------------------- + ESC '[' '[' .. ~> ESC '[' .. # Linux sometimes uses extra '[' for CSI + ESC '[' 'O' .. ~> ESC 'O' .. # Linux sometimes uses extra '[' for SS3 + ESC 'o' .. ~> ESC 'O' .. # Eterm: ctrl + SS3 + ESC '?' .. ~> ESC 'O' .. # vt52 treated as SS3 + +And then translate the following special cases into a standard form: +--------------------------------------------------------------------- + ESC '[' .. '@' ~> ESC '[' '3' '~' # Del on Mach + ESC '[' .. '9' ~> ESC '[' '2' '~' # Ins on Mach + ESC .. [^@$] ~> ESC .. '~' # ETerm,xrvt,urxt: ^ = ctrl, $ = shift, @ = alt + ESC '[' [a-d] ~> ESC '[' '1' ';' '2' [A-D] # Eterm shift+ + ESC 'O' [1-9] final ~> ESC 'O' '1' ';' [1-9] final # modifiers as parameter 1 (like on Haiku) + ESC '[' [1-9] [^~u] ~> ESC 'O' '1' ';' [1-9] final # modifiers as parameter 1 + +The modifier keys are encoded as "(modifiers-1) & mask" where the +shift mask is 0x01, alt 0x02 and ctrl 0x04. Therefore: +------------------------------------------------------------ + 1: - 5: ctrl 9: alt (for minicom) + 2: shift 6: shift+ctrl + 3: alt 7: alt+ctrl + 4: shift+alt 8: shift+alt+ctrl + +The different encodings fox vt100, xterm, and SS3 are: + +vt100: ESC [ vtcode ';' modifiers '~' +-------------------------------------- + 1: Home 10-15: F1-F5 + 2: Ins 16 : F5 + 3: Del 17-21: F6-F10 + 4: End 23-26: F11-F14 + 5: PageUp 28 : F15 + 6: PageDn 29 : F16 + 7: Home 31-34: F17-F20 + 8: End + +xterm: ESC [ 1 ';' modifiers [A-Z] +----------------------------------- + A: Up N: F2 + B: Down O: F3 + C: Right P: F4 + D: Left Q: F5 + E: '5' R: F6 + F: End S: F7 + G: T: F8 + H: Home U: PageDn + I: PageUp V: PageUp + J: W: F11 + K: X: F12 + L: Ins Y: End + M: F1 Z: shift+Tab + +SS3: ESC 'O' 1 ';' modifiers [A-Za-z] +--------------------------------------- + (normal) (numpad) + A: Up N: a: Up n: + B: Down O: b: Down o: + C: Right P: F1 c: Right p: Ins + D: Left Q: F2 d: Left q: End + E: '5' R: F3 e: r: Down + F: End S: F4 f: s: PageDn + G: T: F5 g: t: Left + H: Home U: F6 h: u: '5' + I: Tab V: F7 i: v: Right + J: W: F8 j: '*' w: Home + K: X: F9 k: '+' x: Up + L: Y: F10 l: ',' y: PageUp + M: \x0A '\n' Z: shift+Tab m: '-' z: + +-------------------------------------------------------------*/ + +//------------------------------------------------------------- +// Decode escape sequences +//------------------------------------------------------------- + +static code_t esc_decode_vt(uint32_t vt_code ) { + switch(vt_code) { + case 1: return KEY_HOME; + case 2: return KEY_INS; + case 3: return KEY_DEL; + case 4: return KEY_END; + case 5: return KEY_PAGEUP; + case 6: return KEY_PAGEDOWN; + case 7: return KEY_HOME; + case 8: return KEY_END; + default: + if (vt_code >= 10 && vt_code <= 15) return KEY_F(1 + (vt_code - 10)); + if (vt_code == 16) return KEY_F5; // minicom + if (vt_code >= 17 && vt_code <= 21) return KEY_F(6 + (vt_code - 17)); + if (vt_code >= 23 && vt_code <= 26) return KEY_F(11 + (vt_code - 23)); + if (vt_code >= 28 && vt_code <= 29) return KEY_F(15 + (vt_code - 28)); + if (vt_code >= 31 && vt_code <= 34) return KEY_F(17 + (vt_code - 31)); + } + return KEY_NONE; +} + +static code_t esc_decode_xterm( uint8_t xcode ) { + // ESC [ + switch(xcode) { + case 'A': return KEY_UP; + case 'B': return KEY_DOWN; + case 'C': return KEY_RIGHT; + case 'D': return KEY_LEFT; + case 'E': return '5'; // numpad 5 + case 'F': return KEY_END; + case 'H': return KEY_HOME; + case 'Z': return KEY_TAB | KEY_MOD_SHIFT; + // Freebsd: + case 'I': return KEY_PAGEUP; + case 'L': return KEY_INS; + case 'M': return KEY_F1; + case 'N': return KEY_F2; + case 'O': return KEY_F3; + case 'P': return KEY_F4; // note: differs from + case 'Q': return KEY_F5; + case 'R': return KEY_F6; + case 'S': return KEY_F7; + case 'T': return KEY_F8; + case 'U': return KEY_PAGEDOWN; // Mach + case 'V': return KEY_PAGEUP; // Mach + case 'W': return KEY_F11; + case 'X': return KEY_F12; + case 'Y': return KEY_END; // Mach + } + return KEY_NONE; +} + +static code_t esc_decode_ss3( uint8_t ss3_code ) { + // ESC O + switch(ss3_code) { + case 'A': return KEY_UP; + case 'B': return KEY_DOWN; + case 'C': return KEY_RIGHT; + case 'D': return KEY_LEFT; + case 'E': return '5'; // numpad 5 + case 'F': return KEY_END; + case 'H': return KEY_HOME; + case 'I': return KEY_TAB; + case 'Z': return KEY_TAB | KEY_MOD_SHIFT; + case 'M': return KEY_LINEFEED; + case 'P': return KEY_F1; + case 'Q': return KEY_F2; + case 'R': return KEY_F3; + case 'S': return KEY_F4; + // on Mach + case 'T': return KEY_F5; + case 'U': return KEY_F6; + case 'V': return KEY_F7; + case 'W': return KEY_F8; + case 'X': return KEY_F9; // '=' on vt220 + case 'Y': return KEY_F10; + // numpad + case 'a': return KEY_UP; + case 'b': return KEY_DOWN; + case 'c': return KEY_RIGHT; + case 'd': return KEY_LEFT; + case 'j': return '*'; + case 'k': return '+'; + case 'l': return ','; + case 'm': return '-'; + case 'n': return KEY_DEL; // '.' + case 'o': return '/'; + case 'p': return KEY_INS; + case 'q': return KEY_END; + case 'r': return KEY_DOWN; + case 's': return KEY_PAGEDOWN; + case 't': return KEY_LEFT; + case 'u': return '5'; + case 'v': return KEY_RIGHT; + case 'w': return KEY_HOME; + case 'x': return KEY_UP; + case 'y': return KEY_PAGEUP; + } + return KEY_NONE; +} + +static void tty_read_csi_num(tty_t* tty, uint8_t* ppeek, uint32_t* num, long esc_timeout) { + *num = 1; // default + ssize_t count = 0; + uint32_t i = 0; + while (*ppeek >= '0' && *ppeek <= '9' && count < 16) { + uint8_t digit = *ppeek - '0'; + if (!tty_readc_noblock(tty,ppeek,esc_timeout)) break; // peek is not modified in this case + count++; + i = 10*i + digit; + } + if (count > 0) *num = i; +} + +static code_t tty_read_csi(tty_t* tty, uint8_t c1, uint8_t peek, code_t mods0, long esc_timeout) { + // CSI starts with 0x9b (c1=='[') | ESC [ (c1=='[') | ESC [Oo?] (c1 == 'O') /* = SS3 */ + + // check for extra starter '[' (Linux sends ESC [ [ 15 ~ for F5 for example) + if (c1 == '[' && strchr("[Oo", (char)peek) != NULL) { + uint8_t cx = peek; + if (tty_readc_noblock(tty,&peek,esc_timeout)) { + c1 = cx; + } + } + + // "special" characters ('?' is used for private sequences) + uint8_t special = 0; + if (strchr(":<=>?",(char)peek) != NULL) { + special = peek; + if (!tty_readc_noblock(tty,&peek,esc_timeout)) { + tty_cpush_char(tty,special); // recover + return (key_unicode(c1) | KEY_MOD_ALT); // Alt+ + } + } + + // up to 2 parameters that default to 1 + uint32_t num1 = 1; + uint32_t num2 = 1; + tty_read_csi_num(tty,&peek,&num1,esc_timeout); + if (peek == ';') { + if (!tty_readc_noblock(tty,&peek,esc_timeout)) return KEY_NONE; + tty_read_csi_num(tty,&peek,&num2,esc_timeout); + } + + // the final character (we do not allow 'intermediate characters') + uint8_t final = peek; + code_t modifiers = mods0; + + debug_msg("tty: escape sequence: ESC %c %c %d;%d %c\n", c1, (special == 0 ? '_' : special), num1, num2, final); + + // Adjust special cases into standard ones. + if ((final == '@' || final == '9') && c1 == '[' && num1 == 1) { + // ESC [ @, ESC [ 9 : on Mach + if (final == '@') num1 = 3; // DEL + else if (final == '9') num1 = 2; // INS + final = '~'; + } + else if (final == '^' || final == '$' || final == '@') { + // Eterm/rxvt/urxt + if (final=='^') modifiers |= KEY_MOD_CTRL; + if (final=='$') modifiers |= KEY_MOD_SHIFT; + if (final=='@') modifiers |= KEY_MOD_SHIFT | KEY_MOD_CTRL; + final = '~'; + } + else if (c1 == '[' && final >= 'a' && final <= 'd') { // note: do not catch ESC [ .. u (for unicode) + // ESC [ [a-d] : on Eterm for shift+ cursor + modifiers |= KEY_MOD_SHIFT; + final = 'A' + (final - 'a'); + } + + if (((c1 == 'O') || (c1=='[' && final != '~' && final != 'u')) && + (num2 == 1 && num1 > 1 && num1 <= 8)) + { + // on haiku the modifier can be parameter 1, make it parameter 2 instead + num2 = num1; + num1 = 1; + } + + // parameter 2 determines the modifiers + if (num2 > 1 && num2 <= 9) { + if (num2 == 9) num2 = 3; // iTerm2 in xterm mode + num2--; + if (num2 & 0x1) modifiers |= KEY_MOD_SHIFT; + if (num2 & 0x2) modifiers |= KEY_MOD_ALT; + if (num2 & 0x4) modifiers |= KEY_MOD_CTRL; + } + + // and translate + code_t code = KEY_NONE; + if (final == '~') { + // vt codes + code = esc_decode_vt(num1); + } + else if (c1 == '[' && final == 'u') { + // unicode + code = key_unicode(num1); + } + else if (c1 == 'O' && ((final >= 'A' && final <= 'Z') || (final >= 'a' && final <= 'z'))) { + // ss3 + code = esc_decode_ss3(final); + } + else if (num1 == 1 && final >= 'A' && final <= 'Z') { + // xterm + code = esc_decode_xterm(final); + } + else if (c1 == '[' && final == 'R') { + // cursor position + code = KEY_NONE; + } + + if (code == KEY_NONE && final != 'R') { + debug_msg("tty: ignore escape sequence: ESC %c %zu;%zu %c\n", c1, num1, num2, final); + } + return (code != KEY_NONE ? (code | modifiers) : KEY_NONE); +} + +static code_t tty_read_osc( tty_t* tty, uint8_t* ppeek, long esc_timeout ) { + debug_msg("discard OSC response..\n"); + // keep reading until termination: OSC is terminated by BELL, or ESC \ (ST) (and STX) + while (true) { + uint8_t c = *ppeek; + if (c <= '\x07') { // BELL and anything below (STX, ^C, ^D) + if (c != '\x07') { tty_cpush_char( tty, c ); } + break; + } + else if (c=='\x1B') { + uint8_t c1; + if (!tty_readc_noblock(tty, &c1, esc_timeout)) break; + if (c1=='\\') break; + tty_cpush_char(tty,c1); + } + if (!tty_readc_noblock(tty, ppeek, esc_timeout)) break; + } + return KEY_NONE; +} + +ic_private code_t tty_read_esc(tty_t* tty, long esc_initial_timeout, long esc_timeout) { + code_t mods = 0; + uint8_t peek = 0; + + // lone ESC? + if (!tty_readc_noblock(tty, &peek, esc_initial_timeout)) return KEY_ESC; + + // treat ESC ESC as Alt modifier (macOS sends ESC ESC [ [A-D] for alt-) + if (peek == KEY_ESC) { + if (!tty_readc_noblock(tty, &peek, esc_timeout)) goto alt; + mods |= KEY_MOD_ALT; + } + + // CSI ? + if (peek == '[') { + if (!tty_readc_noblock(tty, &peek, esc_timeout)) goto alt; + return tty_read_csi(tty, '[', peek, mods, esc_timeout); // ESC [ ... + } + + // SS3? + if (peek == 'O' || peek == 'o' || peek == '?' /*vt52*/) { + uint8_t c1 = peek; + if (!tty_readc_noblock(tty, &peek, esc_timeout)) goto alt; + if (c1 == 'o') { + // ETerm uses this for ctrl+ + mods |= KEY_MOD_CTRL; + } + // treat all as standard SS3 'O' + return tty_read_csi(tty,'O',peek,mods, esc_timeout); // ESC [Oo?] ... + } + + // OSC: we may get a delayed query response; ensure it is ignored + if (peek == ']') { + if (!tty_readc_noblock(tty, &peek, esc_timeout)) goto alt; + return tty_read_osc(tty, &peek, esc_timeout); // ESC ] ... + } + +alt: + // Alt+ + return (key_unicode(peek) | KEY_MOD_ALT); // ESC +} diff --git a/extern/isocline/src/undo.c b/extern/isocline/src/undo.c new file mode 100644 index 000000000..eefc318d7 --- /dev/null +++ b/extern/isocline/src/undo.c @@ -0,0 +1,67 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#include +#include + +#include "../include/isocline.h" +#include "common.h" +#include "env.h" +#include "stringbuf.h" +#include "completions.h" +#include "undo.h" + + + +//------------------------------------------------------------- +// edit state +//------------------------------------------------------------- +struct editstate_s { + struct editstate_s* next; + const char* input; // input + ssize_t pos; // cursor position +}; + +ic_private void editstate_init( editstate_t** es ) { + *es = NULL; +} + +ic_private void editstate_done( alloc_t* mem, editstate_t** es ) { + while (*es != NULL) { + editstate_t* next = (*es)->next; + mem_free(mem, (*es)->input); + mem_free(mem, *es ); + *es = next; + } + *es = NULL; +} + +ic_private void editstate_capture( alloc_t* mem, editstate_t** es, const char* input, ssize_t pos) { + if (input==NULL) input = ""; + // alloc + editstate_t* entry = mem_zalloc_tp(mem, editstate_t); + if (entry == NULL) return; + // initialize + entry->input = mem_strdup( mem, input); + entry->pos = pos; + if (entry->input == NULL) { mem_free(mem, entry); return; } + // and push + entry->next = *es; + *es = entry; +} + +// caller should free *input +ic_private bool editstate_restore( alloc_t* mem, editstate_t** es, const char** input, ssize_t* pos ) { + if (*es == NULL) return false; + // pop + editstate_t* entry = *es; + *es = entry->next; + *input = entry->input; + *pos = entry->pos; + mem_free(mem, entry); + return true; +} + diff --git a/extern/isocline/src/undo.h b/extern/isocline/src/undo.h new file mode 100644 index 000000000..576cf9773 --- /dev/null +++ b/extern/isocline/src/undo.h @@ -0,0 +1,24 @@ +/* ---------------------------------------------------------------------------- + Copyright (c) 2021, Daan Leijen + This is free software; you can redistribute it and/or modify it + under the terms of the MIT License. A copy of the license can be + found in the "LICENSE" file at the root of this distribution. +-----------------------------------------------------------------------------*/ +#pragma once +#ifndef IC_UNDO_H +#define IC_UNDO_H + +#include "common.h" + +//------------------------------------------------------------- +// Edit state +//------------------------------------------------------------- +struct editstate_s; +typedef struct editstate_s editstate_t; + +ic_private void editstate_init( editstate_t** es ); +ic_private void editstate_done( alloc_t* mem, editstate_t** es ); +ic_private void editstate_capture( alloc_t* mem, editstate_t** es, const char* input, ssize_t pos); +ic_private bool editstate_restore( alloc_t* mem, editstate_t** es, const char** input, ssize_t* pos ); // caller needs to free input + +#endif // IC_UNDO_H diff --git a/extern/isocline/src/wcwidth.c b/extern/isocline/src/wcwidth.c new file mode 100644 index 000000000..85187d415 --- /dev/null +++ b/extern/isocline/src/wcwidth.c @@ -0,0 +1,292 @@ +// include in "stringbuf.c" +/* + * This is an implementation of wcwidth() and wcswidth() (defined in + * IEEE Std 1002.1-2001) for Unicode. + * + * http://www.opengroup.org/onlinepubs/007904975/functions/wcwidth.html + * http://www.opengroup.org/onlinepubs/007904975/functions/wcswidth.html + * + * In fixed-width output devices, Latin characters all occupy a single + * "cell" position of equal width, whereas ideographic CJK characters + * occupy two such cells. Interoperability between terminal-line + * applications and (teletype-style) character terminals using the + * UTF-8 encoding requires agreement on which character should advance + * the cursor by how many cell positions. No established formal + * standards exist at present on which Unicode character shall occupy + * how many cell positions on character terminals. These routines are + * a first attempt of defining such behavior based on simple rules + * applied to data provided by the Unicode Consortium. + * + * For some graphical characters, the Unicode standard explicitly + * defines a character-cell width via the definition of the East Asian + * FullWidth (F), Wide (W), Half-width (H), and Narrow (Na) classes. + * In all these cases, there is no ambiguity about which width a + * terminal shall use. For characters in the East Asian Ambiguous (A) + * class, the width choice depends purely on a preference of backward + * compatibility with either historic CJK or Western practice. + * Choosing single-width for these characters is easy to justify as + * the appropriate long-term solution, as the CJK practice of + * displaying these characters as double-width comes from historic + * implementation simplicity (8-bit encoded characters were displayed + * single-width and 16-bit ones double-width, even for Greek, + * Cyrillic, etc.) and not any typographic considerations. + * + * Much less clear is the choice of width for the Not East Asian + * (Neutral) class. Existing practice does not dictate a width for any + * of these characters. It would nevertheless make sense + * typographically to allocate two character cells to characters such + * as for instance EM SPACE or VOLUME INTEGRAL, which cannot be + * represented adequately with a single-width glyph. The following + * routines at present merely assign a single-cell width to all + * neutral characters, in the interest of simplicity. This is not + * entirely satisfactory and should be reconsidered before + * establishing a formal standard in this area. At the moment, the + * decision which Not East Asian (Neutral) characters should be + * represented by double-width glyphs cannot yet be answered by + * applying a simple rule from the Unicode database content. Setting + * up a proper standard for the behavior of UTF-8 character terminals + * will require a careful analysis not only of each Unicode character, + * but also of each presentation form, something the author of these + * routines has avoided to do so far. + * + * http://www.unicode.org/unicode/reports/tr11/ + * + * Markus Kuhn -- 2007-05-26 (Unicode 5.0) + * + * Permission to use, copy, modify, and distribute this software + * for any purpose and without fee is hereby granted. The author + * disclaims all warranties with regard to this software. + * + * Latest version: http://www.cl.cam.ac.uk/~mgk25/ucs/wcwidth.c + */ + +#include +#include + +struct interval { + int32_t first; + int32_t last; +}; + +/* auxiliary function for binary search in interval table */ +static int bisearch(int32_t ucs, const struct interval *table, int max) { + int min = 0; + int mid; + + if (ucs < table[0].first || ucs > table[max].last) + return 0; + while (max >= min) { + mid = (min + max) / 2; + if (ucs > table[mid].last) + min = mid + 1; + else if (ucs < table[mid].first) + max = mid - 1; + else + return 1; + } + + return 0; +} + + +/* The following two functions define the column width of an ISO 10646 + * character as follows: + * + * - The null character (U+0000) has a column width of 0. + * + * - Other C0/C1 control characters and DEL will lead to a return + * value of -1. + * + * - Non-spacing and enclosing combining characters (general + * category code Mn or Me in the Unicode database) have a + * column width of 0. + * + * - SOFT HYPHEN (U+00AD) has a column width of 1. + * + * - Other format characters (general category code Cf in the Unicode + * database) and ZERO WIDTH SPACE (U+200B) have a column width of 0. + * + * - Hangul Jamo medial vowels and final consonants (U+1160-U+11FF) + * have a column width of 0. + * + * - Spacing characters in the East Asian Wide (W) or East Asian + * Full-width (F) category as defined in Unicode Technical + * Report #11 have a column width of 2. + * + * - All remaining characters (including all printable + * ISO 8859-1 and WGL4 characters, Unicode control characters, + * etc.) have a column width of 1. + * + * This implementation assumes that wchar_t characters are encoded + * in ISO 10646. + */ + +static int mk_is_wide_char(int32_t ucs) { + static const struct interval wide[] = { + {0x1100, 0x115f}, {0x231a, 0x231b}, {0x2329, 0x232a}, + {0x23e9, 0x23ec}, {0x23f0, 0x23f0}, {0x23f3, 0x23f3}, + {0x25fd, 0x25fe}, {0x2614, 0x2615}, {0x2648, 0x2653}, + {0x267f, 0x267f}, {0x2693, 0x2693}, {0x26a1, 0x26a1}, + {0x26aa, 0x26ab}, {0x26bd, 0x26be}, {0x26c4, 0x26c5}, + {0x26ce, 0x26ce}, {0x26d4, 0x26d4}, {0x26ea, 0x26ea}, + {0x26f2, 0x26f3}, {0x26f5, 0x26f5}, {0x26fa, 0x26fa}, + {0x26fd, 0x26fd}, {0x2705, 0x2705}, {0x270a, 0x270b}, + {0x2728, 0x2728}, {0x274c, 0x274c}, {0x274e, 0x274e}, + {0x2753, 0x2755}, {0x2757, 0x2757}, {0x2795, 0x2797}, + {0x27b0, 0x27b0}, {0x27bf, 0x27bf}, {0x2b1b, 0x2b1c}, + {0x2b50, 0x2b50}, {0x2b55, 0x2b55}, {0x2e80, 0x2fdf}, + {0x2ff0, 0x303e}, {0x3040, 0x3247}, {0x3250, 0x4dbf}, + {0x4e00, 0xa4cf}, {0xa960, 0xa97f}, {0xac00, 0xd7a3}, + {0xf900, 0xfaff}, {0xfe10, 0xfe19}, {0xfe30, 0xfe6f}, + {0xff01, 0xff60}, {0xffe0, 0xffe6}, {0x16fe0, 0x16fe1}, + {0x17000, 0x18aff}, {0x1b000, 0x1b12f}, {0x1b170, 0x1b2ff}, + {0x1f004, 0x1f004}, {0x1f0cf, 0x1f0cf}, {0x1f18e, 0x1f18e}, + {0x1f191, 0x1f19a}, {0x1f200, 0x1f202}, {0x1f210, 0x1f23b}, + {0x1f240, 0x1f248}, {0x1f250, 0x1f251}, {0x1f260, 0x1f265}, + {0x1f300, 0x1f320}, {0x1f32d, 0x1f335}, {0x1f337, 0x1f37c}, + {0x1f37e, 0x1f393}, {0x1f3a0, 0x1f3ca}, {0x1f3cf, 0x1f3d3}, + {0x1f3e0, 0x1f3f0}, {0x1f3f4, 0x1f3f4}, {0x1f3f8, 0x1f43e}, + {0x1f440, 0x1f440}, {0x1f442, 0x1f4fc}, {0x1f4ff, 0x1f53d}, + {0x1f54b, 0x1f54e}, {0x1f550, 0x1f567}, {0x1f57a, 0x1f57a}, + {0x1f595, 0x1f596}, {0x1f5a4, 0x1f5a4}, {0x1f5fb, 0x1f64f}, + {0x1f680, 0x1f6c5}, {0x1f6cc, 0x1f6cc}, {0x1f6d0, 0x1f6d2}, + {0x1f6eb, 0x1f6ec}, {0x1f6f4, 0x1f6f8}, {0x1f910, 0x1f93e}, + {0x1f940, 0x1f94c}, {0x1f950, 0x1f96b}, {0x1f980, 0x1f997}, + {0x1f9c0, 0x1f9c0}, {0x1f9d0, 0x1f9e6}, {0x20000, 0x2fffd}, + {0x30000, 0x3fffd}, + }; + + if ( bisearch(ucs, wide, sizeof(wide) / sizeof(struct interval) - 1) ) { + return 1; + } + + return 0; +} + +static int mk_wcwidth(int32_t ucs) { + /* sorted list of non-overlapping intervals of non-spacing characters */ + /* generated by "uniset +cat=Me +cat=Mn +cat=Cf -00AD +1160-11FF +200B c" */ + static const struct interval combining[] = { + {0x00ad, 0x00ad}, {0x0300, 0x036f}, {0x0483, 0x0489}, + {0x0591, 0x05bd}, {0x05bf, 0x05bf}, {0x05c1, 0x05c2}, + {0x05c4, 0x05c5}, {0x05c7, 0x05c7}, {0x0610, 0x061a}, + {0x061c, 0x061c}, {0x064b, 0x065f}, {0x0670, 0x0670}, + {0x06d6, 0x06dc}, {0x06df, 0x06e4}, {0x06e7, 0x06e8}, + {0x06ea, 0x06ed}, {0x0711, 0x0711}, {0x0730, 0x074a}, + {0x07a6, 0x07b0}, {0x07eb, 0x07f3}, {0x0816, 0x0819}, + {0x081b, 0x0823}, {0x0825, 0x0827}, {0x0829, 0x082d}, + {0x0859, 0x085b}, {0x08d4, 0x08e1}, {0x08e3, 0x0902}, + {0x093a, 0x093a}, {0x093c, 0x093c}, {0x0941, 0x0948}, + {0x094d, 0x094d}, {0x0951, 0x0957}, {0x0962, 0x0963}, + {0x0981, 0x0981}, {0x09bc, 0x09bc}, {0x09c1, 0x09c4}, + {0x09cd, 0x09cd}, {0x09e2, 0x09e3}, {0x0a01, 0x0a02}, + {0x0a3c, 0x0a3c}, {0x0a41, 0x0a42}, {0x0a47, 0x0a48}, + {0x0a4b, 0x0a4d}, {0x0a51, 0x0a51}, {0x0a70, 0x0a71}, + {0x0a75, 0x0a75}, {0x0a81, 0x0a82}, {0x0abc, 0x0abc}, + {0x0ac1, 0x0ac5}, {0x0ac7, 0x0ac8}, {0x0acd, 0x0acd}, + {0x0ae2, 0x0ae3}, {0x0afa, 0x0aff}, {0x0b01, 0x0b01}, + {0x0b3c, 0x0b3c}, {0x0b3f, 0x0b3f}, {0x0b41, 0x0b44}, + {0x0b4d, 0x0b4d}, {0x0b56, 0x0b56}, {0x0b62, 0x0b63}, + {0x0b82, 0x0b82}, {0x0bc0, 0x0bc0}, {0x0bcd, 0x0bcd}, + {0x0c00, 0x0c00}, {0x0c3e, 0x0c40}, {0x0c46, 0x0c48}, + {0x0c4a, 0x0c4d}, {0x0c55, 0x0c56}, {0x0c62, 0x0c63}, + {0x0c81, 0x0c81}, {0x0cbc, 0x0cbc}, {0x0cbf, 0x0cbf}, + {0x0cc6, 0x0cc6}, {0x0ccc, 0x0ccd}, {0x0ce2, 0x0ce3}, + {0x0d00, 0x0d01}, {0x0d3b, 0x0d3c}, {0x0d41, 0x0d44}, + {0x0d4d, 0x0d4d}, {0x0d62, 0x0d63}, {0x0dca, 0x0dca}, + {0x0dd2, 0x0dd4}, {0x0dd6, 0x0dd6}, {0x0e31, 0x0e31}, + {0x0e34, 0x0e3a}, {0x0e47, 0x0e4e}, {0x0eb1, 0x0eb1}, + {0x0eb4, 0x0eb9}, {0x0ebb, 0x0ebc}, {0x0ec8, 0x0ecd}, + {0x0f18, 0x0f19}, {0x0f35, 0x0f35}, {0x0f37, 0x0f37}, + {0x0f39, 0x0f39}, {0x0f71, 0x0f7e}, {0x0f80, 0x0f84}, + {0x0f86, 0x0f87}, {0x0f8d, 0x0f97}, {0x0f99, 0x0fbc}, + {0x0fc6, 0x0fc6}, {0x102d, 0x1030}, {0x1032, 0x1037}, + {0x1039, 0x103a}, {0x103d, 0x103e}, {0x1058, 0x1059}, + {0x105e, 0x1060}, {0x1071, 0x1074}, {0x1082, 0x1082}, + {0x1085, 0x1086}, {0x108d, 0x108d}, {0x109d, 0x109d}, + {0x1160, 0x11ff}, {0x135d, 0x135f}, {0x1712, 0x1714}, + {0x1732, 0x1734}, {0x1752, 0x1753}, {0x1772, 0x1773}, + {0x17b4, 0x17b5}, {0x17b7, 0x17bd}, {0x17c6, 0x17c6}, + {0x17c9, 0x17d3}, {0x17dd, 0x17dd}, {0x180b, 0x180e}, + {0x1885, 0x1886}, {0x18a9, 0x18a9}, {0x1920, 0x1922}, + {0x1927, 0x1928}, {0x1932, 0x1932}, {0x1939, 0x193b}, + {0x1a17, 0x1a18}, {0x1a1b, 0x1a1b}, {0x1a56, 0x1a56}, + {0x1a58, 0x1a5e}, {0x1a60, 0x1a60}, {0x1a62, 0x1a62}, + {0x1a65, 0x1a6c}, {0x1a73, 0x1a7c}, {0x1a7f, 0x1a7f}, + {0x1ab0, 0x1abe}, {0x1b00, 0x1b03}, {0x1b34, 0x1b34}, + {0x1b36, 0x1b3a}, {0x1b3c, 0x1b3c}, {0x1b42, 0x1b42}, + {0x1b6b, 0x1b73}, {0x1b80, 0x1b81}, {0x1ba2, 0x1ba5}, + {0x1ba8, 0x1ba9}, {0x1bab, 0x1bad}, {0x1be6, 0x1be6}, + {0x1be8, 0x1be9}, {0x1bed, 0x1bed}, {0x1bef, 0x1bf1}, + {0x1c2c, 0x1c33}, {0x1c36, 0x1c37}, {0x1cd0, 0x1cd2}, + {0x1cd4, 0x1ce0}, {0x1ce2, 0x1ce8}, {0x1ced, 0x1ced}, + {0x1cf4, 0x1cf4}, {0x1cf8, 0x1cf9}, {0x1dc0, 0x1df9}, + {0x1dfb, 0x1dff}, {0x200b, 0x200f}, {0x202a, 0x202e}, + {0x2060, 0x2064}, {0x2066, 0x206f}, {0x20d0, 0x20f0}, + {0x2cef, 0x2cf1}, {0x2d7f, 0x2d7f}, {0x2de0, 0x2dff}, + {0x302a, 0x302d}, {0x3099, 0x309a}, {0xa66f, 0xa672}, + {0xa674, 0xa67d}, {0xa69e, 0xa69f}, {0xa6f0, 0xa6f1}, + {0xa802, 0xa802}, {0xa806, 0xa806}, {0xa80b, 0xa80b}, + {0xa825, 0xa826}, {0xa8c4, 0xa8c5}, {0xa8e0, 0xa8f1}, + {0xa926, 0xa92d}, {0xa947, 0xa951}, {0xa980, 0xa982}, + {0xa9b3, 0xa9b3}, {0xa9b6, 0xa9b9}, {0xa9bc, 0xa9bc}, + {0xa9e5, 0xa9e5}, {0xaa29, 0xaa2e}, {0xaa31, 0xaa32}, + {0xaa35, 0xaa36}, {0xaa43, 0xaa43}, {0xaa4c, 0xaa4c}, + {0xaa7c, 0xaa7c}, {0xaab0, 0xaab0}, {0xaab2, 0xaab4}, + {0xaab7, 0xaab8}, {0xaabe, 0xaabf}, {0xaac1, 0xaac1}, + {0xaaec, 0xaaed}, {0xaaf6, 0xaaf6}, {0xabe5, 0xabe5}, + {0xabe8, 0xabe8}, {0xabed, 0xabed}, {0xfb1e, 0xfb1e}, + {0xfe00, 0xfe0f}, {0xfe20, 0xfe2f}, {0xfeff, 0xfeff}, + {0xfff9, 0xfffb}, {0x101fd, 0x101fd}, {0x102e0, 0x102e0}, + {0x10376, 0x1037a}, {0x10a01, 0x10a03}, {0x10a05, 0x10a06}, + {0x10a0c, 0x10a0f}, {0x10a38, 0x10a3a}, {0x10a3f, 0x10a3f}, + {0x10ae5, 0x10ae6}, {0x11001, 0x11001}, {0x11038, 0x11046}, + {0x1107f, 0x11081}, {0x110b3, 0x110b6}, {0x110b9, 0x110ba}, + {0x11100, 0x11102}, {0x11127, 0x1112b}, {0x1112d, 0x11134}, + {0x11173, 0x11173}, {0x11180, 0x11181}, {0x111b6, 0x111be}, + {0x111ca, 0x111cc}, {0x1122f, 0x11231}, {0x11234, 0x11234}, + {0x11236, 0x11237}, {0x1123e, 0x1123e}, {0x112df, 0x112df}, + {0x112e3, 0x112ea}, {0x11300, 0x11301}, {0x1133c, 0x1133c}, + {0x11340, 0x11340}, {0x11366, 0x1136c}, {0x11370, 0x11374}, + {0x11438, 0x1143f}, {0x11442, 0x11444}, {0x11446, 0x11446}, + {0x114b3, 0x114b8}, {0x114ba, 0x114ba}, {0x114bf, 0x114c0}, + {0x114c2, 0x114c3}, {0x115b2, 0x115b5}, {0x115bc, 0x115bd}, + {0x115bf, 0x115c0}, {0x115dc, 0x115dd}, {0x11633, 0x1163a}, + {0x1163d, 0x1163d}, {0x1163f, 0x11640}, {0x116ab, 0x116ab}, + {0x116ad, 0x116ad}, {0x116b0, 0x116b5}, {0x116b7, 0x116b7}, + {0x1171d, 0x1171f}, {0x11722, 0x11725}, {0x11727, 0x1172b}, + {0x11a01, 0x11a06}, {0x11a09, 0x11a0a}, {0x11a33, 0x11a38}, + {0x11a3b, 0x11a3e}, {0x11a47, 0x11a47}, {0x11a51, 0x11a56}, + {0x11a59, 0x11a5b}, {0x11a8a, 0x11a96}, {0x11a98, 0x11a99}, + {0x11c30, 0x11c36}, {0x11c38, 0x11c3d}, {0x11c3f, 0x11c3f}, + {0x11c92, 0x11ca7}, {0x11caa, 0x11cb0}, {0x11cb2, 0x11cb3}, + {0x11cb5, 0x11cb6}, {0x11d31, 0x11d36}, {0x11d3a, 0x11d3a}, + {0x11d3c, 0x11d3d}, {0x11d3f, 0x11d45}, {0x11d47, 0x11d47}, + {0x16af0, 0x16af4}, {0x16b30, 0x16b36}, {0x16f8f, 0x16f92}, + {0x1bc9d, 0x1bc9e}, {0x1bca0, 0x1bca3}, {0x1d167, 0x1d169}, + {0x1d173, 0x1d182}, {0x1d185, 0x1d18b}, {0x1d1aa, 0x1d1ad}, + {0x1d242, 0x1d244}, {0x1da00, 0x1da36}, {0x1da3b, 0x1da6c}, + {0x1da75, 0x1da75}, {0x1da84, 0x1da84}, {0x1da9b, 0x1da9f}, + {0x1daa1, 0x1daaf}, {0x1e000, 0x1e006}, {0x1e008, 0x1e018}, + {0x1e01b, 0x1e021}, {0x1e023, 0x1e024}, {0x1e026, 0x1e02a}, + {0x1e8d0, 0x1e8d6}, {0x1e944, 0x1e94a}, {0xe0001, 0xe0001}, + {0xe0020, 0xe007f}, {0xe0100, 0xe01ef}, + }; + + /* test for 8-bit control characters */ + if ( ucs == 0 ) { + return 0; + } + if ( ( ucs < 32 ) || ( ( ucs >= 0x7f ) && ( ucs < 0xa0 ) ) ) { + return -1; + } + + /* binary search in table of non-spacing characters */ + if ( bisearch( ucs, combining, sizeof( combining ) / sizeof( struct interval ) - 1 ) ) { + return 0; + } + + /* if we arrive here, ucs is not a combining or C0/C1 control character */ + return ( mk_is_wide_char( ucs ) ? 2 : 1 ); +} + diff --git a/extern/linenoise.hpp b/extern/linenoise.hpp deleted file mode 100644 index ae36eb0b2..000000000 --- a/extern/linenoise.hpp +++ /dev/null @@ -1,2415 +0,0 @@ -/* - * linenoise.hpp -- Multi-platfrom C++ header-only linenoise library. - * - * All credits and commendations have to go to the authors of the - * following excellent libraries. - * - * - linenoise.h and linenose.c (https://github.com/antirez/linenoise) - * - ANSI.c (https://github.com/adoxa/ansicon) - * - Win32_ANSI.h and Win32_ANSI.c (https://github.com/MSOpenTech/redis) - * - * ------------------------------------------------------------------------ - * - * Copyright (c) 2015 yhirose - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - * 1. Redistributions of source code must retain the above copyright notice, this - * list of conditions and the following disclaimer. - * 2. Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND - * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED - * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR - * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES - * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; - * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND - * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS - * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -/* linenoise.h -- guerrilla line editing library against the idea that a - * line editing lib needs to be 20,000 lines of C code. - * - * See linenoise.c for more information. - * - * ------------------------------------------------------------------------ - * - * Copyright (c) 2010, Salvatore Sanfilippo - * Copyright (c) 2010, Pieter Noordhuis - * - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -/* - * ANSI.c - ANSI escape sequence console driver. - * - * Copyright (C) 2005-2014 Jason Hood - * This software is provided 'as-is', without any express or implied - * warranty. In no event will the author be held liable for any damages - * arising from the use of this software. - * - * Permission is granted to anyone to use this software for any purpose, - * including commercial applications, and to alter it and redistribute it - * freely, subject to the following restrictions: - * - * 1. The origin of this software must not be misrepresented; you must not - * claim that you wrote the original software. If you use this software - * in a product, an acknowledgment in the product documentation would be - * appreciated but is not required. - * 2. Altered source versions must be plainly marked as such, and must not be - * misrepresented as being the original software. - * 3. This notice may not be removed or altered from any source distribution. - * - * Jason Hood - * jadoxa@yahoo.com.au - */ - -/* - * Win32_ANSI.h and Win32_ANSI.c - * - * Derived from ANSI.c by Jason Hood, from his ansicon project (https://github.com/adoxa/ansicon), with modifications. - * - * Copyright (c), Microsoft Open Technologies, Inc. - * All rights reserved. - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are met: - * - Redistributions of source code must retain the above copyright notice, - * this list of conditions and the following disclaimer. - * - Redistributions in binary form must reproduce the above copyright notice, - * this list of conditions and the following disclaimer in the documentation - * and/or other materials provided with the distribution. - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" - * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE - * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE - * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE - * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL - * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR - * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER - * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, - * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -#ifndef LINENOISE_HPP -#define LINENOISE_HPP - -#ifndef _WIN32 -#include -#include -#include -#else -#ifndef NOMINMAX -#define NOMINMAX -#endif -#include -#include -#ifndef STDIN_FILENO -#define STDIN_FILENO (_fileno(stdin)) -#endif -#ifndef STDOUT_FILENO -#define STDOUT_FILENO 1 -#endif -#define isatty _isatty -#define write win32_write -#define read _read -#pragma warning(push) -#pragma warning(disable : 4996) -#endif -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -namespace linenoise { - -typedef std::function&)> CompletionCallback; - -#ifdef _WIN32 - -namespace ansi { - -#define lenof(array) (sizeof(array)/sizeof(*(array))) - -typedef struct -{ - BYTE foreground; // ANSI base color (0 to 7; add 30) - BYTE background; // ANSI base color (0 to 7; add 40) - BYTE bold; // console FOREGROUND_INTENSITY bit - BYTE underline; // console BACKGROUND_INTENSITY bit - BYTE rvideo; // swap foreground/bold & background/underline - BYTE concealed; // set foreground/bold to background/underline - BYTE reverse; // swap console foreground & background attributes -} GRM, *PGRM; // Graphic Rendition Mode - - -inline bool is_digit(char c) { return '0' <= c && c <= '9'; } - -// ========== Global variables and constants - -HANDLE hConOut; // handle to CONOUT$ - -const char ESC = '\x1B'; // ESCape character -const char BEL = '\x07'; -const char SO = '\x0E'; // Shift Out -const char SI = '\x0F'; // Shift In - -const int MAX_ARG = 16; // max number of args in an escape sequence -int state; // automata state -WCHAR prefix; // escape sequence prefix ( '[', ']' or '(' ); -WCHAR prefix2; // secondary prefix ( '?' or '>' ); -WCHAR suffix; // escape sequence suffix -int es_argc; // escape sequence args count -int es_argv[MAX_ARG]; // escape sequence args -WCHAR Pt_arg[MAX_PATH * 2]; // text parameter for Operating System Command -int Pt_len; -BOOL shifted; - - -// DEC Special Graphics Character Set from -// http://vt100.net/docs/vt220-rm/table2-4.html -// Some of these may not look right, depending on the font and code page (in -// particular, the Control Pictures probably won't work at all). -const WCHAR G1[] = -{ - ' ', // _ - blank - L'\x2666', // ` - Black Diamond Suit - L'\x2592', // a - Medium Shade - L'\x2409', // b - HT - L'\x240c', // c - FF - L'\x240d', // d - CR - L'\x240a', // e - LF - L'\x00b0', // f - Degree Sign - L'\x00b1', // g - Plus-Minus Sign - L'\x2424', // h - NL - L'\x240b', // i - VT - L'\x2518', // j - Box Drawings Light Up And Left - L'\x2510', // k - Box Drawings Light Down And Left - L'\x250c', // l - Box Drawings Light Down And Right - L'\x2514', // m - Box Drawings Light Up And Right - L'\x253c', // n - Box Drawings Light Vertical And Horizontal - L'\x00af', // o - SCAN 1 - Macron - L'\x25ac', // p - SCAN 3 - Black Rectangle - L'\x2500', // q - SCAN 5 - Box Drawings Light Horizontal - L'_', // r - SCAN 7 - Low Line - L'_', // s - SCAN 9 - Low Line - L'\x251c', // t - Box Drawings Light Vertical And Right - L'\x2524', // u - Box Drawings Light Vertical And Left - L'\x2534', // v - Box Drawings Light Up And Horizontal - L'\x252c', // w - Box Drawings Light Down And Horizontal - L'\x2502', // x - Box Drawings Light Vertical - L'\x2264', // y - Less-Than Or Equal To - L'\x2265', // z - Greater-Than Or Equal To - L'\x03c0', // { - Greek Small Letter Pi - L'\x2260', // | - Not Equal To - L'\x00a3', // } - Pound Sign - L'\x00b7', // ~ - Middle Dot -}; - -#define FIRST_G1 '_' -#define LAST_G1 '~' - - -// color constants - -#define FOREGROUND_BLACK 0 -#define FOREGROUND_WHITE FOREGROUND_RED|FOREGROUND_GREEN|FOREGROUND_BLUE - -#define BACKGROUND_BLACK 0 -#define BACKGROUND_WHITE BACKGROUND_RED|BACKGROUND_GREEN|BACKGROUND_BLUE - -const BYTE foregroundcolor[8] = - { - FOREGROUND_BLACK, // black foreground - FOREGROUND_RED, // red foreground - FOREGROUND_GREEN, // green foreground - FOREGROUND_RED | FOREGROUND_GREEN, // yellow foreground - FOREGROUND_BLUE, // blue foreground - FOREGROUND_BLUE | FOREGROUND_RED, // magenta foreground - FOREGROUND_BLUE | FOREGROUND_GREEN, // cyan foreground - FOREGROUND_WHITE // white foreground - }; - -const BYTE backgroundcolor[8] = - { - BACKGROUND_BLACK, // black background - BACKGROUND_RED, // red background - BACKGROUND_GREEN, // green background - BACKGROUND_RED | BACKGROUND_GREEN, // yellow background - BACKGROUND_BLUE, // blue background - BACKGROUND_BLUE | BACKGROUND_RED, // magenta background - BACKGROUND_BLUE | BACKGROUND_GREEN, // cyan background - BACKGROUND_WHITE, // white background - }; - -const BYTE attr2ansi[8] = // map console attribute to ANSI number -{ - 0, // black - 4, // blue - 2, // green - 6, // cyan - 1, // red - 5, // magenta - 3, // yellow - 7 // white -}; - -GRM grm; - -// saved cursor position -COORD SavePos; - -// ========== Print Buffer functions - -#define BUFFER_SIZE 2048 - -int nCharInBuffer; -WCHAR ChBuffer[BUFFER_SIZE]; - -//----------------------------------------------------------------------------- -// FlushBuffer() -// Writes the buffer to the console and empties it. -//----------------------------------------------------------------------------- - -inline void FlushBuffer(void) -{ - DWORD nWritten; - if (nCharInBuffer <= 0) return; - WriteConsoleW(hConOut, ChBuffer, nCharInBuffer, &nWritten, NULL); - nCharInBuffer = 0; -} - -//----------------------------------------------------------------------------- -// PushBuffer( WCHAR c ) -// Adds a character in the buffer. -//----------------------------------------------------------------------------- - -inline void PushBuffer(WCHAR c) -{ - if (shifted && c >= FIRST_G1 && c <= LAST_G1) - c = G1[c - FIRST_G1]; - ChBuffer[nCharInBuffer] = c; - if (++nCharInBuffer == BUFFER_SIZE) - FlushBuffer(); -} - -//----------------------------------------------------------------------------- -// SendSequence( LPCWSTR seq ) -// Send the string to the input buffer. -//----------------------------------------------------------------------------- - -inline void SendSequence(LPCWSTR seq) -{ - DWORD out; - INPUT_RECORD in; - HANDLE hStdIn = GetStdHandle(STD_INPUT_HANDLE); - - in.EventType = KEY_EVENT; - in.Event.KeyEvent.bKeyDown = TRUE; - in.Event.KeyEvent.wRepeatCount = 1; - in.Event.KeyEvent.wVirtualKeyCode = 0; - in.Event.KeyEvent.wVirtualScanCode = 0; - in.Event.KeyEvent.dwControlKeyState = 0; - for (; *seq; ++seq) - { - in.Event.KeyEvent.uChar.UnicodeChar = *seq; - WriteConsoleInput(hStdIn, &in, 1, &out); - } -} - -// ========== Print functions - -//----------------------------------------------------------------------------- -// InterpretEscSeq() -// Interprets the last escape sequence scanned by ParseAndPrintANSIString -// prefix escape sequence prefix -// es_argc escape sequence args count -// es_argv[] escape sequence args array -// suffix escape sequence suffix -// -// for instance, with \e[33;45;1m we have -// prefix = '[', -// es_argc = 3, es_argv[0] = 33, es_argv[1] = 45, es_argv[2] = 1 -// suffix = 'm' -//----------------------------------------------------------------------------- - -inline void InterpretEscSeq(void) -{ - int i; - WORD attribut; - CONSOLE_SCREEN_BUFFER_INFO Info; - CONSOLE_CURSOR_INFO CursInfo; - DWORD len, NumberOfCharsWritten; - COORD Pos; - SMALL_RECT Rect; - CHAR_INFO CharInfo; - - if (prefix == '[') - { - if (prefix2 == '?' && (suffix == 'h' || suffix == 'l')) - { - if (es_argc == 1 && es_argv[0] == 25) - { - GetConsoleCursorInfo(hConOut, &CursInfo); - CursInfo.bVisible = (suffix == 'h'); - SetConsoleCursorInfo(hConOut, &CursInfo); - return; - } - } - // Ignore any other \e[? or \e[> sequences. - if (prefix2 != 0) - return; - - GetConsoleScreenBufferInfo(hConOut, &Info); - switch (suffix) - { - case 'm': - if (es_argc == 0) es_argv[es_argc++] = 0; - for (i = 0; i < es_argc; i++) - { - if (30 <= es_argv[i] && es_argv[i] <= 37) - grm.foreground = es_argv[i] - 30; - else if (40 <= es_argv[i] && es_argv[i] <= 47) - grm.background = es_argv[i] - 40; - else switch (es_argv[i]) - { - case 0: - case 39: - case 49: - { - WCHAR def[4]; - int a; - *def = '7'; def[1] = '\0'; - GetEnvironmentVariableW(L"ANSICON_DEF", def, lenof(def)); - a = wcstol(def, NULL, 16); - grm.reverse = FALSE; - if (a < 0) - { - grm.reverse = TRUE; - a = -a; - } - if (es_argv[i] != 49) - grm.foreground = attr2ansi[a & 7]; - if (es_argv[i] != 39) - grm.background = attr2ansi[(a >> 4) & 7]; - if (es_argv[i] == 0) - { - if (es_argc == 1) - { - grm.bold = a & FOREGROUND_INTENSITY; - grm.underline = a & BACKGROUND_INTENSITY; - } - else - { - grm.bold = 0; - grm.underline = 0; - } - grm.rvideo = 0; - grm.concealed = 0; - } - } - break; - - case 1: grm.bold = FOREGROUND_INTENSITY; break; - case 5: // blink - case 4: grm.underline = BACKGROUND_INTENSITY; break; - case 7: grm.rvideo = 1; break; - case 8: grm.concealed = 1; break; - case 21: // oops, this actually turns on double underline - case 22: grm.bold = 0; break; - case 25: - case 24: grm.underline = 0; break; - case 27: grm.rvideo = 0; break; - case 28: grm.concealed = 0; break; - } - } - if (grm.concealed) - { - if (grm.rvideo) - { - attribut = foregroundcolor[grm.foreground] - | backgroundcolor[grm.foreground]; - if (grm.bold) - attribut |= FOREGROUND_INTENSITY | BACKGROUND_INTENSITY; - } - else - { - attribut = foregroundcolor[grm.background] - | backgroundcolor[grm.background]; - if (grm.underline) - attribut |= FOREGROUND_INTENSITY | BACKGROUND_INTENSITY; - } - } - else if (grm.rvideo) - { - attribut = foregroundcolor[grm.background] - | backgroundcolor[grm.foreground]; - if (grm.bold) - attribut |= BACKGROUND_INTENSITY; - if (grm.underline) - attribut |= FOREGROUND_INTENSITY; - } - else - attribut = foregroundcolor[grm.foreground] | grm.bold - | backgroundcolor[grm.background] | grm.underline; - if (grm.reverse) - attribut = ((attribut >> 4) & 15) | ((attribut & 15) << 4); - SetConsoleTextAttribute(hConOut, attribut); - return; - - case 'J': - if (es_argc == 0) es_argv[es_argc++] = 0; // ESC[J == ESC[0J - if (es_argc != 1) return; - switch (es_argv[0]) - { - case 0: // ESC[0J erase from cursor to end of display - len = (Info.dwSize.Y - Info.dwCursorPosition.Y - 1) * Info.dwSize.X - + Info.dwSize.X - Info.dwCursorPosition.X - 1; - FillConsoleOutputCharacter(hConOut, ' ', len, - Info.dwCursorPosition, - &NumberOfCharsWritten); - FillConsoleOutputAttribute(hConOut, Info.wAttributes, len, - Info.dwCursorPosition, - &NumberOfCharsWritten); - return; - - case 1: // ESC[1J erase from start to cursor. - Pos.X = 0; - Pos.Y = 0; - len = Info.dwCursorPosition.Y * Info.dwSize.X - + Info.dwCursorPosition.X + 1; - FillConsoleOutputCharacter(hConOut, ' ', len, Pos, - &NumberOfCharsWritten); - FillConsoleOutputAttribute(hConOut, Info.wAttributes, len, Pos, - &NumberOfCharsWritten); - return; - - case 2: // ESC[2J Clear screen and home cursor - Pos.X = 0; - Pos.Y = 0; - len = Info.dwSize.X * Info.dwSize.Y; - FillConsoleOutputCharacter(hConOut, ' ', len, Pos, - &NumberOfCharsWritten); - FillConsoleOutputAttribute(hConOut, Info.wAttributes, len, Pos, - &NumberOfCharsWritten); - SetConsoleCursorPosition(hConOut, Pos); - return; - - default: - return; - } - - case 'K': - if (es_argc == 0) es_argv[es_argc++] = 0; // ESC[K == ESC[0K - if (es_argc != 1) return; - switch (es_argv[0]) - { - case 0: // ESC[0K Clear to end of line - len = Info.dwSize.X - Info.dwCursorPosition.X + 1; - FillConsoleOutputCharacter(hConOut, ' ', len, - Info.dwCursorPosition, - &NumberOfCharsWritten); - FillConsoleOutputAttribute(hConOut, Info.wAttributes, len, - Info.dwCursorPosition, - &NumberOfCharsWritten); - return; - - case 1: // ESC[1K Clear from start of line to cursor - Pos.X = 0; - Pos.Y = Info.dwCursorPosition.Y; - FillConsoleOutputCharacter(hConOut, ' ', - Info.dwCursorPosition.X + 1, Pos, - &NumberOfCharsWritten); - FillConsoleOutputAttribute(hConOut, Info.wAttributes, - Info.dwCursorPosition.X + 1, Pos, - &NumberOfCharsWritten); - return; - - case 2: // ESC[2K Clear whole line. - Pos.X = 0; - Pos.Y = Info.dwCursorPosition.Y; - FillConsoleOutputCharacter(hConOut, ' ', Info.dwSize.X, Pos, - &NumberOfCharsWritten); - FillConsoleOutputAttribute(hConOut, Info.wAttributes, - Info.dwSize.X, Pos, - &NumberOfCharsWritten); - return; - - default: - return; - } - - case 'X': // ESC[#X Erase # characters. - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[X == ESC[1X - if (es_argc != 1) return; - FillConsoleOutputCharacter(hConOut, ' ', es_argv[0], - Info.dwCursorPosition, - &NumberOfCharsWritten); - FillConsoleOutputAttribute(hConOut, Info.wAttributes, es_argv[0], - Info.dwCursorPosition, - &NumberOfCharsWritten); - return; - - case 'L': // ESC[#L Insert # blank lines. - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[L == ESC[1L - if (es_argc != 1) return; - Rect.Left = 0; - Rect.Top = Info.dwCursorPosition.Y; - Rect.Right = Info.dwSize.X - 1; - Rect.Bottom = Info.dwSize.Y - 1; - Pos.X = 0; - Pos.Y = Info.dwCursorPosition.Y + es_argv[0]; - CharInfo.Char.UnicodeChar = ' '; - CharInfo.Attributes = Info.wAttributes; - ScrollConsoleScreenBuffer(hConOut, &Rect, NULL, Pos, &CharInfo); - return; - - case 'M': // ESC[#M Delete # lines. - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[M == ESC[1M - if (es_argc != 1) return; - if (es_argv[0] > Info.dwSize.Y - Info.dwCursorPosition.Y) - es_argv[0] = Info.dwSize.Y - Info.dwCursorPosition.Y; - Rect.Left = 0; - Rect.Top = Info.dwCursorPosition.Y + es_argv[0]; - Rect.Right = Info.dwSize.X - 1; - Rect.Bottom = Info.dwSize.Y - 1; - Pos.X = 0; - Pos.Y = Info.dwCursorPosition.Y; - CharInfo.Char.UnicodeChar = ' '; - CharInfo.Attributes = Info.wAttributes; - ScrollConsoleScreenBuffer(hConOut, &Rect, NULL, Pos, &CharInfo); - return; - - case 'P': // ESC[#P Delete # characters. - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[P == ESC[1P - if (es_argc != 1) return; - if (Info.dwCursorPosition.X + es_argv[0] > Info.dwSize.X - 1) - es_argv[0] = Info.dwSize.X - Info.dwCursorPosition.X; - Rect.Left = Info.dwCursorPosition.X + es_argv[0]; - Rect.Top = Info.dwCursorPosition.Y; - Rect.Right = Info.dwSize.X - 1; - Rect.Bottom = Info.dwCursorPosition.Y; - CharInfo.Char.UnicodeChar = ' '; - CharInfo.Attributes = Info.wAttributes; - ScrollConsoleScreenBuffer(hConOut, &Rect, NULL, Info.dwCursorPosition, - &CharInfo); - return; - - case '@': // ESC[#@ Insert # blank characters. - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[@ == ESC[1@ - if (es_argc != 1) return; - if (Info.dwCursorPosition.X + es_argv[0] > Info.dwSize.X - 1) - es_argv[0] = Info.dwSize.X - Info.dwCursorPosition.X; - Rect.Left = Info.dwCursorPosition.X; - Rect.Top = Info.dwCursorPosition.Y; - Rect.Right = Info.dwSize.X - 1 - es_argv[0]; - Rect.Bottom = Info.dwCursorPosition.Y; - Pos.X = Info.dwCursorPosition.X + es_argv[0]; - Pos.Y = Info.dwCursorPosition.Y; - CharInfo.Char.UnicodeChar = ' '; - CharInfo.Attributes = Info.wAttributes; - ScrollConsoleScreenBuffer(hConOut, &Rect, NULL, Pos, &CharInfo); - return; - - case 'k': // ESC[#k - case 'A': // ESC[#A Moves cursor up # lines - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[A == ESC[1A - if (es_argc != 1) return; - Pos.Y = Info.dwCursorPosition.Y - es_argv[0]; - if (Pos.Y < 0) Pos.Y = 0; - Pos.X = Info.dwCursorPosition.X; - SetConsoleCursorPosition(hConOut, Pos); - return; - - case 'e': // ESC[#e - case 'B': // ESC[#B Moves cursor down # lines - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[B == ESC[1B - if (es_argc != 1) return; - Pos.Y = Info.dwCursorPosition.Y + es_argv[0]; - if (Pos.Y >= Info.dwSize.Y) Pos.Y = Info.dwSize.Y - 1; - Pos.X = Info.dwCursorPosition.X; - SetConsoleCursorPosition(hConOut, Pos); - return; - - case 'a': // ESC[#a - case 'C': // ESC[#C Moves cursor forward # spaces - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[C == ESC[1C - if (es_argc != 1) return; - Pos.X = Info.dwCursorPosition.X + es_argv[0]; - if (Pos.X >= Info.dwSize.X) Pos.X = Info.dwSize.X - 1; - Pos.Y = Info.dwCursorPosition.Y; - SetConsoleCursorPosition(hConOut, Pos); - return; - - case 'j': // ESC[#j - case 'D': // ESC[#D Moves cursor back # spaces - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[D == ESC[1D - if (es_argc != 1) return; - Pos.X = Info.dwCursorPosition.X - es_argv[0]; - if (Pos.X < 0) Pos.X = 0; - Pos.Y = Info.dwCursorPosition.Y; - SetConsoleCursorPosition(hConOut, Pos); - return; - - case 'E': // ESC[#E Moves cursor down # lines, column 1. - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[E == ESC[1E - if (es_argc != 1) return; - Pos.Y = Info.dwCursorPosition.Y + es_argv[0]; - if (Pos.Y >= Info.dwSize.Y) Pos.Y = Info.dwSize.Y - 1; - Pos.X = 0; - SetConsoleCursorPosition(hConOut, Pos); - return; - - case 'F': // ESC[#F Moves cursor up # lines, column 1. - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[F == ESC[1F - if (es_argc != 1) return; - Pos.Y = Info.dwCursorPosition.Y - es_argv[0]; - if (Pos.Y < 0) Pos.Y = 0; - Pos.X = 0; - SetConsoleCursorPosition(hConOut, Pos); - return; - - case '`': // ESC[#` - case 'G': // ESC[#G Moves cursor column # in current row. - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[G == ESC[1G - if (es_argc != 1) return; - Pos.X = es_argv[0] - 1; - if (Pos.X >= Info.dwSize.X) Pos.X = Info.dwSize.X - 1; - if (Pos.X < 0) Pos.X = 0; - Pos.Y = Info.dwCursorPosition.Y; - SetConsoleCursorPosition(hConOut, Pos); - return; - - case 'd': // ESC[#d Moves cursor row #, current column. - if (es_argc == 0) es_argv[es_argc++] = 1; // ESC[d == ESC[1d - if (es_argc != 1) return; - Pos.Y = es_argv[0] - 1; - if (Pos.Y < 0) Pos.Y = 0; - if (Pos.Y >= Info.dwSize.Y) Pos.Y = Info.dwSize.Y - 1; - SetConsoleCursorPosition(hConOut, Pos); - return; - - case 'f': // ESC[#;#f - case 'H': // ESC[#;#H Moves cursor to line #, column # - if (es_argc == 0) - es_argv[es_argc++] = 1; // ESC[H == ESC[1;1H - if (es_argc == 1) - es_argv[es_argc++] = 1; // ESC[#H == ESC[#;1H - if (es_argc > 2) return; - Pos.X = es_argv[1] - 1; - if (Pos.X < 0) Pos.X = 0; - if (Pos.X >= Info.dwSize.X) Pos.X = Info.dwSize.X - 1; - Pos.Y = es_argv[0] - 1; - if (Pos.Y < 0) Pos.Y = 0; - if (Pos.Y >= Info.dwSize.Y) Pos.Y = Info.dwSize.Y - 1; - SetConsoleCursorPosition(hConOut, Pos); - return; - - case 's': // ESC[s Saves cursor position for recall later - if (es_argc != 0) return; - SavePos = Info.dwCursorPosition; - return; - - case 'u': // ESC[u Return to saved cursor position - if (es_argc != 0) return; - SetConsoleCursorPosition(hConOut, SavePos); - return; - - case 'n': // ESC[#n Device status report - if (es_argc != 1) return; // ESC[n == ESC[0n -> ignored - switch (es_argv[0]) - { - case 5: // ESC[5n Report status - SendSequence(L"\33[0n"); // "OK" - return; - - case 6: // ESC[6n Report cursor position - { - WCHAR buf[32]; - swprintf(buf, 32, L"\33[%d;%dR", Info.dwCursorPosition.Y + 1, - Info.dwCursorPosition.X + 1); - SendSequence(buf); - } - return; - - default: - return; - } - - case 't': // ESC[#t Window manipulation - if (es_argc != 1) return; - if (es_argv[0] == 21) // ESC[21t Report xterm window's title - { - WCHAR buf[MAX_PATH * 2]; - DWORD len = GetConsoleTitleW(buf + 3, lenof(buf) - 3 - 2); - // Too bad if it's too big or fails. - buf[0] = ESC; - buf[1] = ']'; - buf[2] = 'l'; - buf[3 + len] = ESC; - buf[3 + len + 1] = '\\'; - buf[3 + len + 2] = '\0'; - SendSequence(buf); - } - return; - - default: - return; - } - } - else // (prefix == ']') - { - // Ignore any \e]? or \e]> sequences. - if (prefix2 != 0) - return; - - if (es_argc == 1 && es_argv[0] == 0) // ESC]0;titleST - { - SetConsoleTitleW(Pt_arg); - } - } -} - -//----------------------------------------------------------------------------- -// ParseAndPrintANSIString(hDev, lpBuffer, nNumberOfBytesToWrite) -// Parses the string lpBuffer, interprets the escapes sequences and prints the -// characters in the device hDev (console). -// The lexer is a three states automata. -// If the number of arguments es_argc > MAX_ARG, only the MAX_ARG-1 firsts and -// the last arguments are processed (no es_argv[] overflow). -//----------------------------------------------------------------------------- - -inline BOOL ParseAndPrintANSIString(HANDLE hDev, LPCVOID lpBuffer, DWORD nNumberOfBytesToWrite, LPDWORD lpNumberOfBytesWritten) -{ - DWORD i; - LPCSTR s; - - if (hDev != hConOut) // reinit if device has changed - { - hConOut = hDev; - state = 1; - shifted = FALSE; - } - for (i = nNumberOfBytesToWrite, s = (LPCSTR)lpBuffer; i > 0; i--, s++) - { - if (state == 1) - { - if (*s == ESC) state = 2; - else if (*s == SO) shifted = TRUE; - else if (*s == SI) shifted = FALSE; - else PushBuffer(*s); - } - else if (state == 2) - { - if (*s == ESC); // \e\e...\e == \e - else if ((*s == '[') || (*s == ']')) - { - FlushBuffer(); - prefix = *s; - prefix2 = 0; - state = 3; - Pt_len = 0; - *Pt_arg = '\0'; - } - else if (*s == ')' || *s == '(') state = 6; - else state = 1; - } - else if (state == 3) - { - if (is_digit(*s)) - { - es_argc = 0; - es_argv[0] = *s - '0'; - state = 4; - } - else if (*s == ';') - { - es_argc = 1; - es_argv[0] = 0; - es_argv[1] = 0; - state = 4; - } - else if (*s == '?' || *s == '>') - { - prefix2 = *s; - } - else - { - es_argc = 0; - suffix = *s; - InterpretEscSeq(); - state = 1; - } - } - else if (state == 4) - { - if (is_digit(*s)) - { - es_argv[es_argc] = 10 * es_argv[es_argc] + (*s - '0'); - } - else if (*s == ';') - { - if (es_argc < MAX_ARG - 1) es_argc++; - es_argv[es_argc] = 0; - if (prefix == ']') - state = 5; - } - else - { - es_argc++; - suffix = *s; - InterpretEscSeq(); - state = 1; - } - } - else if (state == 5) - { - if (*s == BEL) - { - Pt_arg[Pt_len] = '\0'; - InterpretEscSeq(); - state = 1; - } - else if (*s == '\\' && Pt_len > 0 && Pt_arg[Pt_len - 1] == ESC) - { - Pt_arg[--Pt_len] = '\0'; - InterpretEscSeq(); - state = 1; - } - else if (Pt_len < lenof(Pt_arg) - 1) - Pt_arg[Pt_len++] = *s; - } - else if (state == 6) - { - // Ignore it (ESC ) 0 is implicit; nothing else is supported). - state = 1; - } - } - FlushBuffer(); - if (lpNumberOfBytesWritten != NULL) - *lpNumberOfBytesWritten = nNumberOfBytesToWrite - i; - return (i == 0); -} - -} // namespace ansi - -HANDLE hOut; -HANDLE hIn; -DWORD consolemodeIn = 0; - -inline int win32read(int *c) { - DWORD foo; - INPUT_RECORD b; - KEY_EVENT_RECORD e; - BOOL altgr; - - while (1) { - if (!ReadConsoleInput(hIn, &b, 1, &foo)) return 0; - if (!foo) return 0; - - if (b.EventType == KEY_EVENT && b.Event.KeyEvent.bKeyDown) { - - e = b.Event.KeyEvent; - *c = b.Event.KeyEvent.uChar.AsciiChar; - - altgr = e.dwControlKeyState & (LEFT_CTRL_PRESSED | RIGHT_ALT_PRESSED); - - if (e.dwControlKeyState & (LEFT_CTRL_PRESSED | RIGHT_CTRL_PRESSED) && !altgr) { - - /* Ctrl+Key */ - switch (*c) { - case 'D': - *c = 4; - return 1; - case 'C': - *c = 3; - return 1; - case 'H': - *c = 8; - return 1; - case 'T': - *c = 20; - return 1; - case 'B': /* ctrl-b, left_arrow */ - *c = 2; - return 1; - case 'F': /* ctrl-f right_arrow*/ - *c = 6; - return 1; - case 'P': /* ctrl-p up_arrow*/ - *c = 16; - return 1; - case 'N': /* ctrl-n down_arrow*/ - *c = 14; - return 1; - case 'U': /* Ctrl+u, delete the whole line. */ - *c = 21; - return 1; - case 'K': /* Ctrl+k, delete from current to end of line. */ - *c = 11; - return 1; - case 'A': /* Ctrl+a, go to the start of the line */ - *c = 1; - return 1; - case 'E': /* ctrl+e, go to the end of the line */ - *c = 5; - return 1; - } - - /* Other Ctrl+KEYs ignored */ - } else { - - switch (e.wVirtualKeyCode) { - - case VK_ESCAPE: /* ignore - send ctrl-c, will return -1 */ - *c = 3; - return 1; - case VK_RETURN: /* enter */ - *c = 13; - return 1; - case VK_LEFT: /* left */ - *c = 2; - return 1; - case VK_RIGHT: /* right */ - *c = 6; - return 1; - case VK_UP: /* up */ - *c = 16; - return 1; - case VK_DOWN: /* down */ - *c = 14; - return 1; - case VK_HOME: - *c = 1; - return 1; - case VK_END: - *c = 5; - return 1; - case VK_BACK: - *c = 8; - return 1; - case VK_DELETE: - *c = 4; /* same as Ctrl+D above */ - return 1; - default: - if (*c) return 1; - } - } - } - } - - return -1; /* Makes compiler happy */ -} - -inline int win32_write(int fd, const void *buffer, unsigned int count) { - if (fd == _fileno(stdout)) { - DWORD bytesWritten = 0; - if (FALSE != ansi::ParseAndPrintANSIString(GetStdHandle(STD_OUTPUT_HANDLE), buffer, (DWORD)count, &bytesWritten)) { - return (int)bytesWritten; - } else { - errno = GetLastError(); - return 0; - } - } else if (fd == _fileno(stderr)) { - DWORD bytesWritten = 0; - if (FALSE != ansi::ParseAndPrintANSIString(GetStdHandle(STD_ERROR_HANDLE), buffer, (DWORD)count, &bytesWritten)) { - return (int)bytesWritten; - } else { - errno = GetLastError(); - return 0; - } - } else { - return _write(fd, buffer, count); - } -} -#endif // _WIN32 - -#define LINENOISE_DEFAULT_HISTORY_MAX_LEN 100 -#define LINENOISE_MAX_LINE 4096 -static const char *unsupported_term[] = {"dumb","cons25","emacs",NULL}; -static CompletionCallback completionCallback; - -#ifndef _WIN32 -static struct termios orig_termios; /* In order to restore at exit.*/ -#endif -static bool rawmode = false; /* For atexit() function to check if restore is needed*/ -static bool mlmode = false; /* Multi line mode. Default is single line. */ -static bool atexit_registered = false; /* Register atexit just 1 time. */ -static size_t history_max_len = LINENOISE_DEFAULT_HISTORY_MAX_LEN; -static std::vector history; - -/* The linenoiseState structure represents the state during line editing. - * We pass this state to functions implementing specific editing - * functionalities. */ -struct linenoiseState { - int ifd; /* Terminal stdin file descriptor. */ - int ofd; /* Terminal stdout file descriptor. */ - char *buf; /* Edited line buffer. */ - int buflen; /* Edited line buffer size. */ - std::string prompt; /* Prompt to display. */ - int pos; /* Current cursor position. */ - int oldcolpos; /* Previous refresh cursor column position. */ - int len; /* Current edited line length. */ - int cols; /* Number of columns in terminal. */ - int maxrows; /* Maximum num of rows used so far (multiline mode) */ - int history_index; /* The history index we are currently editing. */ -}; - -enum KEY_ACTION { - KEY_NULL = 0, /* NULL */ - CTRL_A = 1, /* Ctrl+a */ - CTRL_B = 2, /* Ctrl-b */ - CTRL_C = 3, /* Ctrl-c */ - CTRL_D = 4, /* Ctrl-d */ - CTRL_E = 5, /* Ctrl-e */ - CTRL_F = 6, /* Ctrl-f */ - CTRL_H = 8, /* Ctrl-h */ - TAB = 9, /* Tab */ - CTRL_K = 11, /* Ctrl+k */ - CTRL_L = 12, /* Ctrl+l */ - ENTER = 13, /* Enter */ - CTRL_N = 14, /* Ctrl-n */ - CTRL_P = 16, /* Ctrl-p */ - CTRL_T = 20, /* Ctrl-t */ - CTRL_U = 21, /* Ctrl+u */ - CTRL_W = 23, /* Ctrl+w */ - ESC = 27, /* Escape */ - BACKSPACE = 127 /* Backspace */ -}; - -void linenoiseAtExit(void); -bool AddHistory(const char *line); -void refreshLine(struct linenoiseState *l); - -/* ============================ UTF8 utilities ============================== */ - -static unsigned long unicodeWideCharTable[][2] = { - { 0x1100, 0x115F }, { 0x2329, 0x232A }, { 0x2E80, 0x2E99, }, { 0x2E9B, 0x2EF3, }, - { 0x2F00, 0x2FD5, }, { 0x2FF0, 0x2FFB, }, { 0x3000, 0x303E, }, { 0x3041, 0x3096, }, - { 0x3099, 0x30FF, }, { 0x3105, 0x312D, }, { 0x3131, 0x318E, }, { 0x3190, 0x31BA, }, - { 0x31C0, 0x31E3, }, { 0x31F0, 0x321E, }, { 0x3220, 0x3247, }, { 0x3250, 0x4DBF, }, - { 0x4E00, 0xA48C, }, { 0xA490, 0xA4C6, }, { 0xA960, 0xA97C, }, { 0xAC00, 0xD7A3, }, - { 0xF900, 0xFAFF, }, { 0xFE10, 0xFE19, }, { 0xFE30, 0xFE52, }, { 0xFE54, 0xFE66, }, - { 0xFE68, 0xFE6B, }, { 0xFF01, 0xFFE6, }, - { 0x1B000, 0x1B001, }, { 0x1F200, 0x1F202, }, { 0x1F210, 0x1F23A, }, - { 0x1F240, 0x1F248, }, { 0x1F250, 0x1F251, }, { 0x20000, 0x3FFFD, }, -}; - -static int unicodeWideCharTableSize = sizeof(unicodeWideCharTable) / sizeof(unicodeWideCharTable[0]); - -static int unicodeIsWideChar(unsigned long cp) -{ - int i; - for (i = 0; i < unicodeWideCharTableSize; i++) { - if (unicodeWideCharTable[i][0] <= cp && cp <= unicodeWideCharTable[i][1]) { - return 1; - } - } - return 0; -} - -static unsigned long unicodeCombiningCharTable[] = { - 0x0300,0x0301,0x0302,0x0303,0x0304,0x0305,0x0306,0x0307, - 0x0308,0x0309,0x030A,0x030B,0x030C,0x030D,0x030E,0x030F, - 0x0310,0x0311,0x0312,0x0313,0x0314,0x0315,0x0316,0x0317, - 0x0318,0x0319,0x031A,0x031B,0x031C,0x031D,0x031E,0x031F, - 0x0320,0x0321,0x0322,0x0323,0x0324,0x0325,0x0326,0x0327, - 0x0328,0x0329,0x032A,0x032B,0x032C,0x032D,0x032E,0x032F, - 0x0330,0x0331,0x0332,0x0333,0x0334,0x0335,0x0336,0x0337, - 0x0338,0x0339,0x033A,0x033B,0x033C,0x033D,0x033E,0x033F, - 0x0340,0x0341,0x0342,0x0343,0x0344,0x0345,0x0346,0x0347, - 0x0348,0x0349,0x034A,0x034B,0x034C,0x034D,0x034E,0x034F, - 0x0350,0x0351,0x0352,0x0353,0x0354,0x0355,0x0356,0x0357, - 0x0358,0x0359,0x035A,0x035B,0x035C,0x035D,0x035E,0x035F, - 0x0360,0x0361,0x0362,0x0363,0x0364,0x0365,0x0366,0x0367, - 0x0368,0x0369,0x036A,0x036B,0x036C,0x036D,0x036E,0x036F, - 0x0483,0x0484,0x0485,0x0486,0x0487,0x0591,0x0592,0x0593, - 0x0594,0x0595,0x0596,0x0597,0x0598,0x0599,0x059A,0x059B, - 0x059C,0x059D,0x059E,0x059F,0x05A0,0x05A1,0x05A2,0x05A3, - 0x05A4,0x05A5,0x05A6,0x05A7,0x05A8,0x05A9,0x05AA,0x05AB, - 0x05AC,0x05AD,0x05AE,0x05AF,0x05B0,0x05B1,0x05B2,0x05B3, - 0x05B4,0x05B5,0x05B6,0x05B7,0x05B8,0x05B9,0x05BA,0x05BB, - 0x05BC,0x05BD,0x05BF,0x05C1,0x05C2,0x05C4,0x05C5,0x05C7, - 0x0610,0x0611,0x0612,0x0613,0x0614,0x0615,0x0616,0x0617, - 0x0618,0x0619,0x061A,0x064B,0x064C,0x064D,0x064E,0x064F, - 0x0650,0x0651,0x0652,0x0653,0x0654,0x0655,0x0656,0x0657, - 0x0658,0x0659,0x065A,0x065B,0x065C,0x065D,0x065E,0x065F, - 0x0670,0x06D6,0x06D7,0x06D8,0x06D9,0x06DA,0x06DB,0x06DC, - 0x06DF,0x06E0,0x06E1,0x06E2,0x06E3,0x06E4,0x06E7,0x06E8, - 0x06EA,0x06EB,0x06EC,0x06ED,0x0711,0x0730,0x0731,0x0732, - 0x0733,0x0734,0x0735,0x0736,0x0737,0x0738,0x0739,0x073A, - 0x073B,0x073C,0x073D,0x073E,0x073F,0x0740,0x0741,0x0742, - 0x0743,0x0744,0x0745,0x0746,0x0747,0x0748,0x0749,0x074A, - 0x07A6,0x07A7,0x07A8,0x07A9,0x07AA,0x07AB,0x07AC,0x07AD, - 0x07AE,0x07AF,0x07B0,0x07EB,0x07EC,0x07ED,0x07EE,0x07EF, - 0x07F0,0x07F1,0x07F2,0x07F3,0x0816,0x0817,0x0818,0x0819, - 0x081B,0x081C,0x081D,0x081E,0x081F,0x0820,0x0821,0x0822, - 0x0823,0x0825,0x0826,0x0827,0x0829,0x082A,0x082B,0x082C, - 0x082D,0x0859,0x085A,0x085B,0x08E3,0x08E4,0x08E5,0x08E6, - 0x08E7,0x08E8,0x08E9,0x08EA,0x08EB,0x08EC,0x08ED,0x08EE, - 0x08EF,0x08F0,0x08F1,0x08F2,0x08F3,0x08F4,0x08F5,0x08F6, - 0x08F7,0x08F8,0x08F9,0x08FA,0x08FB,0x08FC,0x08FD,0x08FE, - 0x08FF,0x0900,0x0901,0x0902,0x093A,0x093C,0x0941,0x0942, - 0x0943,0x0944,0x0945,0x0946,0x0947,0x0948,0x094D,0x0951, - 0x0952,0x0953,0x0954,0x0955,0x0956,0x0957,0x0962,0x0963, - 0x0981,0x09BC,0x09C1,0x09C2,0x09C3,0x09C4,0x09CD,0x09E2, - 0x09E3,0x0A01,0x0A02,0x0A3C,0x0A41,0x0A42,0x0A47,0x0A48, - 0x0A4B,0x0A4C,0x0A4D,0x0A51,0x0A70,0x0A71,0x0A75,0x0A81, - 0x0A82,0x0ABC,0x0AC1,0x0AC2,0x0AC3,0x0AC4,0x0AC5,0x0AC7, - 0x0AC8,0x0ACD,0x0AE2,0x0AE3,0x0B01,0x0B3C,0x0B3F,0x0B41, - 0x0B42,0x0B43,0x0B44,0x0B4D,0x0B56,0x0B62,0x0B63,0x0B82, - 0x0BC0,0x0BCD,0x0C00,0x0C3E,0x0C3F,0x0C40,0x0C46,0x0C47, - 0x0C48,0x0C4A,0x0C4B,0x0C4C,0x0C4D,0x0C55,0x0C56,0x0C62, - 0x0C63,0x0C81,0x0CBC,0x0CBF,0x0CC6,0x0CCC,0x0CCD,0x0CE2, - 0x0CE3,0x0D01,0x0D41,0x0D42,0x0D43,0x0D44,0x0D4D,0x0D62, - 0x0D63,0x0DCA,0x0DD2,0x0DD3,0x0DD4,0x0DD6,0x0E31,0x0E34, - 0x0E35,0x0E36,0x0E37,0x0E38,0x0E39,0x0E3A,0x0E47,0x0E48, - 0x0E49,0x0E4A,0x0E4B,0x0E4C,0x0E4D,0x0E4E,0x0EB1,0x0EB4, - 0x0EB5,0x0EB6,0x0EB7,0x0EB8,0x0EB9,0x0EBB,0x0EBC,0x0EC8, - 0x0EC9,0x0ECA,0x0ECB,0x0ECC,0x0ECD,0x0F18,0x0F19,0x0F35, - 0x0F37,0x0F39,0x0F71,0x0F72,0x0F73,0x0F74,0x0F75,0x0F76, - 0x0F77,0x0F78,0x0F79,0x0F7A,0x0F7B,0x0F7C,0x0F7D,0x0F7E, - 0x0F80,0x0F81,0x0F82,0x0F83,0x0F84,0x0F86,0x0F87,0x0F8D, - 0x0F8E,0x0F8F,0x0F90,0x0F91,0x0F92,0x0F93,0x0F94,0x0F95, - 0x0F96,0x0F97,0x0F99,0x0F9A,0x0F9B,0x0F9C,0x0F9D,0x0F9E, - 0x0F9F,0x0FA0,0x0FA1,0x0FA2,0x0FA3,0x0FA4,0x0FA5,0x0FA6, - 0x0FA7,0x0FA8,0x0FA9,0x0FAA,0x0FAB,0x0FAC,0x0FAD,0x0FAE, - 0x0FAF,0x0FB0,0x0FB1,0x0FB2,0x0FB3,0x0FB4,0x0FB5,0x0FB6, - 0x0FB7,0x0FB8,0x0FB9,0x0FBA,0x0FBB,0x0FBC,0x0FC6,0x102D, - 0x102E,0x102F,0x1030,0x1032,0x1033,0x1034,0x1035,0x1036, - 0x1037,0x1039,0x103A,0x103D,0x103E,0x1058,0x1059,0x105E, - 0x105F,0x1060,0x1071,0x1072,0x1073,0x1074,0x1082,0x1085, - 0x1086,0x108D,0x109D,0x135D,0x135E,0x135F,0x1712,0x1713, - 0x1714,0x1732,0x1733,0x1734,0x1752,0x1753,0x1772,0x1773, - 0x17B4,0x17B5,0x17B7,0x17B8,0x17B9,0x17BA,0x17BB,0x17BC, - 0x17BD,0x17C6,0x17C9,0x17CA,0x17CB,0x17CC,0x17CD,0x17CE, - 0x17CF,0x17D0,0x17D1,0x17D2,0x17D3,0x17DD,0x180B,0x180C, - 0x180D,0x18A9,0x1920,0x1921,0x1922,0x1927,0x1928,0x1932, - 0x1939,0x193A,0x193B,0x1A17,0x1A18,0x1A1B,0x1A56,0x1A58, - 0x1A59,0x1A5A,0x1A5B,0x1A5C,0x1A5D,0x1A5E,0x1A60,0x1A62, - 0x1A65,0x1A66,0x1A67,0x1A68,0x1A69,0x1A6A,0x1A6B,0x1A6C, - 0x1A73,0x1A74,0x1A75,0x1A76,0x1A77,0x1A78,0x1A79,0x1A7A, - 0x1A7B,0x1A7C,0x1A7F,0x1AB0,0x1AB1,0x1AB2,0x1AB3,0x1AB4, - 0x1AB5,0x1AB6,0x1AB7,0x1AB8,0x1AB9,0x1ABA,0x1ABB,0x1ABC, - 0x1ABD,0x1B00,0x1B01,0x1B02,0x1B03,0x1B34,0x1B36,0x1B37, - 0x1B38,0x1B39,0x1B3A,0x1B3C,0x1B42,0x1B6B,0x1B6C,0x1B6D, - 0x1B6E,0x1B6F,0x1B70,0x1B71,0x1B72,0x1B73,0x1B80,0x1B81, - 0x1BA2,0x1BA3,0x1BA4,0x1BA5,0x1BA8,0x1BA9,0x1BAB,0x1BAC, - 0x1BAD,0x1BE6,0x1BE8,0x1BE9,0x1BED,0x1BEF,0x1BF0,0x1BF1, - 0x1C2C,0x1C2D,0x1C2E,0x1C2F,0x1C30,0x1C31,0x1C32,0x1C33, - 0x1C36,0x1C37,0x1CD0,0x1CD1,0x1CD2,0x1CD4,0x1CD5,0x1CD6, - 0x1CD7,0x1CD8,0x1CD9,0x1CDA,0x1CDB,0x1CDC,0x1CDD,0x1CDE, - 0x1CDF,0x1CE0,0x1CE2,0x1CE3,0x1CE4,0x1CE5,0x1CE6,0x1CE7, - 0x1CE8,0x1CED,0x1CF4,0x1CF8,0x1CF9,0x1DC0,0x1DC1,0x1DC2, - 0x1DC3,0x1DC4,0x1DC5,0x1DC6,0x1DC7,0x1DC8,0x1DC9,0x1DCA, - 0x1DCB,0x1DCC,0x1DCD,0x1DCE,0x1DCF,0x1DD0,0x1DD1,0x1DD2, - 0x1DD3,0x1DD4,0x1DD5,0x1DD6,0x1DD7,0x1DD8,0x1DD9,0x1DDA, - 0x1DDB,0x1DDC,0x1DDD,0x1DDE,0x1DDF,0x1DE0,0x1DE1,0x1DE2, - 0x1DE3,0x1DE4,0x1DE5,0x1DE6,0x1DE7,0x1DE8,0x1DE9,0x1DEA, - 0x1DEB,0x1DEC,0x1DED,0x1DEE,0x1DEF,0x1DF0,0x1DF1,0x1DF2, - 0x1DF3,0x1DF4,0x1DF5,0x1DFC,0x1DFD,0x1DFE,0x1DFF,0x20D0, - 0x20D1,0x20D2,0x20D3,0x20D4,0x20D5,0x20D6,0x20D7,0x20D8, - 0x20D9,0x20DA,0x20DB,0x20DC,0x20E1,0x20E5,0x20E6,0x20E7, - 0x20E8,0x20E9,0x20EA,0x20EB,0x20EC,0x20ED,0x20EE,0x20EF, - 0x20F0,0x2CEF,0x2CF0,0x2CF1,0x2D7F,0x2DE0,0x2DE1,0x2DE2, - 0x2DE3,0x2DE4,0x2DE5,0x2DE6,0x2DE7,0x2DE8,0x2DE9,0x2DEA, - 0x2DEB,0x2DEC,0x2DED,0x2DEE,0x2DEF,0x2DF0,0x2DF1,0x2DF2, - 0x2DF3,0x2DF4,0x2DF5,0x2DF6,0x2DF7,0x2DF8,0x2DF9,0x2DFA, - 0x2DFB,0x2DFC,0x2DFD,0x2DFE,0x2DFF,0x302A,0x302B,0x302C, - 0x302D,0x3099,0x309A,0xA66F,0xA674,0xA675,0xA676,0xA677, - 0xA678,0xA679,0xA67A,0xA67B,0xA67C,0xA67D,0xA69E,0xA69F, - 0xA6F0,0xA6F1,0xA802,0xA806,0xA80B,0xA825,0xA826,0xA8C4, - 0xA8E0,0xA8E1,0xA8E2,0xA8E3,0xA8E4,0xA8E5,0xA8E6,0xA8E7, - 0xA8E8,0xA8E9,0xA8EA,0xA8EB,0xA8EC,0xA8ED,0xA8EE,0xA8EF, - 0xA8F0,0xA8F1,0xA926,0xA927,0xA928,0xA929,0xA92A,0xA92B, - 0xA92C,0xA92D,0xA947,0xA948,0xA949,0xA94A,0xA94B,0xA94C, - 0xA94D,0xA94E,0xA94F,0xA950,0xA951,0xA980,0xA981,0xA982, - 0xA9B3,0xA9B6,0xA9B7,0xA9B8,0xA9B9,0xA9BC,0xA9E5,0xAA29, - 0xAA2A,0xAA2B,0xAA2C,0xAA2D,0xAA2E,0xAA31,0xAA32,0xAA35, - 0xAA36,0xAA43,0xAA4C,0xAA7C,0xAAB0,0xAAB2,0xAAB3,0xAAB4, - 0xAAB7,0xAAB8,0xAABE,0xAABF,0xAAC1,0xAAEC,0xAAED,0xAAF6, - 0xABE5,0xABE8,0xABED,0xFB1E,0xFE00,0xFE01,0xFE02,0xFE03, - 0xFE04,0xFE05,0xFE06,0xFE07,0xFE08,0xFE09,0xFE0A,0xFE0B, - 0xFE0C,0xFE0D,0xFE0E,0xFE0F,0xFE20,0xFE21,0xFE22,0xFE23, - 0xFE24,0xFE25,0xFE26,0xFE27,0xFE28,0xFE29,0xFE2A,0xFE2B, - 0xFE2C,0xFE2D,0xFE2E,0xFE2F, - 0x101FD,0x102E0,0x10376,0x10377,0x10378,0x10379,0x1037A,0x10A01, - 0x10A02,0x10A03,0x10A05,0x10A06,0x10A0C,0x10A0D,0x10A0E,0x10A0F, - 0x10A38,0x10A39,0x10A3A,0x10A3F,0x10AE5,0x10AE6,0x11001,0x11038, - 0x11039,0x1103A,0x1103B,0x1103C,0x1103D,0x1103E,0x1103F,0x11040, - 0x11041,0x11042,0x11043,0x11044,0x11045,0x11046,0x1107F,0x11080, - 0x11081,0x110B3,0x110B4,0x110B5,0x110B6,0x110B9,0x110BA,0x11100, - 0x11101,0x11102,0x11127,0x11128,0x11129,0x1112A,0x1112B,0x1112D, - 0x1112E,0x1112F,0x11130,0x11131,0x11132,0x11133,0x11134,0x11173, - 0x11180,0x11181,0x111B6,0x111B7,0x111B8,0x111B9,0x111BA,0x111BB, - 0x111BC,0x111BD,0x111BE,0x111CA,0x111CB,0x111CC,0x1122F,0x11230, - 0x11231,0x11234,0x11236,0x11237,0x112DF,0x112E3,0x112E4,0x112E5, - 0x112E6,0x112E7,0x112E8,0x112E9,0x112EA,0x11300,0x11301,0x1133C, - 0x11340,0x11366,0x11367,0x11368,0x11369,0x1136A,0x1136B,0x1136C, - 0x11370,0x11371,0x11372,0x11373,0x11374,0x114B3,0x114B4,0x114B5, - 0x114B6,0x114B7,0x114B8,0x114BA,0x114BF,0x114C0,0x114C2,0x114C3, - 0x115B2,0x115B3,0x115B4,0x115B5,0x115BC,0x115BD,0x115BF,0x115C0, - 0x115DC,0x115DD,0x11633,0x11634,0x11635,0x11636,0x11637,0x11638, - 0x11639,0x1163A,0x1163D,0x1163F,0x11640,0x116AB,0x116AD,0x116B0, - 0x116B1,0x116B2,0x116B3,0x116B4,0x116B5,0x116B7,0x1171D,0x1171E, - 0x1171F,0x11722,0x11723,0x11724,0x11725,0x11727,0x11728,0x11729, - 0x1172A,0x1172B,0x16AF0,0x16AF1,0x16AF2,0x16AF3,0x16AF4,0x16B30, - 0x16B31,0x16B32,0x16B33,0x16B34,0x16B35,0x16B36,0x16F8F,0x16F90, - 0x16F91,0x16F92,0x1BC9D,0x1BC9E,0x1D167,0x1D168,0x1D169,0x1D17B, - 0x1D17C,0x1D17D,0x1D17E,0x1D17F,0x1D180,0x1D181,0x1D182,0x1D185, - 0x1D186,0x1D187,0x1D188,0x1D189,0x1D18A,0x1D18B,0x1D1AA,0x1D1AB, - 0x1D1AC,0x1D1AD,0x1D242,0x1D243,0x1D244,0x1DA00,0x1DA01,0x1DA02, - 0x1DA03,0x1DA04,0x1DA05,0x1DA06,0x1DA07,0x1DA08,0x1DA09,0x1DA0A, - 0x1DA0B,0x1DA0C,0x1DA0D,0x1DA0E,0x1DA0F,0x1DA10,0x1DA11,0x1DA12, - 0x1DA13,0x1DA14,0x1DA15,0x1DA16,0x1DA17,0x1DA18,0x1DA19,0x1DA1A, - 0x1DA1B,0x1DA1C,0x1DA1D,0x1DA1E,0x1DA1F,0x1DA20,0x1DA21,0x1DA22, - 0x1DA23,0x1DA24,0x1DA25,0x1DA26,0x1DA27,0x1DA28,0x1DA29,0x1DA2A, - 0x1DA2B,0x1DA2C,0x1DA2D,0x1DA2E,0x1DA2F,0x1DA30,0x1DA31,0x1DA32, - 0x1DA33,0x1DA34,0x1DA35,0x1DA36,0x1DA3B,0x1DA3C,0x1DA3D,0x1DA3E, - 0x1DA3F,0x1DA40,0x1DA41,0x1DA42,0x1DA43,0x1DA44,0x1DA45,0x1DA46, - 0x1DA47,0x1DA48,0x1DA49,0x1DA4A,0x1DA4B,0x1DA4C,0x1DA4D,0x1DA4E, - 0x1DA4F,0x1DA50,0x1DA51,0x1DA52,0x1DA53,0x1DA54,0x1DA55,0x1DA56, - 0x1DA57,0x1DA58,0x1DA59,0x1DA5A,0x1DA5B,0x1DA5C,0x1DA5D,0x1DA5E, - 0x1DA5F,0x1DA60,0x1DA61,0x1DA62,0x1DA63,0x1DA64,0x1DA65,0x1DA66, - 0x1DA67,0x1DA68,0x1DA69,0x1DA6A,0x1DA6B,0x1DA6C,0x1DA75,0x1DA84, - 0x1DA9B,0x1DA9C,0x1DA9D,0x1DA9E,0x1DA9F,0x1DAA1,0x1DAA2,0x1DAA3, - 0x1DAA4,0x1DAA5,0x1DAA6,0x1DAA7,0x1DAA8,0x1DAA9,0x1DAAA,0x1DAAB, - 0x1DAAC,0x1DAAD,0x1DAAE,0x1DAAF,0x1E8D0,0x1E8D1,0x1E8D2,0x1E8D3, - 0x1E8D4,0x1E8D5,0x1E8D6,0xE0100,0xE0101,0xE0102,0xE0103,0xE0104, - 0xE0105,0xE0106,0xE0107,0xE0108,0xE0109,0xE010A,0xE010B,0xE010C, - 0xE010D,0xE010E,0xE010F,0xE0110,0xE0111,0xE0112,0xE0113,0xE0114, - 0xE0115,0xE0116,0xE0117,0xE0118,0xE0119,0xE011A,0xE011B,0xE011C, - 0xE011D,0xE011E,0xE011F,0xE0120,0xE0121,0xE0122,0xE0123,0xE0124, - 0xE0125,0xE0126,0xE0127,0xE0128,0xE0129,0xE012A,0xE012B,0xE012C, - 0xE012D,0xE012E,0xE012F,0xE0130,0xE0131,0xE0132,0xE0133,0xE0134, - 0xE0135,0xE0136,0xE0137,0xE0138,0xE0139,0xE013A,0xE013B,0xE013C, - 0xE013D,0xE013E,0xE013F,0xE0140,0xE0141,0xE0142,0xE0143,0xE0144, - 0xE0145,0xE0146,0xE0147,0xE0148,0xE0149,0xE014A,0xE014B,0xE014C, - 0xE014D,0xE014E,0xE014F,0xE0150,0xE0151,0xE0152,0xE0153,0xE0154, - 0xE0155,0xE0156,0xE0157,0xE0158,0xE0159,0xE015A,0xE015B,0xE015C, - 0xE015D,0xE015E,0xE015F,0xE0160,0xE0161,0xE0162,0xE0163,0xE0164, - 0xE0165,0xE0166,0xE0167,0xE0168,0xE0169,0xE016A,0xE016B,0xE016C, - 0xE016D,0xE016E,0xE016F,0xE0170,0xE0171,0xE0172,0xE0173,0xE0174, - 0xE0175,0xE0176,0xE0177,0xE0178,0xE0179,0xE017A,0xE017B,0xE017C, - 0xE017D,0xE017E,0xE017F,0xE0180,0xE0181,0xE0182,0xE0183,0xE0184, - 0xE0185,0xE0186,0xE0187,0xE0188,0xE0189,0xE018A,0xE018B,0xE018C, - 0xE018D,0xE018E,0xE018F,0xE0190,0xE0191,0xE0192,0xE0193,0xE0194, - 0xE0195,0xE0196,0xE0197,0xE0198,0xE0199,0xE019A,0xE019B,0xE019C, - 0xE019D,0xE019E,0xE019F,0xE01A0,0xE01A1,0xE01A2,0xE01A3,0xE01A4, - 0xE01A5,0xE01A6,0xE01A7,0xE01A8,0xE01A9,0xE01AA,0xE01AB,0xE01AC, - 0xE01AD,0xE01AE,0xE01AF,0xE01B0,0xE01B1,0xE01B2,0xE01B3,0xE01B4, - 0xE01B5,0xE01B6,0xE01B7,0xE01B8,0xE01B9,0xE01BA,0xE01BB,0xE01BC, - 0xE01BD,0xE01BE,0xE01BF,0xE01C0,0xE01C1,0xE01C2,0xE01C3,0xE01C4, - 0xE01C5,0xE01C6,0xE01C7,0xE01C8,0xE01C9,0xE01CA,0xE01CB,0xE01CC, - 0xE01CD,0xE01CE,0xE01CF,0xE01D0,0xE01D1,0xE01D2,0xE01D3,0xE01D4, - 0xE01D5,0xE01D6,0xE01D7,0xE01D8,0xE01D9,0xE01DA,0xE01DB,0xE01DC, - 0xE01DD,0xE01DE,0xE01DF,0xE01E0,0xE01E1,0xE01E2,0xE01E3,0xE01E4, - 0xE01E5,0xE01E6,0xE01E7,0xE01E8,0xE01E9,0xE01EA,0xE01EB,0xE01EC, - 0xE01ED,0xE01EE,0xE01EF, -}; - -static int unicodeCombiningCharTableSize = sizeof(unicodeCombiningCharTable) / sizeof(unicodeCombiningCharTable[0]); - -inline int unicodeIsCombiningChar(unsigned long cp) -{ - int i; - for (i = 0; i < unicodeCombiningCharTableSize; i++) { - if (unicodeCombiningCharTable[i] == cp) { - return 1; - } - } - return 0; -} - -/* Get length of previous UTF8 character - */ -inline int unicodePrevUTF8CharLen(char* buf, int pos) -{ - int end = pos--; - while (pos >= 0 && ((unsigned char)buf[pos] & 0xC0) == 0x80) { - pos--; - } - return end - pos; -} - -/* Get length of previous UTF8 character - */ -inline int unicodeUTF8CharLen(char* buf, int buf_len, int pos) -{ - if (pos == buf_len) { return 0; } - unsigned char ch = buf[pos]; - if (ch < 0x80) { return 1; } - else if (ch < 0xE0) { return 2; } - else if (ch < 0xF0) { return 3; } - else { return 4; } -} - -/* Convert UTF8 to Unicode code point - */ -inline int unicodeUTF8CharToCodePoint( - const char* buf, - int len, - int* cp) -{ - if (len) { - unsigned char byte = buf[0]; - if ((byte & 0x80) == 0) { - *cp = byte; - return 1; - } else if ((byte & 0xE0) == 0xC0) { - if (len >= 2) { - *cp = (((unsigned long)(buf[0] & 0x1F)) << 6) | - ((unsigned long)(buf[1] & 0x3F)); - return 2; - } - } else if ((byte & 0xF0) == 0xE0) { - if (len >= 3) { - *cp = (((unsigned long)(buf[0] & 0x0F)) << 12) | - (((unsigned long)(buf[1] & 0x3F)) << 6) | - ((unsigned long)(buf[2] & 0x3F)); - return 3; - } - } else if ((byte & 0xF8) == 0xF0) { - if (len >= 4) { - *cp = (((unsigned long)(buf[0] & 0x07)) << 18) | - (((unsigned long)(buf[1] & 0x3F)) << 12) | - (((unsigned long)(buf[2] & 0x3F)) << 6) | - ((unsigned long)(buf[3] & 0x3F)); - return 4; - } - } - } - return 0; -} - -/* Get length of grapheme - */ -inline int unicodeGraphemeLen(char* buf, int buf_len, int pos) -{ - if (pos == buf_len) { - return 0; - } - int beg = pos; - pos += unicodeUTF8CharLen(buf, buf_len, pos); - while (pos < buf_len) { - int len = unicodeUTF8CharLen(buf, buf_len, pos); - int cp = 0; - unicodeUTF8CharToCodePoint(buf + pos, len, &cp); - if (!unicodeIsCombiningChar(cp)) { - return pos - beg; - } - pos += len; - } - return pos - beg; -} - -/* Get length of previous grapheme - */ -inline int unicodePrevGraphemeLen(char* buf, int pos) -{ - if (pos == 0) { - return 0; - } - int end = pos; - while (pos > 0) { - int len = unicodePrevUTF8CharLen(buf, pos); - pos -= len; - int cp = 0; - unicodeUTF8CharToCodePoint(buf + pos, len, &cp); - if (!unicodeIsCombiningChar(cp)) { - return end - pos; - } - } - return 0; -} - -inline int isAnsiEscape(const char* buf, int buf_len, int* len) -{ - if (buf_len > 2 && !memcmp("\033[", buf, 2)) { - int off = 2; - while (off < buf_len) { - switch (buf[off++]) { - case 'A': case 'B': case 'C': case 'D': - case 'E': case 'F': case 'G': case 'H': - case 'J': case 'K': case 'S': case 'T': - case 'f': case 'm': - *len = off; - return 1; - } - } - } - return 0; -} - -/* Get column position for the single line mode. - */ -inline int unicodeColumnPos(const char* buf, int buf_len) -{ - int ret = 0; - - int off = 0; - while (off < buf_len) { - int len; - if (isAnsiEscape(buf + off, buf_len - off, &len)) { - off += len; - continue; - } - - int cp = 0; - len = unicodeUTF8CharToCodePoint(buf + off, buf_len - off, &cp); - - if (!unicodeIsCombiningChar(cp)) { - ret += unicodeIsWideChar(cp) ? 2 : 1; - } - - off += len; - } - - return ret; -} - -/* Get column position for the multi line mode. - */ -inline int unicodeColumnPosForMultiLine(char* buf, int buf_len, int pos, int cols, int ini_pos) -{ - int ret = 0; - int colwid = ini_pos; - - int off = 0; - while (off < buf_len) { - int cp = 0; - int len = unicodeUTF8CharToCodePoint(buf + off, buf_len - off, &cp); - - int wid = 0; - if (!unicodeIsCombiningChar(cp)) { - wid = unicodeIsWideChar(cp) ? 2 : 1; - } - - int dif = (int)(colwid + wid) - (int)cols; - if (dif > 0) { - ret += dif; - colwid = wid; - } else if (dif == 0) { - colwid = 0; - } else { - colwid += wid; - } - - if (off >= pos) { - break; - } - - off += len; - ret += wid; - } - - return ret; -} - -/* Read UTF8 character from file. - */ -inline int unicodeReadUTF8Char(int fd, char* buf, int* cp) -{ - int nread = read(fd,&buf[0],1); - - if (nread <= 0) { return nread; } - - unsigned char byte = buf[0]; - - if ((byte & 0x80) == 0) { - ; - } else if ((byte & 0xE0) == 0xC0) { - nread = read(fd,&buf[1],1); - if (nread <= 0) { return nread; } - } else if ((byte & 0xF0) == 0xE0) { - nread = read(fd,&buf[1],2); - if (nread <= 0) { return nread; } - } else if ((byte & 0xF8) == 0xF0) { - nread = read(fd,&buf[1],3); - if (nread <= 0) { return nread; } - } else { - return -1; - } - - return unicodeUTF8CharToCodePoint(buf, 4, cp); -} - -/* ======================= Low level terminal handling ====================== */ - -/* Set if to use or not the multi line mode. */ -inline void SetMultiLine(bool ml) { - mlmode = ml; -} - -/* Return true if the terminal name is in the list of terminals we know are - * not able to understand basic escape sequences. */ -inline bool isUnsupportedTerm(void) { -#ifndef _WIN32 - char *term = getenv("TERM"); - int j; - - if (term == NULL) return false; - for (j = 0; unsupported_term[j]; j++) - if (!strcasecmp(term,unsupported_term[j])) return true; -#endif - return false; -} - -/* Raw mode: 1960 magic shit. */ -inline bool enableRawMode(int fd) { -#ifndef _WIN32 - struct termios raw; - - if (!isatty(STDIN_FILENO)) goto fatal; - if (!atexit_registered) { - atexit(linenoiseAtExit); - atexit_registered = true; - } - if (tcgetattr(fd,&orig_termios) == -1) goto fatal; - - raw = orig_termios; /* modify the original mode */ - /* input modes: no break, no CR to NL, no parity check, no strip char, - * no start/stop output control. */ - raw.c_iflag &= ~(BRKINT | ICRNL | INPCK | ISTRIP | IXON); - /* output modes - disable post processing */ - // NOTE: Multithreaded issue #20 (https://github.com/yhirose/cpp-linenoise/issues/20) - // raw.c_oflag &= ~(OPOST); - /* control modes - set 8 bit chars */ - raw.c_cflag |= (CS8); - /* local modes - echoing off, canonical off, no extended functions, - * no signal chars (^Z,^C) */ - raw.c_lflag &= ~(ECHO | ICANON | IEXTEN | ISIG); - /* control chars - set return condition: min number of bytes and timer. - * We want read to return every single byte, without timeout. */ - raw.c_cc[VMIN] = 1; raw.c_cc[VTIME] = 0; /* 1 byte, no timer */ - - /* put terminal in raw mode after flushing */ - if (tcsetattr(fd,TCSAFLUSH,&raw) < 0) goto fatal; - rawmode = true; -#else - if (!atexit_registered) { - /* Cleanup them at exit */ - atexit(linenoiseAtExit); - atexit_registered = true; - - /* Init windows console handles only once */ - hOut = GetStdHandle(STD_OUTPUT_HANDLE); - if (hOut==INVALID_HANDLE_VALUE) goto fatal; - } - - DWORD consolemodeOut; - if (!GetConsoleMode(hOut, &consolemodeOut)) { - CloseHandle(hOut); - errno = ENOTTY; - return false; - }; - - hIn = GetStdHandle(STD_INPUT_HANDLE); - if (hIn == INVALID_HANDLE_VALUE) { - CloseHandle(hOut); - errno = ENOTTY; - return false; - } - - GetConsoleMode(hIn, &consolemodeIn); - /* Enable raw mode */ - SetConsoleMode(hIn, consolemodeIn & ~ENABLE_PROCESSED_INPUT); - - rawmode = true; -#endif - return true; - -fatal: - errno = ENOTTY; - return false; -} - -inline void disableRawMode(int fd) { -#ifdef _WIN32 - if (consolemodeIn) { - SetConsoleMode(hIn, consolemodeIn); - consolemodeIn = 0; - } - rawmode = false; -#else - /* Don't even check the return value as it's too late. */ - if (rawmode && tcsetattr(fd,TCSAFLUSH,&orig_termios) != -1) - rawmode = false; -#endif -} - -/* Use the ESC [6n escape sequence to query the horizontal cursor position - * and return it. On error -1 is returned, on success the position of the - * cursor. */ -inline int getCursorPosition(int ifd, int ofd) { - char buf[32]; - int cols, rows; - unsigned int i = 0; - - /* Report cursor location */ - if (write(ofd, "\x1b[6n", 4) != 4) return -1; - - /* Read the response: ESC [ rows ; cols R */ - while (i < sizeof(buf)-1) { - if (read(ifd,buf+i,1) != 1) break; - if (buf[i] == 'R') break; - i++; - } - buf[i] = '\0'; - - /* Parse it. */ - if (buf[0] != ESC || buf[1] != '[') return -1; - if (sscanf(buf+2,"%d;%d",&rows,&cols) != 2) return -1; - return cols; -} - -/* Try to get the number of columns in the current terminal, or assume 80 - * if it fails. */ -inline int getColumns(int ifd, int ofd) { -#ifdef _WIN32 - CONSOLE_SCREEN_BUFFER_INFO b; - - if (!GetConsoleScreenBufferInfo(hOut, &b)) return 80; - return b.srWindow.Right - b.srWindow.Left; -#else - struct winsize ws; - - if (ioctl(1, TIOCGWINSZ, &ws) == -1 || ws.ws_col == 0) { - /* ioctl() failed. Try to query the terminal itself. */ - int start, cols; - - /* Get the initial position so we can restore it later. */ - start = getCursorPosition(ifd,ofd); - if (start == -1) goto failed; - - /* Go to right margin and get position. */ - if (write(ofd,"\x1b[999C",6) != 6) goto failed; - cols = getCursorPosition(ifd,ofd); - if (cols == -1) goto failed; - - /* Restore position. */ - if (cols > start) { - char seq[32]; - snprintf(seq,32,"\x1b[%dD",cols-start); - if (write(ofd,seq,strlen(seq)) == -1) { - /* Can't recover... */ - } - } - return cols; - } else { - return ws.ws_col; - } - -failed: - return 80; -#endif -} - -/* Clear the screen. Used to handle ctrl+l */ -inline void linenoiseClearScreen(void) { - if (write(STDOUT_FILENO,"\x1b[H\x1b[2J",7) <= 0) { - /* nothing to do, just to avoid warning. */ - } -} - -/* Beep, used for completion when there is nothing to complete or when all - * the choices were already shown. */ -inline void linenoiseBeep(void) { - fprintf(stderr, "\x7"); - fflush(stderr); -} - -/* ============================== Completion ================================ */ - -/* This is an helper function for linenoiseEdit() and is called when the - * user types the key in order to complete the string currently in the - * input. - * - * The state of the editing is encapsulated into the pointed linenoiseState - * structure as described in the structure definition. */ -inline int completeLine(struct linenoiseState *ls, char *cbuf, int *c) { - std::vector lc; - int nread = 0, nwritten; - *c = 0; - - completionCallback(ls->buf,lc); - if (lc.empty()) { - linenoiseBeep(); - } else { - int stop = 0, i = 0; - - while(!stop) { - /* Show completion or original buffer */ - if (i < static_cast(lc.size())) { - struct linenoiseState saved = *ls; - - ls->len = ls->pos = static_cast(lc[i].size()); - ls->buf = &lc[i][0]; - refreshLine(ls); - ls->len = saved.len; - ls->pos = saved.pos; - ls->buf = saved.buf; - } else { - refreshLine(ls); - } - - //nread = read(ls->ifd,&c,1); -#ifdef _WIN32 - nread = win32read(c); - if (nread == 1) { - cbuf[0] = *c; - } -#else - nread = unicodeReadUTF8Char(ls->ifd,cbuf,c); -#endif - if (nread <= 0) { - *c = -1; - return nread; - } - - switch(*c) { - case 9: /* tab */ - i = (i+1) % (lc.size()+1); - if (i == static_cast(lc.size())) linenoiseBeep(); - break; - case 27: /* escape */ - /* Re-show original buffer */ - if (i < static_cast(lc.size())) refreshLine(ls); - stop = 1; - break; - default: - /* Update buffer and return */ - if (i < static_cast(lc.size())) { - nwritten = snprintf(ls->buf,ls->buflen,"%s",&lc[i][0]); - ls->len = ls->pos = nwritten; - } - stop = 1; - break; - } - } - } - - return nread; -} - -/* Register a callback function to be called for tab-completion. */ -inline void SetCompletionCallback(CompletionCallback fn) { - completionCallback = fn; -} - -/* =========================== Line editing ================================= */ - -/* Single line low level line refresh. - * - * Rewrite the currently edited line accordingly to the buffer content, - * cursor position, and number of columns of the terminal. */ -inline void refreshSingleLine(struct linenoiseState *l) { - char seq[64]; - int pcolwid = unicodeColumnPos(l->prompt.c_str(), static_cast(l->prompt.length())); - int fd = l->ofd; - char *buf = l->buf; - int len = l->len; - int pos = l->pos; - std::string ab; - - while((pcolwid+unicodeColumnPos(buf, pos)) >= l->cols) { - int glen = unicodeGraphemeLen(buf, len, 0); - buf += glen; - len -= glen; - pos -= glen; - } - while (pcolwid+unicodeColumnPos(buf, len) > l->cols) { - len -= unicodePrevGraphemeLen(buf, len); - } - - /* Cursor to left edge */ - snprintf(seq,64,"\r"); - ab += seq; - /* Write the prompt and the current buffer content */ - ab += l->prompt; - ab.append(buf, len); - /* Erase to right */ - snprintf(seq,64,"\x1b[0K"); - ab += seq; - /* Move cursor to original position. */ - snprintf(seq,64,"\r\x1b[%dC", (int)(unicodeColumnPos(buf, pos)+pcolwid)); - ab += seq; - if (write(fd,ab.c_str(), static_cast(ab.length())) == -1) {} /* Can't recover from write error. */ -} - -/* Multi line low level line refresh. - * - * Rewrite the currently edited line accordingly to the buffer content, - * cursor position, and number of columns of the terminal. */ -inline void refreshMultiLine(struct linenoiseState *l) { - char seq[64]; - int pcolwid = unicodeColumnPos(l->prompt.c_str(), static_cast(l->prompt.length())); - int colpos = unicodeColumnPosForMultiLine(l->buf, l->len, l->len, l->cols, pcolwid); - int colpos2; /* cursor column position. */ - int rows = (pcolwid+colpos+l->cols-1)/l->cols; /* rows used by current buf. */ - int rpos = (pcolwid+l->oldcolpos+l->cols)/l->cols; /* cursor relative row. */ - int rpos2; /* rpos after refresh. */ - int col; /* colum position, zero-based. */ - int old_rows = (int)l->maxrows; - int fd = l->ofd, j; - std::string ab; - - /* Update maxrows if needed. */ - if (rows > (int)l->maxrows) l->maxrows = rows; - - /* First step: clear all the lines used before. To do so start by - * going to the last row. */ - if (old_rows-rpos > 0) { - snprintf(seq,64,"\x1b[%dB", old_rows-rpos); - ab += seq; - } - - /* Now for every row clear it, go up. */ - for (j = 0; j < old_rows-1; j++) { - snprintf(seq,64,"\r\x1b[0K\x1b[1A"); - ab += seq; - } - - /* Clean the top line. */ - snprintf(seq,64,"\r\x1b[0K"); - ab += seq; - - /* Write the prompt and the current buffer content */ - ab += l->prompt; - ab.append(l->buf, l->len); - - /* Get text width to cursor position */ - colpos2 = unicodeColumnPosForMultiLine(l->buf, l->len, l->pos, l->cols, pcolwid); - - /* If we are at the very end of the screen with our prompt, we need to - * emit a newline and move the prompt to the first column. */ - if (l->pos && - l->pos == l->len && - (colpos2+pcolwid) % l->cols == 0) - { - ab += "\n"; - snprintf(seq,64,"\r"); - ab += seq; - rows++; - if (rows > (int)l->maxrows) l->maxrows = rows; - } - - /* Move cursor to right position. */ - rpos2 = (pcolwid+colpos2+l->cols)/l->cols; /* current cursor relative row. */ - - /* Go up till we reach the expected positon. */ - if (rows-rpos2 > 0) { - snprintf(seq,64,"\x1b[%dA", rows-rpos2); - ab += seq; - } - - /* Set column. */ - col = (pcolwid + colpos2) % l->cols; - if (col) - snprintf(seq,64,"\r\x1b[%dC", col); - else - snprintf(seq,64,"\r"); - ab += seq; - - l->oldcolpos = colpos2; - - if (write(fd,ab.c_str(), static_cast(ab.length())) == -1) {} /* Can't recover from write error. */ -} - -/* Calls the two low level functions refreshSingleLine() or - * refreshMultiLine() according to the selected mode. */ -inline void refreshLine(struct linenoiseState *l) { - if (mlmode) - refreshMultiLine(l); - else - refreshSingleLine(l); -} - -/* Insert the character 'c' at cursor current position. - * - * On error writing to the terminal -1 is returned, otherwise 0. */ -inline int linenoiseEditInsert(struct linenoiseState *l, const char* cbuf, int clen) { - if (l->len < l->buflen) { - if (l->len == l->pos) { - memcpy(&l->buf[l->pos],cbuf,clen); - l->pos+=clen; - l->len+=clen;; - l->buf[l->len] = '\0'; - if ((!mlmode && unicodeColumnPos(l->prompt.c_str(), static_cast(l->prompt.length()))+unicodeColumnPos(l->buf,l->len) < l->cols) /* || mlmode */) { - /* Avoid a full update of the line in the - * trivial case. */ - if (write(l->ofd,cbuf,clen) == -1) return -1; - } else { - refreshLine(l); - } - } else { - memmove(l->buf+l->pos+clen,l->buf+l->pos,l->len-l->pos); - memcpy(&l->buf[l->pos],cbuf,clen); - l->pos+=clen; - l->len+=clen; - l->buf[l->len] = '\0'; - refreshLine(l); - } - } - return 0; -} - -/* Move cursor on the left. */ -inline void linenoiseEditMoveLeft(struct linenoiseState *l) { - if (l->pos > 0) { - l->pos -= unicodePrevGraphemeLen(l->buf, l->pos); - refreshLine(l); - } -} - -/* Move cursor on the right. */ -inline void linenoiseEditMoveRight(struct linenoiseState *l) { - if (l->pos != l->len) { - l->pos += unicodeGraphemeLen(l->buf, l->len, l->pos); - refreshLine(l); - } -} - -/* Move cursor to the start of the line. */ -inline void linenoiseEditMoveHome(struct linenoiseState *l) { - if (l->pos != 0) { - l->pos = 0; - refreshLine(l); - } -} - -/* Move cursor to the end of the line. */ -inline void linenoiseEditMoveEnd(struct linenoiseState *l) { - if (l->pos != l->len) { - l->pos = l->len; - refreshLine(l); - } -} - -/* Substitute the currently edited line with the next or previous history - * entry as specified by 'dir'. */ -#define LINENOISE_HISTORY_NEXT 0 -#define LINENOISE_HISTORY_PREV 1 -inline void linenoiseEditHistoryNext(struct linenoiseState *l, int dir) { - if (history.size() > 1) { - /* Update the current history entry before to - * overwrite it with the next one. */ - history[history.size() - 1 - l->history_index] = l->buf; - /* Show the new entry */ - l->history_index += (dir == LINENOISE_HISTORY_PREV) ? 1 : -1; - if (l->history_index < 0) { - l->history_index = 0; - return; - } else if (l->history_index >= (int)history.size()) { - l->history_index = static_cast(history.size())-1; - return; - } - memset(l->buf, 0, l->buflen); - strcpy(l->buf,history[history.size() - 1 - l->history_index].c_str()); - l->len = l->pos = static_cast(strlen(l->buf)); - refreshLine(l); - } -} - -/* Delete the character at the right of the cursor without altering the cursor - * position. Basically this is what happens with the "Delete" keyboard key. */ -inline void linenoiseEditDelete(struct linenoiseState *l) { - if (l->len > 0 && l->pos < l->len) { - int glen = unicodeGraphemeLen(l->buf,l->len,l->pos); - memmove(l->buf+l->pos,l->buf+l->pos+glen,l->len-l->pos-glen); - l->len-=glen; - l->buf[l->len] = '\0'; - refreshLine(l); - } -} - -/* Backspace implementation. */ -inline void linenoiseEditBackspace(struct linenoiseState *l) { - if (l->pos > 0 && l->len > 0) { - int glen = unicodePrevGraphemeLen(l->buf,l->pos); - memmove(l->buf+l->pos-glen,l->buf+l->pos,l->len-l->pos); - l->pos-=glen; - l->len-=glen; - l->buf[l->len] = '\0'; - refreshLine(l); - } -} - -/* Delete the previosu word, maintaining the cursor at the start of the - * current word. */ -inline void linenoiseEditDeletePrevWord(struct linenoiseState *l) { - int old_pos = l->pos; - int diff; - - while (l->pos > 0 && l->buf[l->pos-1] == ' ') - l->pos--; - while (l->pos > 0 && l->buf[l->pos-1] != ' ') - l->pos--; - diff = old_pos - l->pos; - memmove(l->buf+l->pos,l->buf+old_pos,l->len-old_pos+1); - l->len -= diff; - refreshLine(l); -} - -/* This function is the core of the line editing capability of linenoise. - * It expects 'fd' to be already in "raw mode" so that every key pressed - * will be returned ASAP to read(). - * - * The resulting string is put into 'buf' when the user type enter, or - * when ctrl+d is typed. - * - * The function returns the length of the current buffer. */ -inline int linenoiseEdit(int stdin_fd, int stdout_fd, char *buf, int buflen, const char *prompt) -{ - struct linenoiseState l; - - /* Populate the linenoise state that we pass to functions implementing - * specific editing functionalities. */ - l.ifd = stdin_fd; - l.ofd = stdout_fd; - l.buf = buf; - l.buflen = buflen; - l.prompt = prompt; - l.oldcolpos = l.pos = 0; - l.len = 0; - l.cols = getColumns(stdin_fd, stdout_fd); - l.maxrows = 0; - l.history_index = 0; - - /* Buffer starts empty. */ - l.buf[0] = '\0'; - l.buflen--; /* Make sure there is always space for the nulterm */ - - /* The latest history entry is always our current buffer, that - * initially is just an empty string. */ - AddHistory(""); - - if (write(l.ofd,prompt, static_cast(l.prompt.length())) == -1) return -1; - while(1) { - int c; - char cbuf[4]; - int nread; - char seq[3]; - -#ifdef _WIN32 - nread = win32read(&c); - if (nread == 1) { - cbuf[0] = c; - } -#else - nread = unicodeReadUTF8Char(l.ifd,cbuf,&c); -#endif - if (nread <= 0) return (int)l.len; - - /* Only autocomplete when the callback is set. It returns < 0 when - * there was an error reading from fd. Otherwise it will return the - * character that should be handled next. */ - if (c == 9 && completionCallback != NULL) { - nread = completeLine(&l,cbuf,&c); - /* Return on errors */ - if (c < 0) return l.len; - /* Read next character when 0 */ - if (c == 0) continue; - } - - switch(c) { - case ENTER: /* enter */ - if (!history.empty()) history.pop_back(); - if (mlmode) linenoiseEditMoveEnd(&l); - return (int)l.len; - case CTRL_C: /* ctrl-c */ - errno = EAGAIN; - return -1; - case BACKSPACE: /* backspace */ - case 8: /* ctrl-h */ - linenoiseEditBackspace(&l); - break; - case CTRL_D: /* ctrl-d, remove char at right of cursor, or if the - line is empty, act as end-of-file. */ - if (l.len > 0) { - linenoiseEditDelete(&l); - } else { - history.pop_back(); - return -1; - } - break; - case CTRL_T: /* ctrl-t, swaps current character with previous. */ - if (l.pos > 0 && l.pos < l.len) { - char aux = buf[l.pos-1]; - buf[l.pos-1] = buf[l.pos]; - buf[l.pos] = aux; - if (l.pos != l.len-1) l.pos++; - refreshLine(&l); - } - break; - case CTRL_B: /* ctrl-b */ - linenoiseEditMoveLeft(&l); - break; - case CTRL_F: /* ctrl-f */ - linenoiseEditMoveRight(&l); - break; - case CTRL_P: /* ctrl-p */ - linenoiseEditHistoryNext(&l, LINENOISE_HISTORY_PREV); - break; - case CTRL_N: /* ctrl-n */ - linenoiseEditHistoryNext(&l, LINENOISE_HISTORY_NEXT); - break; - case ESC: /* escape sequence */ - /* Read the next two bytes representing the escape sequence. - * Use two calls to handle slow terminals returning the two - * chars at different times. */ - if (read(l.ifd,seq,1) == -1) break; - if (read(l.ifd,seq+1,1) == -1) break; - - /* ESC [ sequences. */ - if (seq[0] == '[') { - if (seq[1] >= '0' && seq[1] <= '9') { - /* Extended escape, read additional byte. */ - if (read(l.ifd,seq+2,1) == -1) break; - if (seq[2] == '~') { - switch(seq[1]) { - case '3': /* Delete key. */ - linenoiseEditDelete(&l); - break; - } - } - } else { - switch(seq[1]) { - case 'A': /* Up */ - linenoiseEditHistoryNext(&l, LINENOISE_HISTORY_PREV); - break; - case 'B': /* Down */ - linenoiseEditHistoryNext(&l, LINENOISE_HISTORY_NEXT); - break; - case 'C': /* Right */ - linenoiseEditMoveRight(&l); - break; - case 'D': /* Left */ - linenoiseEditMoveLeft(&l); - break; - case 'H': /* Home */ - linenoiseEditMoveHome(&l); - break; - case 'F': /* End*/ - linenoiseEditMoveEnd(&l); - break; - } - } - } - - /* ESC O sequences. */ - else if (seq[0] == 'O') { - switch(seq[1]) { - case 'H': /* Home */ - linenoiseEditMoveHome(&l); - break; - case 'F': /* End*/ - linenoiseEditMoveEnd(&l); - break; - } - } - break; - default: - if (linenoiseEditInsert(&l,cbuf,nread)) return -1; - break; - case CTRL_U: /* Ctrl+u, delete the whole line. */ - buf[0] = '\0'; - l.pos = l.len = 0; - refreshLine(&l); - break; - case CTRL_K: /* Ctrl+k, delete from current to end of line. */ - buf[l.pos] = '\0'; - l.len = l.pos; - refreshLine(&l); - break; - case CTRL_A: /* Ctrl+a, go to the start of the line */ - linenoiseEditMoveHome(&l); - break; - case CTRL_E: /* ctrl+e, go to the end of the line */ - linenoiseEditMoveEnd(&l); - break; - case CTRL_L: /* ctrl+l, clear screen */ - linenoiseClearScreen(); - refreshLine(&l); - break; - case CTRL_W: /* ctrl+w, delete previous word */ - linenoiseEditDeletePrevWord(&l); - break; - } - } - return l.len; -} - -/* This function calls the line editing function linenoiseEdit() using - * the STDIN file descriptor set in raw mode. */ -inline bool linenoiseRaw(const char *prompt, std::string& line) { - bool quit = false; - - if (!isatty(STDIN_FILENO)) { - /* Not a tty: read from file / pipe. */ - std::getline(std::cin, line); - } else { - /* Interactive editing. */ - if (enableRawMode(STDIN_FILENO) == false) { - return quit; - } - - char buf[LINENOISE_MAX_LINE]; - auto count = linenoiseEdit(STDIN_FILENO, STDOUT_FILENO, buf, LINENOISE_MAX_LINE, prompt); - if (count == -1) { - quit = true; - } else { - line.assign(buf, count); - } - - disableRawMode(STDIN_FILENO); - printf("\n"); - } - return quit; -} - -/* The high level function that is the main API of the linenoise library. - * This function checks if the terminal has basic capabilities, just checking - * for a blacklist of stupid terminals, and later either calls the line - * editing function or uses dummy fgets() so that you will be able to type - * something even in the most desperate of the conditions. */ -inline bool Readline(const char *prompt, std::string& line) { - if (isUnsupportedTerm()) { - printf("%s",prompt); - fflush(stdout); - std::getline(std::cin, line); - return false; - } else { - return linenoiseRaw(prompt, line); - } -} - -inline std::string Readline(const char *prompt, bool& quit) { - std::string line; - quit = Readline(prompt, line); - return line; -} - -inline std::string Readline(const char *prompt) { - bool quit; // dummy - return Readline(prompt, quit); -} - -/* ================================ History ================================= */ - -/* At exit we'll try to fix the terminal to the initial conditions. */ -inline void linenoiseAtExit(void) { - disableRawMode(STDIN_FILENO); -} - -/* This is the API call to add a new entry in the linenoise history. - * It uses a fixed array of char pointers that are shifted (memmoved) - * when the history max length is reached in order to remove the older - * entry and make room for the new one, so it is not exactly suitable for huge - * histories, but will work well for a few hundred of entries. - * - * Using a circular buffer is smarter, but a bit more complex to handle. */ -inline bool AddHistory(const char* line) { - if (history_max_len == 0) return false; - - /* Don't add duplicated lines. */ - if (!history.empty() && history.back() == line) return false; - - /* If we reached the max length, remove the older line. */ - if (history.size() == history_max_len) { - history.erase(history.begin()); - } - history.push_back(line); - - return true; -} - -/* Set the maximum length for the history. This function can be called even - * if there is already some history, the function will make sure to retain - * just the latest 'len' elements if the new history length value is smaller - * than the amount of items already inside the history. */ -inline bool SetHistoryMaxLen(size_t len) { - if (len < 1) return false; - history_max_len = len; - if (len < history.size()) { - history.resize(len); - } - return true; -} - -/* Save the history in the specified file. On success *true* is returned - * otherwise *false* is returned. */ -inline bool SaveHistory(const char* path) { - std::ofstream f(path); // TODO: need 'std::ios::binary'? - if (!f) return false; - for (const auto& h: history) { - f << h << std::endl; - } - return true; -} - -/* Load the history from the specified file. If the file does not exist - * zero is returned and no operation is performed. - * - * If the file exists and the operation succeeded *true* is returned, otherwise - * on error *false* is returned. */ -inline bool LoadHistory(const char* path) { - std::ifstream f(path); - if (!f) return false; - std::string line; - while (std::getline(f, line)) { - AddHistory(line.c_str()); - } - return true; -} - -inline const std::vector& GetHistory() { - return history; -} - -} // namespace linenoise - -#ifdef _WIN32 -#undef isatty -#undef write -#undef read -#pragma warning(pop) -#endif - -#endif /* __LINENOISE_HPP */ diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index e8e3b3156..59e125740 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2536,7 +2536,7 @@ TEST_CASE("autocomplete_documentation_symbols") TEST_CASE_FIXTURE(ACFixture, "autocomplete_ifelse_expressions") { - check(R"( + check(R"( local temp = false local even = true; local a = true @@ -2551,63 +2551,63 @@ a = if temp then even elseif true then temp e@8 a = if temp then even elseif true then temp else e@9 )"); - auto ac = autocomplete('1'); - CHECK(ac.entryMap.count("temp")); - CHECK(ac.entryMap.count("true")); - CHECK(ac.entryMap.count("then") == 0); - CHECK(ac.entryMap.count("else") == 0); - CHECK(ac.entryMap.count("elseif") == 0); - - ac = autocomplete('2'); - CHECK(ac.entryMap.count("temp") == 0); - CHECK(ac.entryMap.count("true") == 0); - CHECK(ac.entryMap.count("then")); - CHECK(ac.entryMap.count("else") == 0); - CHECK(ac.entryMap.count("elseif") == 0); - - ac = autocomplete('3'); - CHECK(ac.entryMap.count("even")); - CHECK(ac.entryMap.count("then") == 0); - CHECK(ac.entryMap.count("else") == 0); - CHECK(ac.entryMap.count("elseif") == 0); - - ac = autocomplete('4'); - CHECK(ac.entryMap.count("even") == 0); - CHECK(ac.entryMap.count("then") == 0); - CHECK(ac.entryMap.count("else")); - CHECK(ac.entryMap.count("elseif")); - - ac = autocomplete('5'); - CHECK(ac.entryMap.count("temp")); - CHECK(ac.entryMap.count("true")); - CHECK(ac.entryMap.count("then") == 0); - CHECK(ac.entryMap.count("else") == 0); - CHECK(ac.entryMap.count("elseif") == 0); - - ac = autocomplete('6'); - CHECK(ac.entryMap.count("temp") == 0); - CHECK(ac.entryMap.count("true") == 0); - CHECK(ac.entryMap.count("then")); - CHECK(ac.entryMap.count("else") == 0); - CHECK(ac.entryMap.count("elseif") == 0); - - ac = autocomplete('7'); - CHECK(ac.entryMap.count("temp")); - CHECK(ac.entryMap.count("true")); - CHECK(ac.entryMap.count("then") == 0); - CHECK(ac.entryMap.count("else") == 0); - CHECK(ac.entryMap.count("elseif") == 0); - - ac = autocomplete('8'); - CHECK(ac.entryMap.count("even") == 0); - CHECK(ac.entryMap.count("then") == 0); - CHECK(ac.entryMap.count("else")); - CHECK(ac.entryMap.count("elseif")); - - ac = autocomplete('9'); - CHECK(ac.entryMap.count("then") == 0); - CHECK(ac.entryMap.count("else") == 0); - CHECK(ac.entryMap.count("elseif") == 0); + auto ac = autocomplete('1'); + CHECK(ac.entryMap.count("temp")); + CHECK(ac.entryMap.count("true")); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); + + ac = autocomplete('2'); + CHECK(ac.entryMap.count("temp") == 0); + CHECK(ac.entryMap.count("true") == 0); + CHECK(ac.entryMap.count("then")); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); + + ac = autocomplete('3'); + CHECK(ac.entryMap.count("even")); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); + + ac = autocomplete('4'); + CHECK(ac.entryMap.count("even") == 0); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else")); + CHECK(ac.entryMap.count("elseif")); + + ac = autocomplete('5'); + CHECK(ac.entryMap.count("temp")); + CHECK(ac.entryMap.count("true")); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); + + ac = autocomplete('6'); + CHECK(ac.entryMap.count("temp") == 0); + CHECK(ac.entryMap.count("true") == 0); + CHECK(ac.entryMap.count("then")); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); + + ac = autocomplete('7'); + CHECK(ac.entryMap.count("temp")); + CHECK(ac.entryMap.count("true")); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); + + ac = autocomplete('8'); + CHECK(ac.entryMap.count("even") == 0); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else")); + CHECK(ac.entryMap.count("elseif")); + + ac = autocomplete('9'); + CHECK(ac.entryMap.count("then") == 0); + CHECK(ac.entryMap.count("else") == 0); + CHECK(ac.entryMap.count("elseif") == 0); } TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack") diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index 4a28bdde9..d8af94dba 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -611,7 +611,8 @@ TEST_CASE("TableLiteralsIndexConstant") CHECK_EQ("\n" + compileFunction0(R"( local a, b = "key", "value" return {[a] = 42, [b] = 0} -)"), R"( +)"), + R"( NEWTABLE R0 2 0 LOADN R1 42 SETTABLEKS R1 R0 K0 @@ -624,7 +625,8 @@ RETURN R0 1 CHECK_EQ("\n" + compileFunction0(R"( local a, b = 1, 2 return {[a] = 42, [b] = 0} -)"), R"( +)"), + R"( NEWTABLE R0 0 2 LOADN R1 42 SETTABLEN R1 R0 1 @@ -789,8 +791,6 @@ RETURN R0 1 TEST_CASE("TableSizePredictionLoop") { - ScopedFastFlag sff("LuauPredictTableSizeLoop", true); - CHECK_EQ("\n" + compileFunction0(R"( local t = {} for i=1,4 do @@ -2827,7 +2827,7 @@ RETURN R1 -1 TEST_CASE("FastcallSelect") { - ScopedFastFlag sff("LuauCompileSelectBuiltin", true); + ScopedFastFlag sff("LuauCompileSelectBuiltin2", true); // select(_, ...) compiles to a builtin call CHECK_EQ("\n" + compileFunction0("return (select('#', ...))"), R"( @@ -2846,7 +2846,8 @@ for i=1, select('#', ...) do sum += select(i, ...) end return sum -)"), R"( +)"), + R"( LOADN R0 0 LOADN R3 1 LOADK R5 K0 @@ -2856,13 +2857,14 @@ GETVARARGS R6 -1 CALL R4 -1 1 MOVE R1 R4 LOADN R2 1 -FORNPREP R1 +7 -FASTCALL1 57 R3 +3 +FORNPREP R1 +8 +FASTCALL1 57 R3 +4 GETIMPORT R4 2 +MOVE R5 R3 GETVARARGS R6 -1 CALL R4 -1 1 ADD R0 R0 R4 -FORNLOOP R1 -7 +FORNLOOP R1 -8 RETURN R0 1 )"); diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 914b881f7..e580949f5 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -492,7 +492,6 @@ TEST_CASE("DateTime") TEST_CASE("Debug") { - ScopedFastFlag sffr("LuauBytecodeV2Read", true); ScopedFastFlag sffw("LuauBytecodeV2Write", true); runConformance("debug.lua"); diff --git a/tests/LValue.test.cpp b/tests/LValue.test.cpp index 8a092779c..606f6de39 100644 --- a/tests/LValue.test.cpp +++ b/tests/LValue.test.cpp @@ -38,8 +38,6 @@ TEST_SUITE_BEGIN("LValue"); TEST_CASE("Luau_merge_hashmap_order") { - ScopedFastFlag sff{"LuauLValueAsKey", true}; - std::string a = "a"; std::string b = "b"; std::string c = "c"; @@ -58,20 +56,18 @@ TEST_CASE("Luau_merge_hashmap_order") TypeArena arena; merge(arena, m, other); - REQUIRE_EQ(3, m.NEW_refinements.size()); - REQUIRE(m.NEW_refinements.count(mkSymbol(a))); - REQUIRE(m.NEW_refinements.count(mkSymbol(b))); - REQUIRE(m.NEW_refinements.count(mkSymbol(c))); + REQUIRE_EQ(3, m.size()); + REQUIRE(m.count(mkSymbol(a))); + REQUIRE(m.count(mkSymbol(b))); + REQUIRE(m.count(mkSymbol(c))); - CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(a)])); - CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(b)])); - CHECK_EQ("boolean | number", toString(m.NEW_refinements[mkSymbol(c)])); + CHECK_EQ("string", toString(m[mkSymbol(a)])); + CHECK_EQ("string", toString(m[mkSymbol(b)])); + CHECK_EQ("boolean | number", toString(m[mkSymbol(c)])); } TEST_CASE("Luau_merge_hashmap_order2") { - ScopedFastFlag sff{"LuauLValueAsKey", true}; - std::string a = "a"; std::string b = "b"; std::string c = "c"; @@ -90,20 +86,18 @@ TEST_CASE("Luau_merge_hashmap_order2") TypeArena arena; merge(arena, m, other); - REQUIRE_EQ(3, m.NEW_refinements.size()); - REQUIRE(m.NEW_refinements.count(mkSymbol(a))); - REQUIRE(m.NEW_refinements.count(mkSymbol(b))); - REQUIRE(m.NEW_refinements.count(mkSymbol(c))); + REQUIRE_EQ(3, m.size()); + REQUIRE(m.count(mkSymbol(a))); + REQUIRE(m.count(mkSymbol(b))); + REQUIRE(m.count(mkSymbol(c))); - CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(a)])); - CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(b)])); - CHECK_EQ("boolean | number", toString(m.NEW_refinements[mkSymbol(c)])); + CHECK_EQ("string", toString(m[mkSymbol(a)])); + CHECK_EQ("string", toString(m[mkSymbol(b)])); + CHECK_EQ("boolean | number", toString(m[mkSymbol(c)])); } TEST_CASE("one_map_has_overlap_at_end_whereas_other_has_it_in_start") { - ScopedFastFlag sff{"LuauLValueAsKey", true}; - std::string a = "a"; std::string b = "b"; std::string c = "c"; @@ -125,18 +119,18 @@ TEST_CASE("one_map_has_overlap_at_end_whereas_other_has_it_in_start") TypeArena arena; merge(arena, m, other); - REQUIRE_EQ(5, m.NEW_refinements.size()); - REQUIRE(m.NEW_refinements.count(mkSymbol(a))); - REQUIRE(m.NEW_refinements.count(mkSymbol(b))); - REQUIRE(m.NEW_refinements.count(mkSymbol(c))); - REQUIRE(m.NEW_refinements.count(mkSymbol(d))); - REQUIRE(m.NEW_refinements.count(mkSymbol(e))); - - CHECK_EQ("string", toString(m.NEW_refinements[mkSymbol(a)])); - CHECK_EQ("number", toString(m.NEW_refinements[mkSymbol(b)])); - CHECK_EQ("boolean | string", toString(m.NEW_refinements[mkSymbol(c)])); - CHECK_EQ("number", toString(m.NEW_refinements[mkSymbol(d)])); - CHECK_EQ("boolean", toString(m.NEW_refinements[mkSymbol(e)])); + REQUIRE_EQ(5, m.size()); + REQUIRE(m.count(mkSymbol(a))); + REQUIRE(m.count(mkSymbol(b))); + REQUIRE(m.count(mkSymbol(c))); + REQUIRE(m.count(mkSymbol(d))); + REQUIRE(m.count(mkSymbol(e))); + + CHECK_EQ("string", toString(m[mkSymbol(a)])); + CHECK_EQ("number", toString(m[mkSymbol(b)])); + CHECK_EQ("boolean | string", toString(m[mkSymbol(c)])); + CHECK_EQ("number", toString(m[mkSymbol(d)])); + CHECK_EQ("boolean", toString(m[mkSymbol(e)])); } TEST_CASE("hashing_lvalue_global_prop_access") @@ -159,7 +153,7 @@ TEST_CASE("hashing_lvalue_global_prop_access") CHECK_EQ(LValueHasher{}(t_x1), LValueHasher{}(t_x2)); CHECK_EQ(LValueHasher{}(t_x2), LValueHasher{}(t_x2)); - NEW_RefinementMap m; + RefinementMap m; m[t_x1] = getSingletonTypes().stringType; m[t_x2] = getSingletonTypes().numberType; @@ -188,7 +182,7 @@ TEST_CASE("hashing_lvalue_local_prop_access") CHECK_NE(LValueHasher{}(t_x1), LValueHasher{}(t_x2)); CHECK_EQ(LValueHasher{}(t_x2), LValueHasher{}(t_x2)); - NEW_RefinementMap m; + RefinementMap m; m[t_x1] = getSingletonTypes().stringType; m[t_x2] = getSingletonTypes().numberType; diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 5ad06f0d7..d1cc49b2d 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -54,6 +54,17 @@ return _ CHECK_EQ(result.warnings[0].text, "Placeholder value '_' is read here; consider using a named variable"); } +TEST_CASE_FIXTURE(Fixture, "PlaceholderReadGlobal") +{ + LintResult result = lint(R"( +_ = 5 +print(_) +)"); + + CHECK_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "Placeholder value '_' is read here; consider using a named variable"); +} + TEST_CASE_FIXTURE(Fixture, "PlaceholderWrite") { LintResult result = lint(R"( @@ -853,7 +864,7 @@ string.format("%Y") local _ = ("%"):format() -- correct format strings, just to uh make sure -string.format("hello %d %f", 4, 5) +string.format("hello %+10d %.02f %%", 4, 5) )"); CHECK_EQ(result.warnings.size(), 4); @@ -1078,16 +1089,18 @@ TEST_CASE_FIXTURE(Fixture, "FormatStringDate") os.date("%") os.date("%L") os.date("%?") +os.date("\0") -- correct formats os.date("it's %c now") os.date("!*t") )"); - CHECK_EQ(result.warnings.size(), 3); + CHECK_EQ(result.warnings.size(), 4); CHECK_EQ(result.warnings[0].text, "Invalid date format: unfinished replacement"); CHECK_EQ(result.warnings[1].text, "Invalid date format: unexpected replacement character; must be a date format specifier or %"); CHECK_EQ(result.warnings[2].text, "Invalid date format: unexpected replacement character; must be a date format specifier or %"); + CHECK_EQ(result.warnings[3].text, "Invalid date format: date format can not contain null characters"); } TEST_CASE_FIXTURE(Fixture, "FormatStringTyped") @@ -1396,8 +1409,6 @@ end TEST_CASE_FIXTURE(Fixture, "TableOperations") { - ScopedFastFlag sff("LuauLintTableCreateTable", true); - LintResult result = lintTyped(R"( local t = {} local tt = {} @@ -1435,8 +1446,10 @@ table.create(42, {} :: {}) "table.insert may change behavior if the call returns more than one result; consider adding parentheses around second argument"); CHECK_EQ(result.warnings[6].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); CHECK_EQ(result.warnings[7].text, "table.move uses index 0 but arrays are 1-based; did you mean 1 instead?"); - CHECK_EQ(result.warnings[8].text, "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); - CHECK_EQ(result.warnings[9].text, "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); + CHECK_EQ( + result.warnings[8].text, "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); + CHECK_EQ( + result.warnings[9].text, "table.create with a table literal will reuse the same object for all elements; consider using a for loop instead"); } TEST_CASE_FIXTURE(Fixture, "DuplicateConditions") diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 90831ee9d..c1a8887b6 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -7,8 +7,6 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauFixAmbiguousErrorRecoveryInAssign) - using namespace Luau; namespace @@ -1639,10 +1637,7 @@ TEST_CASE_FIXTURE(Fixture, "parse_error_confusing_function_call") "Ambiguous syntax: this looks like an argument list for a function call, but could also be a start of new statement; use ';' to separate " "statements"); - if (FFlag::LuauFixAmbiguousErrorRecoveryInAssign) - CHECK(result4.errors.size() == 1); - else - CHECK(result4.errors.size() == 5); + CHECK(result4.errors.size() == 1); } TEST_CASE_FIXTURE(Fixture, "parse_error_varargs") diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 86165814c..572b882d8 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -209,8 +209,6 @@ TEST_CASE_FIXTURE(Fixture, "as_expr_does_not_propagate_type_info") TEST_CASE_FIXTURE(Fixture, "as_expr_is_bidirectional") { - ScopedFastFlag sff{"LuauBidirectionalAsExpr", true}; - CheckResult result = check(R"( local a = 55 :: number? local b = a :: number @@ -224,7 +222,6 @@ TEST_CASE_FIXTURE(Fixture, "as_expr_is_bidirectional") TEST_CASE_FIXTURE(Fixture, "as_expr_warns_on_unrelated_cast") { - ScopedFastFlag sff{"LuauBidirectionalAsExpr", true}; ScopedFastFlag sff2{"LuauErrorRecoveryType", true}; CheckResult result = check(R"( diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 5e08654a7..6730bedb0 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -889,4 +889,55 @@ TEST_CASE_FIXTURE(Fixture, "dont_add_definitions_to_persistent_types") REQUIRE(gtv->definition); } +TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types") +{ + ScopedFastFlag sff[]{ + {"LuauAssertStripsFalsyTypes", true}, + {"LuauDiscriminableUnions", true}, + }; + + CheckResult result = check(R"( + local function f(x: (number | boolean)?) + return assert(x) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("((boolean | number)?) -> number | true", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types_even_from_type_pack_tail_but_only_for_the_first_type") +{ + ScopedFastFlag sff[]{ + {"LuauAssertStripsFalsyTypes", true}, + {"LuauDiscriminableUnions", true}, + }; + + CheckResult result = check(R"( + local function f(...: number?) + return assert(...) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(...number?) -> (number, ...number?)", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "assert_returns_false_and_string_iff_it_knows_the_first_argument_cannot_be_truthy") +{ + ScopedFastFlag sff[]{ + {"LuauAssertStripsFalsyTypes", true}, + {"LuauDiscriminableUnions", true}, + }; + + CheckResult result = check(R"( + local function f(x: nil) + return assert(x, "hmm") + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("(nil) -> nil", toString(requireType("f"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 47c13be9a..e5eb0dca0 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -176,19 +176,6 @@ TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a REQUIRE_EQ("{| [any]: any, x: number, y: number |}", toString(requireType("b"))); } -TEST_CASE_FIXTURE(Fixture, "normal_conditional_expression_has_refinements") -{ - CheckResult result = check(R"( - local foo: {x: number}? = nil - local bar = foo and foo.x -- TODO: Geez. We are inferring the wrong types here. Should be 'number?'. - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // Binary and/or return types are straight up wrong. JIRA: CLI-40300 - CHECK_EQ("boolean | number", toString(requireType("bar"))); -} - // Luau currently doesn't yet know how to allow assignments when the binding was refined. TEST_CASE_FIXTURE(Fixture, "while_body_are_also_refined") { diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index f346ddfdf..3a610c3a1 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -939,8 +939,6 @@ TEST_CASE_FIXTURE(Fixture, "type_comparison_ifelse_expression") TEST_CASE_FIXTURE(Fixture, "correctly_lookup_a_shadowed_local_that_which_was_previously_refined") { - ScopedFastFlag sff{"LuauLValueAsKey", true}; - CheckResult result = check(R"( local foo: string? = "hi" assert(foo) @@ -955,8 +953,6 @@ TEST_CASE_FIXTURE(Fixture, "correctly_lookup_a_shadowed_local_that_which_was_pre TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_refined") { - ScopedFastFlag sff{"LuauLValueAsKey", true}; - CheckResult result = check(R"( type T = {x: string | number} local t: T? = {x = "hi"} @@ -974,8 +970,6 @@ TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_ TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_refined2") { - ScopedFastFlag sff{"LuauLValueAsKey", true}; - CheckResult result = check(R"( type T = { x: { y: number }? } @@ -993,8 +987,6 @@ TEST_CASE_FIXTURE(Fixture, "correctly_lookup_property_whose_base_was_previously_ TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string") { - ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true}; - CheckResult result = check(R"( type T = { [string]: { prop: number }? } local t: T = {} @@ -1061,27 +1053,62 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_tag") CHECK_EQ("Dog", toString(requireTypeAtPosition({9, 33}))); } -TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscript_expr_is_constant_string") +TEST_CASE_FIXTURE(Fixture, "and_or_peephole_refinement") { - ScopedFastFlag sff{"LuauRefiLookupFromIndexExpr", true}; + CheckResult result = check(R"( + local function len(a: {any}) + return a and #a or nil + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "narrow_boolean_to_true_or_false") +{ + ScopedFastFlag sff[]{ + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + {"LuauDiscriminableUnions", true}, + {"LuauAssertStripsFalsyTypes", true}, + }; CheckResult result = check(R"( - type T = { [string]: { prop: number }? } - local t: T = {} + local function is_true(b: true) end + local function is_false(b: false) end - if t["hello"] then - local foo = t["hello"].prop + local function f(x: boolean) + if x then + is_true(x) + else + is_false(x) + end end )"); LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "and_or_peephole_refinement") +TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false") { + ScopedFastFlag sff[]{ + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + {"LuauDiscriminableUnions", true}, + {"LuauAssertStripsFalsyTypes", true}, + }; + CheckResult result = check(R"( - local function len(a: {any}) - return a and #a or nil + type Ok = { ok: true, value: T } + type Err = { ok: false, error: E } + type Result = Ok | Err + + local function apply(t: Result, f: (T) -> (), g: (E) -> ()) + if t.ok then + f(t.value) + else + g(t.error) + end end )"); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 483109213..f19cb618b 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1482,7 +1482,7 @@ TEST_CASE_FIXTURE(Fixture, "casting_unsealed_tables_with_props_into_table_with_i REQUIRE(tm); CHECK_EQ("{| [string]: string |}", toString(tm->wantedType, o)); // Should t now have an indexer? - // It would if the assignment to rt was correctly typed. + // It would if the assignment to rt was correctly typed. CHECK_EQ("{ [string]: string, foo: number }", toString(tm->givenType, o)); } @@ -2082,7 +2082,7 @@ caused by: TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, {"LuauTableSubtypingVariance2", true}, @@ -2103,7 +2103,7 @@ a.p = { x = 9 } TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_error") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, {"LuauTableSubtypingVariance2", true}, @@ -2131,7 +2131,7 @@ caused by: TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table_with_indexer") { - ScopedFastFlag sffs[] { + ScopedFastFlag sffs[]{ {"LuauPropertiesGetExpectedType", true}, {"LuauExpectedTypesOfProperties", true}, {"LuauTableSubtypingVariance2", true}, diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index c9b30e1a2..ead3d762c 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -2429,21 +2429,6 @@ TEST_CASE_FIXTURE(Fixture, "should_be_able_to_infer_this_without_stack_overflowi LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "x_or_y_forces_both_x_and_y_to_be_of_same_type_if_either_is_free") -{ - CheckResult result = check(R"( - local function f(x, y) return x or y end - - local x = f(1, 2) - local y = f(3, "foo") - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(*requireType("x"), *typeChecker.numberType); - - CHECK_EQ(result.errors[0], (TypeError{Location{{4, 23}, {4, 28}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); -} - TEST_CASE_FIXTURE(Fixture, "inferring_hundreds_of_self_calls_should_not_suffocate_memory") { CheckResult result = check(R"( @@ -4509,7 +4494,7 @@ f(function(x) print(x) end) } TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") -{ +{ ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; CheckResult result = check(R"( @@ -4777,7 +4762,7 @@ local a: X = if true then {"1", 2, 3} else {4, 5, 6} TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_2") { ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true}; - ScopedFastFlag luauIfElseBranchTypeUnion{ "LuauIfElseBranchTypeUnion", true }; + ScopedFastFlag luauIfElseBranchTypeUnion{"LuauIfElseBranchTypeUnion", true}; CheckResult result = check(R"( local a: number? = if true then 1 else nil @@ -5012,16 +4997,14 @@ local b: B = a )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ( - toString(result.errors[0]), R"(Type '(number, number) -> (number, string)' could not be converted into '(number, number) -> (number, boolean)' + CHECK_EQ(toString(result.errors[0]), + R"(Type '(number, number) -> (number, string)' could not be converted into '(number, number) -> (number, boolean)' caused by: Return #2 type is not compatible. Type 'string' could not be converted into 'boolean')"); } TEST_CASE_FIXTURE(Fixture, "prop_access_on_any_with_other_options") { - ScopedFastFlag sff{"LuauLValueAsKey", true}; - CheckResult result = check(R"( local function f(thing: any | string) local foo = thing.SomeRandomKey @@ -5120,4 +5103,65 @@ end )"); } +TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") +{ + ScopedFastFlag committingTxnLog{"LuauUseCommittingTxnLog", true}; + ScopedFastFlag subtypingVariance{"LuauTableSubtypingVariance2", true}; + + CheckResult result = check(R"( + --!strict + --!nolint + + type FieldSpecifier = { + fieldName: string, + } + + type ReadFieldOptions = FieldSpecifier & { from: number? } + + type Policies = { + getStoreFieldName: (self: Policies, fieldSpec: FieldSpecifier) -> string, + } + + local Policies = {} + + local function foo(p: Policies) + end + + function Policies:getStoreFieldName(specifier: FieldSpecifier): string + return "" + end + + function Policies:readField(options: ReadFieldOptions) + local _ = self:getStoreFieldName(options) + --[[ + Type error: + TypeError { "MainModule", Location { { line = 25, col = 16 }, { line = 25, col = 20 } }, TypeMismatch { Policies, {- getStoreFieldName: (tp1) -> (a, b...) -} } } + ]] + foo(self) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types") +{ + ScopedFastFlag noSealedTypeMod{"LuauNoSealedTypeMod", true}; + + fileResolver.source["game/A"] = R"( +export type Type = { unrelated: boolean } +return {} + )"; + + fileResolver.source["game/B"] = R"( +local types = require(game.A) +type Type = types.Type +local x: Type = {} +function x:Destroy(): () end + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 1e790eba9..0aeca0965 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -260,4 +260,17 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly") CHECK(unifyErrors.size() == 0); } +TEST_CASE_FIXTURE(TryUnifyFixture, "recursive_metatable_getmatchtag") +{ + ScopedFastFlag luauUnionTagMatchFix{"LuauUnionTagMatchFix", true}; + + TypeVar redirect{FreeTypeVar{TypeLevel{}}}; + TypeVar table{TableTypeVar{}}; + TypeVar metatable{MetatableTypeVar{&redirect, &table}}; + redirect = BoundTypeVar{&metatable}; // Now we have a metatable that is recursive on the table type + TypeVar variant{UnionTypeVar{{&metatable, typeChecker.numberType}}}; + + state.tryUnify(&metatable, &variant); +} + TEST_SUITE_END(); diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index 188b8ebc4..de091632d 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -118,6 +118,10 @@ assert((function() return #_G end)() == 0) assert((function() return #{1,2} end)() == 2) assert((function() return #'g' end)() == 1) +local ud = newproxy(true) +getmetatable(ud).__len = function() return 42 end +assert((function() return #ud end)() == 42) + assert((function() local a = 1 a = -a return a end)() == -1) -- while/repeat diff --git a/tests/conformance/vararg.lua b/tests/conformance/vararg.lua index 5aaa422be..d05f95776 100644 --- a/tests/conformance/vararg.lua +++ b/tests/conformance/vararg.lua @@ -105,6 +105,45 @@ assert(a==5 and b==4 and c==3 and d==2 and e==1) a,b,c,d,e = f(4) assert(a==nil and b==nil and c==nil and d==nil and e==nil) +-- select tests +a = {select(3, unpack{10,20,30,40})} +assert(table.getn(a) == 2 and a[1] == 30 and a[2] == 40) +a = {select(1)} +assert(next(a) == nil) +a = {select(-1, 3, 5, 7)} +assert(a[1] == 7 and a[2] == nil) +a = {select(-2, 3, 5, 7)} +assert(a[1] == 5 and a[2] == 7 and a[3] == nil) +pcall(select, 10000) +pcall(select, -10000) + +-- select(_, ...) has special optimizations so it needs extra testing +function selectone(n, ...) + local e = select(n, ...) + return e +end + +function selectmany(n, ...) + return table.concat({select(n, ...)}, ',') +end + +assert(selectone('#') == 0) +assert(selectmany('#') == "0") + +assert(selectone('#', 10, 20, 30) == 3) +assert(selectmany('#', 10, 20, 30) == "3") + +assert(selectone(1, 10, 20, 30) == 10) +assert(selectmany(1, 10, 20, 30) == "10,20,30") + +assert(selectone(2, 10, 20, 30) == 20) +assert(selectmany(2, 10, 20, 30) == "20,30") + +assert(selectone(-2, 10, 20, 30) == 20) +assert(selectmany(-2, 10, 20, 30) == "20,30") + +assert(selectone('3', 10, 20, 30) == 30) +assert(selectmany('3', 10, 20, 30) == "30") -- varargs for main chunks f = loadstring[[ return {...} ]] @@ -122,16 +161,5 @@ f = loadstring[[ assert(f("a", "b", nil, {}, assert)) assert(f()) -a = {select(3, unpack{10,20,30,40})} -assert(table.getn(a) == 2 and a[1] == 30 and a[2] == 40) -a = {select(1)} -assert(next(a) == nil) -a = {select(-1, 3, 5, 7)} -assert(a[1] == 7 and a[2] == nil) -a = {select(-2, 3, 5, 7)} -assert(a[1] == 5 and a[2] == 7 and a[3] == nil) -pcall(select, 10000) -pcall(select, -10000) - return('OK') From 4e60eec1fc97132bee2f5bb09664b08f235f2c55 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 3 Feb 2022 16:31:50 -0800 Subject: [PATCH 22/32] Apply fix to the crash --- Analysis/src/TypeInfer.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 4d25fe2e3..b9096d2eb 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -4977,6 +4977,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (notEnoughParameters && hasDefaultParameters) { // 'applyTypeFunction' is used to substitute default types that reference previous generic types + applyTypeFunction.log = TxnLog::empty(); applyTypeFunction.typeArguments.clear(); applyTypeFunction.typePackArguments.clear(); applyTypeFunction.currentModule = currentModule; From 4748777ce850813d639e97338679357171d98289 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 3 Feb 2022 16:43:44 -0800 Subject: [PATCH 23/32] Fix isocline warnings --- extern/isocline/src/isocline.c | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/extern/isocline/src/isocline.c b/extern/isocline/src/isocline.c index 132780628..8b6055cf9 100644 --- a/extern/isocline/src/isocline.c +++ b/extern/isocline/src/isocline.c @@ -13,7 +13,12 @@ // $ gcc -c src/isocline.c //------------------------------------------------------------- #if !defined(IC_SEPARATE_OBJS) -# define _CRT_SECURE_NO_WARNINGS // for msvc +# ifndef _CRT_NONSTDC_NO_WARNINGS +# define _CRT_NONSTDC_NO_WARNINGS // for msvc +# endif +# ifndef _CRT_SECURE_NO_WARNINGS +# define _CRT_SECURE_NO_WARNINGS // for msvc +# endif # define _XOPEN_SOURCE 700 // for wcwidth # define _DEFAULT_SOURCE // ensure usleep stays visible with _XOPEN_SOURCE >= 700 # include "attr.c" From bbae46600635e43d875cc95392a1e8481f445524 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 4 Feb 2022 12:31:19 -0800 Subject: [PATCH 24/32] Sync to upstream/release/513 This takes the extra bug fix for generic name confusion --- Analysis/include/Luau/TypeInfer.h | 3 ++- Analysis/src/TypeInfer.cpp | 7 ++++--- tests/TypeInfer.generics.test.cpp | 25 +++++++++++++++++++++++++ 3 files changed, 31 insertions(+), 4 deletions(-) diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 90dc9f426..f61ecbf52 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -346,7 +346,8 @@ struct TypeChecker // Note: `scope` must be a fresh scope. GenericTypeDefinitions createGenericTypes(const ScopePtr& scope, std::optional levelOpt, const AstNode& node, - const AstArray& genericNames, const AstArray& genericPackNames); + const AstArray& genericNames, const AstArray& genericPackNames, + bool useCache = false); public: ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 4d25fe2e3..e1987937b 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -29,6 +29,7 @@ LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as fals LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) +LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) LUAU_FASTFLAGVARIABLE(LuauIfElseBranchTypeUnion, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpectedType2, false) LUAU_FASTFLAGVARIABLE(LuauLengthOnCompositeType, false) @@ -1199,7 +1200,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias if (FFlag::LuauProperTypeLevels) aliasScope->level.subLevel = subLevel; - auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks); + auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks, /* useCache = */ true); TypeId ty = freshType(aliasScope); FreeTypeVar* ftv = getMutable(ty); @@ -5361,7 +5362,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, } GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, std::optional levelOpt, const AstNode& node, - const AstArray& genericNames, const AstArray& genericPackNames) + const AstArray& genericNames, const AstArray& genericPackNames, bool useCache) { LUAU_ASSERT(scope->parent); @@ -5387,7 +5388,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st } TypeId g; - if (FFlag::LuauRecursiveTypeParameterRestriction) + if (FFlag::LuauRecursiveTypeParameterRestriction && (!FFlag::LuauGenericFunctionsDontCacheTypeParams || useCache)) { TypeId& cached = scope->parent->typeAliasTypeParameters[n]; if (!cached) diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index a7f275515..8a2c6f27e 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -673,4 +673,29 @@ local d: D = c R"(Type '() -> ()' could not be converted into '() -> ()'; different number of generic type pack parameters)"); } +TEST_CASE_FIXTURE(Fixture, "generic_functions_dont_cache_type_parameters") +{ + ScopedFastFlag sff{"LuauGenericFunctionsDontCacheTypeParams", true}; + + CheckResult result = check(R"( +-- See https://github.com/Roblox/luau/issues/332 +-- This function has a type parameter with the same name as clones, +-- so if we cache type parameter names for functions these get confused. +-- function id(x : Z) : Z +function id(x : X) : X + return x +end + +function clone(dict: {[X]:Y}): {[X]:Y} + local copy = {} + for k, v in pairs(dict) do + copy[k] = v + end + return copy +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); From e9bf182585e4cfc3bdf9bf71fc74ca947027b8dd Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 11 Feb 2022 10:43:14 -0800 Subject: [PATCH 25/32] Sync to upstream/release/514 --- Analysis/include/Luau/Autocomplete.h | 2 + Analysis/include/Luau/Linter.h | 1 + Analysis/include/Luau/Substitution.h | 2 +- Analysis/include/Luau/TxnLog.h | 4 +- Analysis/include/Luau/TypeInfer.h | 11 ++ Analysis/include/Luau/TypePack.h | 3 - Analysis/include/Luau/TypeVar.h | 3 - Analysis/src/EmbeddedBuiltinDefinitions.cpp | 12 +- Analysis/src/Linter.cpp | 103 ++++++++++++-- Analysis/src/Module.cpp | 28 ++-- Analysis/src/Scope.cpp | 4 + Analysis/src/TxnLog.cpp | 8 ++ Analysis/src/TypeInfer.cpp | 101 +++++++++++--- Analysis/src/TypeVar.cpp | 4 +- Analysis/src/Unifier.cpp | 125 +++++++++++------ Ast/src/Parser.cpp | 2 +- CLI/Ast.cpp | 86 ++++++++++++ CLI/Repl.cpp | 68 ++++++--- CLI/Repl.h | 7 +- CMakeLists.txt | 14 +- Compiler/src/BytecodeBuilder.cpp | 8 +- Compiler/src/Compiler.cpp | 62 ++------- Sources.cmake | 8 ++ VM/src/lbuiltins.cpp | 2 +- VM/src/lgcdebug.cpp | 4 + VM/src/lmem.cpp | 127 ++++++++++++----- VM/src/lobject.cpp | 4 +- VM/src/lvmexecute.cpp | 3 +- extern/isocline/src/isocline.c | 7 +- tests/Compiler.test.cpp | 4 - tests/Conformance.test.cpp | 2 - tests/Linter.test.cpp | 36 ++++- tests/Repl.test.cpp | 92 ++++++++++++ tests/TypeInfer.aliases.test.cpp | 61 ++++++++ tests/TypeInfer.builtins.test.cpp | 15 +- tests/TypeInfer.provisional.test.cpp | 74 +++++++++- tests/TypeInfer.refinements.test.cpp | 22 +-- tests/TypeInfer.test.cpp | 147 ++++++++++++++++++++ tests/TypeInfer.tryUnify.test.cpp | 17 +++ tests/conformance/basic.lua | 10 +- tests/conformance/debug.lua | 1 + tests/conformance/errors.lua | 16 ++- tests/conformance/gc.lua | 7 +- tests/conformance/math.lua | 1 + tests/conformance/vararg.lua | 6 + tests/conformance/vector.lua | 9 ++ tools/heapgraph.py | 33 +++-- tools/svg.py | 2 +- 48 files changed, 1100 insertions(+), 268 deletions(-) create mode 100644 CLI/Ast.cpp diff --git a/Analysis/include/Luau/Autocomplete.h b/Analysis/include/Luau/Autocomplete.h index 585342933..65b788d35 100644 --- a/Analysis/include/Luau/Autocomplete.h +++ b/Analysis/include/Luau/Autocomplete.h @@ -86,6 +86,8 @@ struct OwningAutocompleteResult }; AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName, Position position, StringCompletionCallback callback); + +// Deprecated, do not use in new work. OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view source, Position position, StringCompletionCallback callback); } // namespace Luau diff --git a/Analysis/include/Luau/Linter.h b/Analysis/include/Luau/Linter.h index 1f7f7f9dc..ec3c124d7 100644 --- a/Analysis/include/Luau/Linter.h +++ b/Analysis/include/Luau/Linter.h @@ -49,6 +49,7 @@ struct LintWarning Code_DeprecatedApi = 22, Code_TableOperations = 23, Code_DuplicateCondition = 24, + Code_MisleadingAndOr = 25, Code__Count }; diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index 4f3307cdf..f85b42690 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -93,7 +93,7 @@ struct Tarjan // This should never be null; ensure you initialize it before calling // substitution methods. - const TxnLog* log; + const TxnLog* log = nullptr; std::vector edgesTy; std::vector edgesTp; diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index 02b873748..f238e258a 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -307,8 +307,8 @@ struct TxnLog // // We can't use a DenseHashMap here because we need a non-const iterator // over the map when we concatenate. - std::unordered_map> typeVarChanges; - std::unordered_map> typePackChanges; + std::unordered_map, DenseHashPointer> typeVarChanges; + std::unordered_map, DenseHashPointer> typePackChanges; TxnLog* parent = nullptr; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index f61ecbf52..5592fa1f5 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -103,6 +103,11 @@ struct GenericTypeDefinitions std::vector genericPacks; }; +struct HashBoolNamePair +{ + size_t operator()(const std::pair& pair) const; +}; + // All TypeVars are retained via Environment::typeVars. All TypeIds // within a program are borrowed pointers into this set. struct TypeChecker @@ -411,6 +416,12 @@ struct TypeChecker private: int checkRecursionCount = 0; int recursionCount = 0; + + /** + * We use this to avoid doing second-pass analysis of type aliases that are duplicates. We record a pair + * (exported, name) to properly deal with the case where the two duplicates do not have the same export status. + */ + DenseHashSet, HashBoolNamePair> duplicateTypeAliases; }; // Unit test hook diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index ca588ccb7..c74bad114 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -54,9 +54,6 @@ struct TypePackVar bool persistent = false; // Pointer to the type arena that allocated this type. - // Do not depend on the value of this under any circumstances. This is for - // debugging purposes only. This is only set in debug builds; it is nullptr - // in all other environments. TypeArena* owningArena = nullptr; }; diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 11dc93773..8d1a9fa6c 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -449,9 +449,6 @@ struct TypeVar final std::optional documentationSymbol; // Pointer to the type arena that allocated this type. - // Do not depend on the value of this under any circumstances. This is for - // debugging purposes only. This is only set in debug builds; it is nullptr - // in all other environments. TypeArena* owningArena = nullptr; bool operator==(const TypeVar& rhs) const; diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 249825067..f3ef88fc5 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -1,8 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -LUAU_FASTFLAGVARIABLE(LuauFixTonumberReturnType, false) - namespace Luau { @@ -115,6 +113,7 @@ declare function gcinfo(): number declare function error(message: T, level: number?) declare function tostring(value: T): string + declare function tonumber(value: T, radix: number?): number? declare function rawequal(a: T1, b: T2): boolean declare function rawget(tab: {[K]: V}, k: K): V @@ -200,14 +199,7 @@ declare function gcinfo(): number std::string getBuiltinDefinitionSource() { - std::string result = kBuiltinDefinitionLuaSrc; - - if (FFlag::LuauFixTonumberReturnType) - result += "declare function tonumber(value: T, radix: number?): number?\n"; - else - result += "declare function tonumber(value: T, radix: number?): number\n"; - - return result; + return kBuiltinDefinitionLuaSrc; } } // namespace Luau diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 57a33e931..2ba6a0fce 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -43,6 +43,7 @@ static const char* kWarningNames[] = { "DeprecatedApi", "TableOperations", "DuplicateCondition", + "MisleadingAndOr", }; // clang-format on @@ -2040,18 +2041,28 @@ class LintDeprecatedApi : AstVisitor const Property* prop = lookupClassProp(cty, node->index.value); if (prop && prop->deprecated) - { - if (!prop->deprecatedSuggestion.empty()) - emitWarning(*context, LintWarning::Code_DeprecatedApi, node->location, "Member '%s.%s' is deprecated, use '%s' instead", - cty->name.c_str(), node->index.value, prop->deprecatedSuggestion.c_str()); - else - emitWarning(*context, LintWarning::Code_DeprecatedApi, node->location, "Member '%s.%s' is deprecated", cty->name.c_str(), - node->index.value); - } + report(node->location, *prop, cty->name.c_str(), node->index.value); + } + else if (const TableTypeVar* tty = get(follow(*ty))) + { + auto prop = tty->props.find(node->index.value); + + if (prop != tty->props.end() && prop->second.deprecated) + report(node->location, prop->second, tty->name ? tty->name->c_str() : nullptr, node->index.value); } return true; } + + void report(const Location& location, const Property& prop, const char* container, const char* field) + { + std::string suggestion = prop.deprecatedSuggestion.empty() ? "" : format(", use '%s' instead", prop.deprecatedSuggestion.c_str()); + + if (container) + emitWarning(*context, LintWarning::Code_DeprecatedApi, location, "Member '%s.%s' is deprecated%s", container, field, suggestion.c_str()); + else + emitWarning(*context, LintWarning::Code_DeprecatedApi, location, "Member '%s' is deprecated%s", field, suggestion.c_str()); + } }; class LintTableOperations : AstVisitor @@ -2257,6 +2268,39 @@ class LintDuplicateCondition : AstVisitor return false; } + bool visit(AstExprIfElse* expr) override + { + if (!expr->falseExpr->is()) + return true; + + // if..elseif chain detected, we need to unroll it + std::vector conditions; + conditions.reserve(2); + + AstExprIfElse* head = expr; + while (head) + { + head->condition->visit(this); + head->trueExpr->visit(this); + + conditions.push_back(head->condition); + + if (head->falseExpr->is()) + { + head = head->falseExpr->as(); + continue; + } + + head->falseExpr->visit(this); + break; + } + + detectDuplicates(conditions); + + // block recursive visits so that we only analyze each chain once + return false; + } + bool visit(AstExprBinary* expr) override { if (expr->op != AstExprBinary::And && expr->op != AstExprBinary::Or) @@ -2418,6 +2462,46 @@ class LintDuplicateLocal : AstVisitor } }; +class LintMisleadingAndOr : AstVisitor +{ +public: + LUAU_NOINLINE static void process(LintContext& context) + { + LintMisleadingAndOr pass; + pass.context = &context; + + context.root->visit(&pass); + } + +private: + LintContext* context; + + bool visit(AstExprBinary* node) override + { + if (node->op != AstExprBinary::Or) + return true; + + AstExprBinary* and_ = node->left->as(); + if (!and_ || and_->op != AstExprBinary::And) + return true; + + const char* alt = nullptr; + + if (and_->right->is()) + alt = "nil"; + else if (AstExprConstantBool* c = and_->right->as(); c && c->value == false) + alt = "false"; + + if (alt) + emitWarning(*context, LintWarning::Code_MisleadingAndOr, node->location, + "The and-or expression always evaluates to the second alternative because the first alternative is %s; consider using if-then-else " + "expression instead", + alt); + + return true; + } +}; + static void fillBuiltinGlobals(LintContext& context, const AstNameTable& names, const ScopePtr& env) { ScopePtr current = env; @@ -2522,6 +2606,9 @@ std::vector lint(AstStat* root, const AstNameTable& names, const Sc if (context.warningEnabled(LintWarning::Code_DuplicateLocal)) LintDuplicateLocal::process(context); + if (context.warningEnabled(LintWarning::Code_MisleadingAndOr)) + LintMisleadingAndOr::process(context); + std::sort(context.result.begin(), context.result.end(), WarningComparator()); return context.result; diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 4fdff8f7a..817a33e9f 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -12,10 +12,10 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) -LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) +LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) // Remove with FFlagLuauImmutableTypes LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) LUAU_FASTFLAG(LuauTypeAliasDefaults) - +LUAU_FASTFLAG(LuauImmutableTypes) LUAU_FASTFLAGVARIABLE(LuauPrepopulateUnionOptionsBeforeAllocation, false) namespace Luau @@ -66,7 +66,7 @@ TypeId TypeArena::addTV(TypeVar&& tv) { TypeId allocated = typeVars.allocate(std::move(tv)); - if (FFlag::DebugLuauTrackOwningArena) + if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) asMutable(allocated)->owningArena = this; return allocated; @@ -76,7 +76,7 @@ TypeId TypeArena::freshType(TypeLevel level) { TypeId allocated = typeVars.allocate(FreeTypeVar{level}); - if (FFlag::DebugLuauTrackOwningArena) + if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) asMutable(allocated)->owningArena = this; return allocated; @@ -86,7 +86,7 @@ TypePackId TypeArena::addTypePack(std::initializer_list types) { TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); - if (FFlag::DebugLuauTrackOwningArena) + if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) asMutable(allocated)->owningArena = this; return allocated; @@ -96,7 +96,7 @@ TypePackId TypeArena::addTypePack(std::vector types) { TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); - if (FFlag::DebugLuauTrackOwningArena) + if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) asMutable(allocated)->owningArena = this; return allocated; @@ -106,7 +106,7 @@ TypePackId TypeArena::addTypePack(TypePack tp) { TypePackId allocated = typePacks.allocate(std::move(tp)); - if (FFlag::DebugLuauTrackOwningArena) + if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) asMutable(allocated)->owningArena = this; return allocated; @@ -116,7 +116,7 @@ TypePackId TypeArena::addTypePack(TypePackVar tp) { TypePackId allocated = typePacks.allocate(std::move(tp)); - if (FFlag::DebugLuauTrackOwningArena) + if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) asMutable(allocated)->owningArena = this; return allocated; @@ -454,8 +454,16 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks, cloneState}; Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. - // TODO: Make this work when the arena of 'res' might be frozen - asMutable(res)->documentationSymbol = typeId->documentationSymbol; + if (FFlag::LuauImmutableTypes) + { + // Persistent types are not being cloned and we get the original type back which might be read-only + if (!res->persistent) + asMutable(res)->documentationSymbol = typeId->documentationSymbol; + } + else + { + asMutable(res)->documentationSymbol = typeId->documentationSymbol; + } } return res; diff --git a/Analysis/src/Scope.cpp b/Analysis/src/Scope.cpp index c30db9c25..0a362a5eb 100644 --- a/Analysis/src/Scope.cpp +++ b/Analysis/src/Scope.cpp @@ -2,6 +2,8 @@ #include "Luau/Scope.h" +LUAU_FASTFLAG(LuauTwoPassAliasDefinitionFix); + namespace Luau { @@ -17,6 +19,8 @@ Scope::Scope(const ScopePtr& parent, int subLevel) , returnType(parent->returnType) , level(parent->level.incr()) { + if (FFlag::LuauTwoPassAliasDefinitionFix) + level = level.incr(); level.subLevel = subLevel; } diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 0968a4c10..00067bdd1 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -250,6 +250,10 @@ PendingTypePack* TxnLog::queue(TypePackId tp) PendingType* TxnLog::pending(TypeId ty) const { + // This function will technically work if `this` is nullptr, but this + // indicates a bug, so we explicitly assert. + LUAU_ASSERT(static_cast(this) != nullptr); + for (const TxnLog* current = this; current; current = current->parent) { if (auto it = current->typeVarChanges.find(ty); it != current->typeVarChanges.end()) @@ -261,6 +265,10 @@ PendingType* TxnLog::pending(TypeId ty) const PendingTypePack* TxnLog::pending(TypePackId tp) const { + // This function will technically work if `this` is nullptr, but this + // indicates a bug, so we explicitly assert. + LUAU_ASSERT(static_cast(this) != nullptr); + for (const TxnLog* current = this; current; current = current->parent) { if (auto it = current->typePackChanges.find(tp); it != current->typePackChanges.end()) diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index e1987937b..f1c314cd2 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -32,12 +32,13 @@ LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) LUAU_FASTFLAGVARIABLE(LuauIfElseBranchTypeUnion, false) LUAU_FASTFLAGVARIABLE(LuauIfElseExpectedType2, false) +LUAU_FASTFLAGVARIABLE(LuauImmutableTypes, false) LUAU_FASTFLAGVARIABLE(LuauLengthOnCompositeType, false) LUAU_FASTFLAGVARIABLE(LuauNoSealedTypeMod, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) -LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions, false) +LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions2, false) LUAU_FASTFLAGVARIABLE(LuauTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) @@ -47,7 +48,10 @@ LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) LUAU_FASTFLAG(LuauUnionTagMatchFix) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) +LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) LUAU_FASTFLAGVARIABLE(LuauAssertStripsFalsyTypes, false) +LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. +LUAU_FASTFLAGVARIABLE(LuauAnotherTypeLevelFix, false) namespace Luau { @@ -213,6 +217,11 @@ static bool isMetamethod(const Name& name) name == "__metatable" || name == "__eq" || name == "__lt" || name == "__le" || name == "__mode"; } +size_t HashBoolNamePair::operator()(const std::pair& pair) const +{ + return std::hash()(pair.first) ^ std::hash()(pair.second); +} + TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHandler) : resolver(resolver) , iceHandler(iceHandler) @@ -225,6 +234,7 @@ TypeChecker::TypeChecker(ModuleResolver* resolver, InternalErrorReporter* iceHan , anyType(getSingletonTypes().anyType) , optionalNumberType(getSingletonTypes().optionalNumberType) , anyTypePack(getSingletonTypes().anyTypePack) + , duplicateTypeAliases{{false, {}}} { globalScope = std::make_shared(globalTypes.addTypePack(TypePackVar{FreeTypePack{TypeLevel{}}})); @@ -291,6 +301,9 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona unifierState.skipCacheForType.clear(); } + if (FFlag::LuauTwoPassAliasDefinitionFix) + duplicateTypeAliases.clear(); + return std::move(currentModule); } @@ -496,6 +509,9 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std { if (const auto& typealias = stat->as()) { + if (FFlag::LuauTwoPassAliasDefinitionFix && typealias->name == Parser::errorName) + continue; + auto& bindings = typealias->exported ? scope->exportedTypeBindings : scope->privateTypeBindings; Name name = typealias->name.value; @@ -1176,6 +1192,10 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias // Once with forwardDeclare, and once without. Name name = typealias.name.value; + // If the alias is missing a name, we can't do anything with it. Ignore it. + if (FFlag::LuauTwoPassAliasDefinitionFix && name == Parser::errorName) + return; + std::optional binding; if (auto it = scope->exportedTypeBindings.find(name); it != scope->exportedTypeBindings.end()) binding = it->second; @@ -1192,6 +1212,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias reportError(TypeError{typealias.location, DuplicateTypeDefinition{name, location}}); bindingsMap[name] = TypeFun{binding->typeParams, binding->typePackParams, errorRecoveryType(anyType)}; + if (FFlag::LuauTwoPassAliasDefinitionFix) + duplicateTypeAliases.insert({typealias.exported, name}); } else { @@ -1211,6 +1233,11 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias } else { + // If the first pass failed (this should mean a duplicate definition), the second pass isn't going to be + // interesting. + if (FFlag::LuauTwoPassAliasDefinitionFix && duplicateTypeAliases.find({typealias.exported, name})) + return; + if (!binding) ice("Not predeclared"); @@ -1235,7 +1262,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias if (auto ttv = getMutable(follow(ty))) { // If the table is already named and we want to rename the type function, we have to bind new alias to a copy - if (ttv->name) + // Additionally, we can't modify types that come from other modules + if (ttv->name || (FFlag::LuauImmutableTypes && follow(ty)->owningArena != ¤tModule->internalTypes)) { bool sameTys = std::equal(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), binding->typeParams.begin(), binding->typeParams.end(), [](auto&& itp, auto&& tp) { @@ -1247,7 +1275,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias }); // Copy can be skipped if this is an identical alias - if (ttv->name != name || !sameTys || !sameTps) + if ((FFlag::LuauImmutableTypes && !ttv->name) || ttv->name != name || !sameTys || !sameTps) { // This is a shallow clone, original recursive links to self are not updated TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; @@ -1279,9 +1307,17 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias } } else if (auto mtv = getMutable(follow(ty))) - mtv->syntheticName = name; + { + // We can't modify types that come from other modules + if (!FFlag::LuauImmutableTypes || follow(ty)->owningArena == ¤tModule->internalTypes) + mtv->syntheticName = name; + } + + TypeId& bindingType = bindingsMap[name].type; + bool ok = unify(ty, bindingType, typealias.location); - unify(ty, bindingsMap[name].type, typealias.location); + if (FFlag::LuauTwoPassAliasDefinitionFix && ok) + bindingType = ty; } } @@ -1564,7 +1600,12 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprCa else if (auto vtp = get(retPack)) return {vtp->ty, std::move(result.predicates)}; else if (get(retPack)) - ice("Unexpected abstract type pack!", expr.location); + { + if (FFlag::LuauReturnAnyInsteadOfICE) + return {anyType, std::move(result.predicates)}; + else + ice("Unexpected abstract type pack!", expr.location); + } else ice("Unknown TypePack type!", expr.location); } @@ -1614,11 +1655,23 @@ std::optional TypeChecker::getIndexTypeFromType( tablify(type); - const PrimitiveTypeVar* primitiveType = get(type); - if (primitiveType && primitiveType->type == PrimitiveTypeVar::String) + if (FFlag::LuauDiscriminableUnions2) { - if (std::optional mtIndex = findMetatableEntry(type, "__index", location)) + if (isString(type)) + { + std::optional mtIndex = findMetatableEntry(stringType, "__index", location); + LUAU_ASSERT(mtIndex); type = *mtIndex; + } + } + else + { + const PrimitiveTypeVar* primitiveType = get(type); + if (primitiveType && primitiveType->type == PrimitiveTypeVar::String) + { + if (std::optional mtIndex = findMetatableEntry(type, "__index", location)) + type = *mtIndex; + } } if (TableTypeVar* tableType = getMutableTableType(type)) @@ -2476,7 +2529,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right); - return {checkBinaryOperation(FFlag::LuauDiscriminableUnions ? scope : innerScope, expr, lhsTy, rhsTy), + return {checkBinaryOperation(FFlag::LuauDiscriminableUnions2 ? scope : innerScope, expr, lhsTy, rhsTy), {AndPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; } else if (expr.op == AstExprBinary::Or) @@ -2489,7 +2542,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi auto [rhsTy, rhsPredicates] = checkExpr(innerScope, *expr.right); // Because of C++, I'm not sure if lhsPredicates was not moved out by the time we call checkBinaryOperation. - TypeId result = checkBinaryOperation(FFlag::LuauDiscriminableUnions ? scope : innerScope, expr, lhsTy, rhsTy, lhsPredicates); + TypeId result = checkBinaryOperation(FFlag::LuauDiscriminableUnions2 ? scope : innerScope, expr, lhsTy, rhsTy, lhsPredicates); return {result, {OrPredicate{std::move(lhsPredicates), std::move(rhsPredicates)}}}; } else if (expr.op == AstExprBinary::CompareEq || expr.op == AstExprBinary::CompareNe) @@ -2497,8 +2550,8 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprBi if (auto predicate = tryGetTypeGuardPredicate(expr)) return {booleanType, {std::move(*predicate)}}; - ExprResult lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions); - ExprResult rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions); + ExprResult lhs = checkExpr(scope, *expr.left, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions2); + ExprResult rhs = checkExpr(scope, *expr.right, std::nullopt, /*forceSingleton=*/FFlag::LuauDiscriminableUnions2); PredicateVec predicates; @@ -2785,12 +2838,16 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) { - TypeId resultType = freshType(scope); + TypeId resultType = freshType(FFlag::LuauAnotherTypeLevelFix ? exprTable->level : scope->level); exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; return resultType; } else { + /* + * If we use [] indexing to fetch a property from a sealed table that has no indexer, we have no idea if it will + * work, so we just mint a fresh type, return that, and hope for the best. + */ TypeId resultType = freshType(scope); return resultType; } @@ -4195,6 +4252,9 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module return errorRecoveryType(scope); } + if (FFlag::LuauImmutableTypes) + return *moduleType; + SeenTypes seenTypes; SeenTypePacks seenTypePacks; CloneState cloneState; @@ -4978,6 +5038,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (notEnoughParameters && hasDefaultParameters) { // 'applyTypeFunction' is used to substitute default types that reference previous generic types + applyTypeFunction.log = TxnLog::empty(); applyTypeFunction.typeArguments.clear(); applyTypeFunction.typePackArguments.clear(); applyTypeFunction.currentModule = currentModule; @@ -5445,7 +5506,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate) { - LUAU_ASSERT(FFlag::LuauDiscriminableUnions); + LUAU_ASSERT(FFlag::LuauDiscriminableUnions2); const LValue* target = &lvalue; std::optional key; // If set, we know we took the base of the lvalue path and should be walking down each option of the base's type. @@ -5658,7 +5719,7 @@ void TypeChecker::resolve(const TruthyPredicate& truthyP, ErrorVec& errVec, Refi return std::nullopt; }; - if (FFlag::LuauDiscriminableUnions) + if (FFlag::LuauDiscriminableUnions2) { std::optional ty = resolveLValue(refis, scope, truthyP.lvalue); if (ty && fromOr) @@ -5771,7 +5832,7 @@ void TypeChecker::resolve(const IsAPredicate& isaP, ErrorVec& errVec, Refinement return res; }; - if (FFlag::LuauDiscriminableUnions) + if (FFlag::LuauDiscriminableUnions2) { refineLValue(isaP.lvalue, refis, scope, predicate); } @@ -5846,7 +5907,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec if (auto it = primitives.find(typeguardP.kind); it != primitives.end()) { - if (FFlag::LuauDiscriminableUnions) + if (FFlag::LuauDiscriminableUnions2) { refineLValue(typeguardP.lvalue, refis, scope, it->second(sense)); return; @@ -5868,7 +5929,7 @@ void TypeChecker::resolve(const TypeGuardPredicate& typeguardP, ErrorVec& errVec } auto fail = [&](const TypeErrorData& err) { - if (!FFlag::LuauDiscriminableUnions) + if (!FFlag::LuauDiscriminableUnions2) errVec.push_back(TypeError{typeguardP.location, err}); addRefinement(refis, typeguardP.lvalue, errorRecoveryType(scope)); }; @@ -5900,7 +5961,7 @@ void TypeChecker::resolve(const EqPredicate& eqP, ErrorVec& errVec, RefinementMa return {ty}; }; - if (FFlag::LuauDiscriminableUnions) + if (FFlag::LuauDiscriminableUnions2) { std::vector rhs = options(eqP.type); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 2321eafda..7e438e31c 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -28,6 +28,7 @@ LUAU_FASTFLAGVARIABLE(LuauMetatableAreEqualRecursion, false) LUAU_FASTFLAGVARIABLE(LuauRefactorTypeVarQuestions, false) LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAG(LuauUnionTagMatchFix) +LUAU_FASTFLAG(LuauDiscriminableUnions2) namespace Luau { @@ -393,7 +394,8 @@ bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) if (seen.contains(ty)) return true; - if (isPrim(ty, PrimitiveTypeVar::String) || get(ty) || get(ty) || get(ty)) + bool isStr = FFlag::LuauDiscriminableUnions2 ? isString(ty) : isPrim(ty, PrimitiveTypeVar::String); + if (isStr || get(ty) || get(ty) || get(ty)) return true; if (auto uty = get(ty)) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 89e4ae237..a8ad51593 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -15,6 +15,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTFLAGVARIABLE(LuauCommittingTxnLogFreeTpPromote, false) +LUAU_FASTFLAG(LuauImmutableTypes) LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); @@ -24,6 +25,7 @@ LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauProperTypeLevels); LUAU_FASTFLAGVARIABLE(LuauUnifyPackTails, false) LUAU_FASTFLAGVARIABLE(LuauUnionTagMatchFix, false) +LUAU_FASTFLAGVARIABLE(LuauFollowWithCommittingTxnLogInAnyUnification, false) namespace Luau { @@ -32,11 +34,13 @@ struct PromoteTypeLevels { DEPRECATED_TxnLog& DEPRECATED_log; TxnLog& log; + const TypeArena* typeArena = nullptr; TypeLevel minLevel; - explicit PromoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, TypeLevel minLevel) + explicit PromoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel) : DEPRECATED_log(DEPRECATED_log) , log(log) + , typeArena(typeArena) , minLevel(minLevel) { } @@ -65,8 +69,12 @@ struct PromoteTypeLevels } template - bool operator()(TID, const T&) + bool operator()(TID ty, const T&) { + // Type levels of types from other modules are already global, so we don't need to promote anything inside + if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + return false; + return true; } @@ -83,12 +91,20 @@ struct PromoteTypeLevels bool operator()(TypeId ty, const FunctionTypeVar&) { + // Type levels of types from other modules are already global, so we don't need to promote anything inside + if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + return false; + promote(ty, FFlag::LuauUseCommittingTxnLog ? log.getMutable(ty) : getMutable(ty)); return true; } bool operator()(TypeId ty, const TableTypeVar& ttv) { + // Type levels of types from other modules are already global, so we don't need to promote anything inside + if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + return false; + if (ttv.state != TableState::Free && ttv.state != TableState::Generic) return true; @@ -108,24 +124,33 @@ struct PromoteTypeLevels } }; -void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, TypeLevel minLevel, TypeId ty) +void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypeId ty) { - PromoteTypeLevels ptl{DEPRECATED_log, log, minLevel}; + // Type levels of types from other modules are already global, so we don't need to promote anything inside + if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + return; + + PromoteTypeLevels ptl{DEPRECATED_log, log, typeArena, minLevel}; DenseHashSet seen{nullptr}; visitTypeVarOnce(ty, ptl, seen); } -void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, TypeLevel minLevel, TypePackId tp) +void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp) { - PromoteTypeLevels ptl{DEPRECATED_log, log, minLevel}; + // Type levels of types from other modules are already global, so we don't need to promote anything inside + if (FFlag::LuauImmutableTypes && tp->owningArena != typeArena) + return; + + PromoteTypeLevels ptl{DEPRECATED_log, log, typeArena, minLevel}; DenseHashSet seen{nullptr}; visitTypeVarOnce(tp, ptl, seen); } struct SkipCacheForType { - SkipCacheForType(const DenseHashMap& skipCacheForType) + SkipCacheForType(const DenseHashMap& skipCacheForType, const TypeArena* typeArena) : skipCacheForType(skipCacheForType) + , typeArena(typeArena) { } @@ -152,6 +177,10 @@ struct SkipCacheForType bool operator()(TypeId ty, const TableTypeVar&) { + // Types from other modules don't contain mutable elements and are ok to cache + if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + return false; + TableTypeVar& ttv = *getMutable(ty); if (ttv.boundTo) @@ -172,6 +201,10 @@ struct SkipCacheForType template bool operator()(TypeId ty, const T& t) { + // Types from other modules don't contain mutable elements and are ok to cache + if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + return false; + const bool* prev = skipCacheForType.find(ty); if (prev && *prev) @@ -184,8 +217,12 @@ struct SkipCacheForType } template - bool operator()(TypePackId, const T&) + bool operator()(TypePackId tp, const T&) { + // Types from other modules don't contain mutable elements and are ok to cache + if (FFlag::LuauImmutableTypes && tp->owningArena != typeArena) + return false; + return true; } @@ -208,6 +245,7 @@ struct SkipCacheForType } const DenseHashMap& skipCacheForType; + const TypeArena* typeArena = nullptr; bool result = false; }; @@ -422,13 +460,13 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { if (FFlag::LuauUseCommittingTxnLog) { - promoteTypeLevels(DEPRECATED_log, log, superLevel, subTy); + promoteTypeLevels(DEPRECATED_log, log, types, superLevel, subTy); log.replace(superTy, BoundTypeVar(subTy)); } else { if (FFlag::LuauProperTypeLevels) - promoteTypeLevels(DEPRECATED_log, log, superLevel, subTy); + promoteTypeLevels(DEPRECATED_log, log, types, superLevel, subTy); else if (auto subLevel = getMutableLevel(subTy)) { if (!subLevel->subsumes(superFree->level)) @@ -466,13 +504,13 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { if (FFlag::LuauUseCommittingTxnLog) { - promoteTypeLevels(DEPRECATED_log, log, subLevel, superTy); + promoteTypeLevels(DEPRECATED_log, log, types, subLevel, superTy); log.replace(subTy, BoundTypeVar(superTy)); } else { if (FFlag::LuauProperTypeLevels) - promoteTypeLevels(DEPRECATED_log, log, subLevel, superTy); + promoteTypeLevels(DEPRECATED_log, log, types, subLevel, superTy); else if (auto superLevel = getMutableLevel(superTy)) { if (!superLevel->subsumes(subFree->level)) @@ -849,7 +887,7 @@ void Unifier::cacheResult(TypeId subTy, TypeId superTy) return; auto skipCacheFor = [this](TypeId ty) { - SkipCacheForType visitor{sharedState.skipCacheForType}; + SkipCacheForType visitor{sharedState.skipCacheForType, types}; visitTypeVarOnce(ty, visitor, sharedState.seenAny); sharedState.skipCacheForType[ty] = visitor.result; @@ -1637,32 +1675,35 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal tryUnify_(subFunction->retType, superFunction->retType); } - if (FFlag::LuauUseCommittingTxnLog) - { - if (superFunction->definition && !subFunction->definition && !subTy->persistent) - { - PendingType* newSubTy = log.queue(subTy); - FunctionTypeVar* newSubFtv = getMutable(newSubTy); - LUAU_ASSERT(newSubFtv); - newSubFtv->definition = superFunction->definition; - } - else if (!superFunction->definition && subFunction->definition && !superTy->persistent) - { - PendingType* newSuperTy = log.queue(superTy); - FunctionTypeVar* newSuperFtv = getMutable(newSuperTy); - LUAU_ASSERT(newSuperFtv); - newSuperFtv->definition = subFunction->definition; - } - } - else + if (!FFlag::LuauImmutableTypes) { - if (superFunction->definition && !subFunction->definition && !subTy->persistent) + if (FFlag::LuauUseCommittingTxnLog) { - subFunction->definition = superFunction->definition; + if (superFunction->definition && !subFunction->definition && !subTy->persistent) + { + PendingType* newSubTy = log.queue(subTy); + FunctionTypeVar* newSubFtv = getMutable(newSubTy); + LUAU_ASSERT(newSubFtv); + newSubFtv->definition = superFunction->definition; + } + else if (!superFunction->definition && subFunction->definition && !superTy->persistent) + { + PendingType* newSuperTy = log.queue(superTy); + FunctionTypeVar* newSuperFtv = getMutable(newSuperTy); + LUAU_ASSERT(newSuperFtv); + newSuperFtv->definition = subFunction->definition; + } } - else if (!superFunction->definition && subFunction->definition && !superTy->persistent) + else { - superFunction->definition = subFunction->definition; + if (superFunction->definition && !subFunction->definition && !subTy->persistent) + { + subFunction->definition = superFunction->definition; + } + else if (!superFunction->definition && subFunction->definition && !superTy->persistent) + { + superFunction->definition = subFunction->definition; + } } } @@ -2631,7 +2672,7 @@ static void queueTypePack(std::vector& queue, DenseHashSet& { while (true) { - a = follow(a); + a = FFlag::LuauFollowWithCommittingTxnLogInAnyUnification ? state.log.follow(a) : follow(a); if (seenTypePacks.find(a)) break; @@ -2738,7 +2779,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever } static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHashSet& seen, DenseHashSet& seenTypePacks, - TypeId anyType, TypePackId anyTypePack) + const TypeArena* typeArena, TypeId anyType, TypePackId anyTypePack) { while (!queue.empty()) { @@ -2746,8 +2787,14 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas { TypeId ty = state.log.follow(queue.back()); queue.pop_back(); + + // Types from other modules don't have free types + if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + continue; + if (seen.find(ty)) continue; + seen.insert(ty); if (state.log.getMutable(ty)) @@ -2853,7 +2900,7 @@ void Unifier::tryUnifyWithAny(TypeId subTy, TypeId anyTy) sharedState.tempSeenTy.clear(); sharedState.tempSeenTp.clear(); - Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, getSingletonTypes().anyType, anyTP); + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types, getSingletonTypes().anyType, anyTP); } void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp) @@ -2869,7 +2916,7 @@ void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp) queueTypePack(queue, sharedState.tempSeenTp, *this, subTy, anyTp); - Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, anyTy, anyTp); + Luau::tryUnifyWithAny(queue, *this, sharedState.tempSeenTy, sharedState.tempSeenTp, types, anyTy, anyTp); } std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, Name name) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index f559e2e07..30b32f914 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -1133,7 +1133,7 @@ AstTypePack* Parser::parseTypeList(TempVector& result, TempVector + +#include "Luau/Common.h" +#include "Luau/Ast.h" +#include "Luau/JsonEncoder.h" +#include "Luau/Parser.h" +#include "Luau/ParseOptions.h" + +#include "FileUtils.h" + +static void displayHelp(const char* argv0) +{ + printf("Usage: %s [file]\n", argv0); +} + +static int assertionHandler(const char* expr, const char* file, int line, const char* function) +{ + printf("%s(%d): ASSERTION FAILED: %s\n", file, line, expr); + return 1; +} + +int main(int argc, char** argv) +{ + Luau::assertHandler() = assertionHandler; + + for (Luau::FValue* flag = Luau::FValue::list; flag; flag = flag->next) + if (strncmp(flag->name, "Luau", 4) == 0) + flag->value = true; + + if (argc >= 2 && strcmp(argv[1], "--help") == 0) + { + displayHelp(argv[0]); + return 0; + } + else if (argc < 2) + { + displayHelp(argv[0]); + return 1; + } + + const char* name = argv[1]; + std::optional maybeSource = std::nullopt; + if (strcmp(name, "-") == 0) + { + maybeSource = readStdin(); + } + else + { + maybeSource = readFile(name); + } + + if (!maybeSource) + { + fprintf(stderr, "Couldn't read source %s\n", name); + return 1; + } + + std::string source = *maybeSource; + + Luau::Allocator allocator; + Luau::AstNameTable names(allocator); + + Luau::ParseOptions options; + options.supportContinueStatement = true; + options.allowTypeAnnotations = true; + options.allowDeclarationSyntax = true; + + Luau::ParseResult parseResult = Luau::Parser::parse(source.data(), source.size(), names, allocator, options); + + if (parseResult.errors.size() > 0) + { + fprintf(stderr, "Parse errors were encountered:\n"); + for (const Luau::ParseError& error : parseResult.errors) + { + fprintf(stderr, " %s - %s\n", toString(error.getLocation()).c_str(), error.getMessage().c_str()); + } + fprintf(stderr, "\n"); + } + + printf("%s", Luau::toJson(parseResult.root).c_str()); + + return parseResult.errors.size() > 0 ? 1 : 0; +} + + diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 5af6b508f..9a6e25c28 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -1,4 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#include "Repl.h" + #include "lua.h" #include "lualib.h" @@ -38,13 +40,14 @@ enum class CompileFormat struct GlobalOptions { int optimizationLevel = 1; + int debugLevel = 1; } globalOptions; static Luau::CompileOptions copts() { Luau::CompileOptions result = {}; result.optimizationLevel = globalOptions.optimizationLevel; - result.debugLevel = 1; + result.debugLevel = globalOptions.debugLevel; result.coverageLevel = coverageActive() ? 2 : 0; return result; @@ -240,9 +243,8 @@ std::string runCode(lua_State* L, const std::string& source) return std::string(); } -static void completeIndexer(ic_completion_env_t* cenv, const char* editBuffer) +static void completeIndexer(lua_State* L, const std::string& editBuffer, const AddCompletionCallback& addCompletionCallback) { - auto* L = reinterpret_cast(ic_completion_arg(cenv)); std::string_view lookup = editBuffer; char lastSep = 0; @@ -276,7 +278,7 @@ static void completeIndexer(ic_completion_env_t* cenv, const char* editBuffer) // Add an opening paren for function calls by default. completion += "("; } - ic_add_completion_ex(cenv, completion.data(), key.data(), nullptr); + addCompletionCallback(completion, std::string(key)); } } lua_pop(L, 1); @@ -295,10 +297,11 @@ static void completeIndexer(ic_completion_env_t* cenv, const char* editBuffer) { // Replace the string object with the string class to perform further lookups of string functions // Note: We retrieve the string class from _G to prevent issues if the user assigns to `string`. + lua_pop(L, 1); // Pop the string instance lua_getglobal(L, "_G"); lua_pushlstring(L, "string", 6); lua_rawget(L, -2); - lua_remove(L, -2); + lua_remove(L, -2); // Remove the global table LUAU_ASSERT(lua_istable(L, -1)); } else if (!lua_istable(L, -1)) @@ -312,6 +315,26 @@ static void completeIndexer(ic_completion_env_t* cenv, const char* editBuffer) lua_pop(L, 1); } +void getCompletions(lua_State* L, const std::string& editBuffer, const AddCompletionCallback& addCompletionCallback) +{ + // look the value up in current global table first + lua_pushvalue(L, LUA_GLOBALSINDEX); + completeIndexer(L, editBuffer, addCompletionCallback); + + // and in actual global table after that + lua_getglobal(L, "_G"); + completeIndexer(L, editBuffer, addCompletionCallback); +} + +static void icGetCompletions(ic_completion_env_t* cenv, const char* editBuffer) +{ + auto* L = reinterpret_cast(ic_completion_arg(cenv)); + + getCompletions(L, std::string(editBuffer), [cenv](const std::string& completion, const std::string& display) { + ic_add_completion_ex(cenv, completion.data(), display.data(), nullptr); + }); +} + static bool isMethodOrFunctionChar(const char* s, long len) { char c = *s; @@ -320,15 +343,7 @@ static bool isMethodOrFunctionChar(const char* s, long len) static void completeRepl(ic_completion_env_t* cenv, const char* editBuffer) { - auto* L = reinterpret_cast(ic_completion_arg(cenv)); - - // look the value up in current global table first - lua_pushvalue(L, LUA_GLOBALSINDEX); - ic_complete_word(cenv, editBuffer, completeIndexer, isMethodOrFunctionChar); - - // and in actual global table after that - lua_getglobal(L, "_G"); - ic_complete_word(cenv, editBuffer, completeIndexer, isMethodOrFunctionChar); + ic_complete_word(cenv, editBuffer, icGetCompletions, isMethodOrFunctionChar); } struct LinenoiseScopedHistory @@ -372,19 +387,20 @@ static void runReplImpl(lua_State* L) for (;;) { - const char* line = ic_readline(buffer.empty() ? "" : ">"); + const char* prompt = buffer.empty() ? "" : ">"; + std::unique_ptr line(ic_readline(prompt), free); if (!line) break; - if (buffer.empty() && runCode(L, std::string("return ") + line) == std::string()) + if (buffer.empty() && runCode(L, std::string("return ") + line.get()) == std::string()) { - ic_history_add(line); + ic_history_add(line.get()); continue; } if (!buffer.empty()) buffer += "\n"; - buffer += line; + buffer += line.get(); std::string error = runCode(L, buffer); @@ -400,7 +416,6 @@ static void runReplImpl(lua_State* L) ic_history_add(buffer.c_str()); buffer.clear(); - free((void*)line); } } @@ -504,7 +519,7 @@ static bool compileFile(const char* name, CompileFormat format) if (format == CompileFormat::Text) { - bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source); + bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Source | Luau::BytecodeBuilder::Dump_Locals); bcb.setDumpSource(*source); } @@ -549,7 +564,8 @@ static void displayHelp(const char* argv0) printf(" --coverage: collect code coverage while running the code and output results to coverage.out\n"); printf(" -h, --help: Display this usage message.\n"); printf(" -i, --interactive: Run an interactive REPL after executing the last script specified.\n"); - printf(" -O: use compiler optimization level (n=0-2).\n"); + printf(" -O: compile with optimization level n (default 1, n should be between 0 and 2).\n"); + printf(" -g: compile with debug level n (default 1, n should be between 0 and 2).\n"); printf(" --profile[=N]: profile the code using N Hz sampling (default 10000) and output results to profile.out\n"); printf(" --timetrace: record compiler time tracing information into trace.json\n"); } @@ -620,6 +636,16 @@ int replMain(int argc, char** argv) } globalOptions.optimizationLevel = level; } + else if (strncmp(argv[i], "-g", 2) == 0) + { + int level = atoi(argv[i] + 2); + if (level < 0 || level > 2) + { + fprintf(stderr, "Error: Debug level must be between 0 and 2 inclusive.\n"); + return 1; + } + globalOptions.debugLevel = level; + } else if (strcmp(argv[i], "--profile") == 0) { profile = 10000; // default to 10 KHz diff --git a/CLI/Repl.h b/CLI/Repl.h index 11a077ae8..cd54b7e08 100644 --- a/CLI/Repl.h +++ b/CLI/Repl.h @@ -3,10 +3,15 @@ #include "lua.h" +#include #include +using AddCompletionCallback = std::function; + // Note: These are internal functions which are being exposed in a header // so they can be included by unit tests. -int replMain(int argc, char** argv); void setupState(lua_State* L); std::string runCode(lua_State* L, const std::string& source); +void getCompletions(lua_State* L, const std::string& editBuffer, const AddCompletionCallback& addCompletionCallback); + +int replMain(int argc, char** argv); diff --git a/CMakeLists.txt b/CMakeLists.txt index 881d3c3f9..c19d2b40b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -5,13 +5,20 @@ if(EXT_PLATFORM_STRING) endif() cmake_minimum_required(VERSION 3.0) -project(Luau LANGUAGES CXX C) option(LUAU_BUILD_CLI "Build CLI" ON) option(LUAU_BUILD_TESTS "Build tests" ON) option(LUAU_BUILD_WEB "Build Web module" OFF) option(LUAU_WERROR "Warnings as errors" OFF) +option(LUAU_STATIC_CRT "Link with the static CRT (/MT)" OFF) + +if(LUAU_STATIC_CRT) + cmake_minimum_required(VERSION 3.15) + cmake_policy(SET CMP0091 NEW) + set(CMAKE_MSVC_RUNTIME_LIBRARY "MultiThreaded$<$:Debug>") +endif() +project(Luau LANGUAGES CXX C) add_library(Luau.Ast STATIC) add_library(Luau.Compiler STATIC) add_library(Luau.Analysis STATIC) @@ -21,10 +28,12 @@ add_library(isocline STATIC) if(LUAU_BUILD_CLI) add_executable(Luau.Repl.CLI) add_executable(Luau.Analyze.CLI) + add_executable(Luau.Ast.CLI) # This also adds target `name` on Linux/macOS and `name.exe` on Windows set_target_properties(Luau.Repl.CLI PROPERTIES OUTPUT_NAME luau) set_target_properties(Luau.Analyze.CLI PROPERTIES OUTPUT_NAME luau-analyze) + set_target_properties(Luau.Ast.CLI PROPERTIES OUTPUT_NAME luau-ast) endif() if(LUAU_BUILD_TESTS) @@ -98,6 +107,7 @@ endif() if(LUAU_BUILD_CLI) target_compile_options(Luau.Repl.CLI PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) + target_compile_options(Luau.Ast.CLI PRIVATE ${LUAU_OPTIONS}) target_include_directories(Luau.Repl.CLI PRIVATE extern extern/isocline/include) @@ -111,6 +121,8 @@ if(LUAU_BUILD_CLI) endif() target_link_libraries(Luau.Analyze.CLI PRIVATE Luau.Analysis) + + target_link_libraries(Luau.Ast.CLI PRIVATE Luau.Ast Luau.Analysis) endif() if(LUAU_BUILD_TESTS) diff --git a/Compiler/src/BytecodeBuilder.cpp b/Compiler/src/BytecodeBuilder.cpp index e6d024546..09f06b686 100644 --- a/Compiler/src/BytecodeBuilder.cpp +++ b/Compiler/src/BytecodeBuilder.cpp @@ -6,8 +6,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauBytecodeV2Write, false) - namespace Luau { @@ -510,7 +508,7 @@ uint32_t BytecodeBuilder::getDebugPC() const void BytecodeBuilder::finalize() { LUAU_ASSERT(bytecode.empty()); - bytecode = char(FFlag::LuauBytecodeV2Write ? LBC_VERSION_FUTURE : LBC_VERSION); + bytecode = char(LBC_VERSION_FUTURE); writeStringTable(bytecode); @@ -611,9 +609,7 @@ void BytecodeBuilder::writeFunction(std::string& ss, uint32_t id) const writeVarInt(ss, child); // debug info - if (FFlag::LuauBytecodeV2Write) - writeVarInt(ss, func.debuglinedefined); - + writeVarInt(ss, func.debuglinedefined); writeVarInt(ss, func.debugname); bool hasLines = true; diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index e4253adc3..656a99265 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -15,7 +15,6 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauCompileTableIndexOpt, false) LUAU_FASTFLAG(LuauCompileSelectBuiltin2) namespace Luau @@ -1182,18 +1181,9 @@ struct Compiler const AstExprTable::Item& item = expr->items.data[i]; LUAU_ASSERT(item.key); // no list portion => all items have keys - if (FFlag::LuauCompileTableIndexOpt) - { - const Constant* ckey = constants.find(item.key); - - indexSize += (ckey && ckey->type == Constant::Type_Number && ckey->valueNumber == double(indexSize + 1)); - } - else - { - AstExprConstantNumber* ckey = item.key->as(); + const Constant* ckey = constants.find(item.key); - indexSize += (ckey && ckey->value == double(indexSize + 1)); - } + indexSize += (ckey && ckey->type == Constant::Type_Number && ckey->valueNumber == double(indexSize + 1)); } // we only perform the optimization if we don't have any other []-keys @@ -1295,43 +1285,10 @@ struct Compiler { RegScope rsi(this); - if (FFlag::LuauCompileTableIndexOpt) - { - LValue lv = compileLValueIndex(reg, key, rsi); - uint8_t rv = compileExprAuto(value, rsi); + LValue lv = compileLValueIndex(reg, key, rsi); + uint8_t rv = compileExprAuto(value, rsi); - compileAssign(lv, rv); - } - else - { - // Optimization: use SETTABLEKS/SETTABLEN for literal keys, this happens often as part of usual table construction syntax - if (AstExprConstantString* ckey = key->as()) - { - BytecodeBuilder::StringRef cname = sref(ckey->value); - int32_t cid = bytecode.addConstantString(cname); - if (cid < 0) - CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); - - uint8_t rv = compileExprAuto(value, rsi); - - bytecode.emitABC(LOP_SETTABLEKS, rv, reg, uint8_t(BytecodeBuilder::getStringHash(cname))); - bytecode.emitAux(cid); - } - else if (AstExprConstantNumber* ckey = key->as(); - ckey && ckey->value >= 1 && ckey->value <= 256 && double(int(ckey->value)) == ckey->value) - { - uint8_t rv = compileExprAuto(value, rsi); - - bytecode.emitABC(LOP_SETTABLEN, rv, reg, uint8_t(int(ckey->value) - 1)); - } - else - { - uint8_t rk = compileExprAuto(key, rsi); - uint8_t rv = compileExprAuto(value, rsi); - - bytecode.emitABC(LOP_SETTABLE, rv, reg, rk); - } - } + compileAssign(lv, rv); } // items without a key are set using SETLIST so that we can initialize large arrays quickly else @@ -1439,8 +1396,7 @@ struct Compiler uint8_t rt = compileExprAuto(expr->expr, rs); uint8_t i = uint8_t(int(cv->valueNumber) - 1); - if (FFlag::LuauCompileTableIndexOpt) - setDebugLine(expr->index); + setDebugLine(expr->index); bytecode.emitABC(LOP_GETTABLEN, target, rt, i); } @@ -1453,8 +1409,7 @@ struct Compiler if (cid < 0) CompileError::raise(expr->location, "Exceeded constant limit; simplify the code to compile"); - if (FFlag::LuauCompileTableIndexOpt) - setDebugLine(expr->index); + setDebugLine(expr->index); bytecode.emitABC(LOP_GETTABLEKS, target, rt, uint8_t(BytecodeBuilder::getStringHash(iname))); bytecode.emitAux(cid); @@ -1853,8 +1808,7 @@ struct Compiler void compileLValueUse(const LValue& lv, uint8_t reg, bool set) { - if (FFlag::LuauCompileTableIndexOpt) - setDebugLine(lv.location); + setDebugLine(lv.location); switch (lv.kind) { diff --git a/Sources.cmake b/Sources.cmake index b36b6db56..773f6f351 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -193,6 +193,14 @@ if(TARGET Luau.Analyze.CLI) CLI/Analyze.cpp) endif() +if(TARGET Luau.Ast.CLI) + target_sources(Luau.Ast.CLI PRIVATE + CLI/Ast.cpp + CLI/FileUtils.h + CLI/FileUtils.cpp + ) +endif() + if(TARGET Luau.UnitTest) # Luau.UnitTest Sources target_sources(Luau.UnitTest PRIVATE diff --git a/VM/src/lbuiltins.cpp b/VM/src/lbuiltins.cpp index ecc14e87b..718d387d8 100644 --- a/VM/src/lbuiltins.cpp +++ b/VM/src/lbuiltins.cpp @@ -1098,7 +1098,7 @@ static int luauF_select(lua_State* L, StkId res, TValue* arg0, int nresults, Stk int i = int(nvalue(arg0)); // i >= 1 && i <= n - if (unsigned(i - 1) <= unsigned(n)) + if (unsigned(i - 1) < unsigned(n)) { setobj2s(L, res, L->base - n + (i - 1)); return 1; diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index 906fb0d04..ce1965200 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -250,6 +250,8 @@ void luaC_validate(lua_State* L) if (FFlag::LuauGcPagedSweep) { + validategco(L, NULL, obj2gco(g->mainthread)); + luaM_visitgco(L, L, validategco); } else @@ -565,6 +567,8 @@ void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* if (FFlag::LuauGcPagedSweep) { + dumpgco(f, NULL, obj2gco(g->mainthread)); + luaM_visitgco(L, f, dumpgco); } else diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index de85cf595..19617b8ca 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -8,6 +8,76 @@ #include +/* + * Luau heap uses a size-segregated page structure, with individual pages and large allocations + * allocated using system heap (via frealloc callback). + * + * frealloc callback serves as a general, if slow, allocation callback that can allocate, free or + * resize allocations: + * + * void* frealloc(void* ud, void* ptr, size_t oldsize, size_t newsize); + * + * frealloc(ud, NULL, 0, x) creates a new block of size x + * frealloc(ud, p, x, 0) frees the block p (must return NULL) + * frealloc(ud, NULL, 0, 0) does nothing, equivalent to free(NULL) + * + * frealloc returns NULL if it cannot create or reallocate the area + * (any reallocation to an equal or smaller size cannot fail!) + * + * On top of this, Luau implements heap storage which is split into two types of allocations: + * + * - GCO, short for "garbage collected objects" + * - other objects (for example, arrays stored inside table objects) + * + * The heap layout for these two allocation types is a bit different. + * + * All GCO are allocated in pages, which is a block of memory of ~16K in size that has a page header + * (lua_Page). Each page contains 1..N blocks of the same size, where N is selected to fill the page + * completely. This amortizes the allocation cost and increases locality. Each GCO block starts with + * the GC header (GCheader) which contains the object type, mark bits and other GC metadata. If the + * GCO block is free (not used), then it must have the type set to TNIL; in this case the block can + * be part of the per-page free list, the link for that list is stored after the header (freegcolink). + * + * Importantly, the GCO block doesn't have any back references to the page it's allocated in, so it's + * impossible to free it in isolation - GCO blocks are freed by sweeping the pages they belong to, + * using luaM_freegco which must specify the page; this is called by page sweeper that traverses the + * entire page's worth of objects. For this reason it's also important that freed GCO blocks keep the + * GC header intact and accessible (with type = NIL) so that the sweeper can access it. + * + * Some GCOs are too large to fit in a 16K page without excessive fragmentation (the size threshold is + * currently 512 bytes); in this case, we allocate a dedicated small page with just a single block's worth + * storage space, but that requires allocating an extra page header. In effect large GCOs are a little bit + * less memory efficient, but this allows us to uniformly sweep small and large GCOs using page lists. + * + * All GCO pages are linked in a large intrusive linked list (global_State::allgcopages). Additionally, + * for each block size there's a page free list that contains pages that have at least one free block + * (global_State::freegcopages). This free list is used to make sure object allocation is O(1). + * + * Compared to GCOs, regular allocations have two important differences: they can be freed in isolation, + * and they don't start with a GC header. Because of this, each allocation is prefixed with block metadata, + * which contains the pointer to the page for allocated blocks, and the pointer to the next free block + * inside the page for freed blocks. + * For regular allocations that are too large to fit in a page (using the same threshold of 512 bytes), + * we don't allocate a separate page, instead simply using frealloc to allocate a vanilla block of memory. + * + * Just like GCO pages, we store a page free list (global_State::freepages) that allows O(1) allocation; + * there is no global list for non-GCO pages since we never need to traverse them directly. + * + * In both cases, we pick the page by computing the size class from the block size which rounds the block + * size up to reduce the chance that we'll allocate pages that have very few allocated blocks. The size + * class strategy is determined by SizeClassConfig constructor. + * + * Note that when the last block in a page is freed, we immediately free the page with frealloc - the + * memory manager doesn't currently attempt to keep unused memory around. This can result in excessive + * allocation traffic and can be mitigated by adding a page cache in the future. + * + * For both GCO and non-GCO pages, the per-page block allocation combines bump pointer style allocation + * (lua_Page::freeNext) and per-page free list (lua_Page::freeList). We use the bump allocator to allocate + * the contents of the page, and the free list for further reuse; this allows shorter page setup times + * which results in less variance between allocation cost, as well as tighter sweep bounds for newly + * allocated pages. + */ + LUAU_FASTFLAG(LuauGcPagedSweep) #ifndef __has_feature @@ -56,6 +126,7 @@ static_assert(offsetof(GCObject, ts) == 0, "TString data must be located at the const size_t kSizeClasses = LUA_SIZECLASSES; const size_t kMaxSmallSize = 512; const size_t kPageSize = 16 * 1024 - 24; // slightly under 16KB since that results in less fragmentation due to heap metadata + const size_t kBlockHeader = sizeof(double) > sizeof(void*) ? sizeof(double) : sizeof(void*); // suitable for aligning double & void* on all platforms const size_t kGCOLinkOffset = (sizeof(GCheader) + sizeof(void*) - 1) & ~(sizeof(void*) - 1); // GCO pages contain freelist links after the GC header @@ -107,24 +178,6 @@ const SizeClassConfig kSizeClassConfig; #define metadata(block) (*(void**)(block)) #define freegcolink(block) (*(void**)((char*)block + kGCOLinkOffset)) -/* -** About the realloc function: -** void * frealloc (void *ud, void *ptr, size_t osize, size_t nsize); -** (`osize' is the old size, `nsize' is the new size) -** -** Lua ensures that (ptr == NULL) iff (osize == 0). -** -** * frealloc(ud, NULL, 0, x) creates a new block of size `x' -** -** * frealloc(ud, p, x, 0) frees the block `p' -** (in this specific case, frealloc must return NULL). -** particularly, frealloc(ud, NULL, 0, 0) does nothing -** (which is equivalent to free(NULL) in ANSI C) -** -** frealloc returns NULL if it cannot create or reallocate the area -** (any reallocation to an equal or smaller size cannot fail!) -*/ - struct lua_Page { // list of pages with free blocks @@ -135,13 +188,12 @@ struct lua_Page lua_Page* gcolistprev; lua_Page* gcolistnext; - int busyBlocks; - int blockSize; - - void* freeList; - int freeNext; + int pageSize; // page size in bytes, including page header + int blockSize; // block size in bytes, including block header (for non-GCO) - int pageSize; + void* freeList; // next free block in this page; linked with metadata()/freegcolink() + int freeNext; // next free block offset in this page, in bytes; when negative, freeList is used instead + int busyBlocks; // number of blocks allocated out of this page union { @@ -177,7 +229,7 @@ static lua_Page* newpageold(lua_State* L, uint8_t sizeClass) page->gcolistprev = NULL; page->gcolistnext = NULL; - page->busyBlocks = 0; + page->pageSize = kPageSize; page->blockSize = blockSize; // note: we start with the last block in the page and move downward @@ -185,6 +237,7 @@ static lua_Page* newpageold(lua_State* L, uint8_t sizeClass) // additionally, GC stores objects in singly linked lists, and this way the GC lists end up in increasing pointer order page->freeList = NULL; page->freeNext = (blockCount - 1) * blockSize; + page->busyBlocks = 0; // prepend a page to page freelist (which is empty because we only ever allocate a new page when it is!) LUAU_ASSERT(!g->freepages[sizeClass]); @@ -214,7 +267,7 @@ static lua_Page* newpage(lua_State* L, lua_Page** gcopageset, int pageSize, int page->gcolistprev = NULL; page->gcolistnext = NULL; - page->busyBlocks = 0; + page->pageSize = pageSize; page->blockSize = blockSize; // note: we start with the last block in the page and move downward @@ -222,8 +275,7 @@ static lua_Page* newpage(lua_State* L, lua_Page** gcopageset, int pageSize, int // additionally, GC stores objects in singly linked lists, and this way the GC lists end up in increasing pointer order page->freeList = NULL; page->freeNext = (blockCount - 1) * blockSize; - - page->pageSize = pageSize; + page->busyBlocks = 0; if (gcopageset) { @@ -406,8 +458,7 @@ static void* newgcoblock(lua_State* L, int sizeClass) page->next = NULL; } - // the user data is right after the metadata - return (char*)block; + return block; } static void freeblock(lua_State* L, int sizeClass, void* block) @@ -421,6 +472,7 @@ static void freeblock(lua_State* L, int sizeClass, void* block) lua_Page* page = (lua_Page*)metadata(block); LUAU_ASSERT(page && page->busyBlocks > 0); LUAU_ASSERT(size_t(page->blockSize) == kSizeClassConfig.sizeOfClass[sizeClass] + kBlockHeader); + LUAU_ASSERT(block >= page->data && block < (char*)page + page->pageSize); // if the page wasn't in the page free list, it should be now since it got a block! if (!page->freeList && page->freeNext < 0) @@ -455,6 +507,9 @@ static void freeblock(lua_State* L, int sizeClass, void* block) static void freegcoblock(lua_State* L, int sizeClass, void* block, lua_Page* page) { LUAU_ASSERT(FFlag::LuauGcPagedSweep); + LUAU_ASSERT(page && page->busyBlocks > 0); + LUAU_ASSERT(page->blockSize == kSizeClassConfig.sizeOfClass[sizeClass]); + LUAU_ASSERT(block >= page->data && block < (char*)page + page->pageSize); global_State* g = L->global; @@ -575,6 +630,8 @@ void luaM_freegco_(lua_State* L, GCObject* block, size_t osize, uint8_t memcat, else { LUAU_ASSERT(page->busyBlocks == 1); + LUAU_ASSERT(size_t(page->blockSize) == osize); + LUAU_ASSERT((void*)block == page->data); freepage(L, &g->allgcopages, page); } @@ -626,8 +683,12 @@ void luaM_getpagewalkinfo(lua_Page* page, char** start, char** end, int* busyBlo int blockCount = (page->pageSize - offsetof(lua_Page, data)) / page->blockSize; - *start = page->data + page->freeNext + page->blockSize; - *end = page->data + blockCount * page->blockSize; + LUAU_ASSERT(page->freeNext >= -page->blockSize && page->freeNext <= (blockCount - 1) * page->blockSize); + + char* data = page->data; // silences ubsan when indexing page->data + + *start = data + page->freeNext + page->blockSize; + *end = data + blockCount * page->blockSize; *busyBlocks = page->busyBlocks; *blockSize = page->blockSize; } @@ -675,7 +736,7 @@ void luaM_visitgco(lua_State* L, void* context, bool (*visitor)(void* context, l for (lua_Page* curr = g->allgcopages; curr;) { - lua_Page* next = curr->gcolistnext; // page blockvisit might destroy the page + lua_Page* next = curr->gcolistnext; // block visit might destroy the page luaM_visitpage(curr, context, visitor); diff --git a/VM/src/lobject.cpp b/VM/src/lobject.cpp index 370c7b283..d5bd76a8d 100644 --- a/VM/src/lobject.cpp +++ b/VM/src/lobject.cpp @@ -131,7 +131,7 @@ void luaO_chunkid(char* out, const char* source, size_t bufflen) { size_t l; source++; /* skip the `@' */ - bufflen -= sizeof(" '...' "); + bufflen -= sizeof("..."); l = strlen(source); strcpy(out, ""); if (l > bufflen) @@ -144,7 +144,7 @@ void luaO_chunkid(char* out, const char* source, size_t bufflen) else { /* out = [string "string"] */ size_t len = strcspn(source, "\n\r"); /* stop at first newline */ - bufflen -= sizeof(" [string \"...\"] "); + bufflen -= sizeof("[string \"...\"]"); if (len > bufflen) len = bufflen; strcpy(out, "[string \""); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index cba3670ad..c3b662a2c 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -609,7 +609,8 @@ static void luau_execute(lua_State* L) if (unsigned(ic) < LUA_VECTOR_SIZE && name[1] == '\0') { - setnvalue(ra, rb->value.v[ic]); + const float* v = rb->value.v; // silences ubsan when indexing v[] + setnvalue(ra, v[ic]); VM_NEXT(); } diff --git a/extern/isocline/src/isocline.c b/extern/isocline/src/isocline.c index 132780628..8b6055cf9 100644 --- a/extern/isocline/src/isocline.c +++ b/extern/isocline/src/isocline.c @@ -13,7 +13,12 @@ // $ gcc -c src/isocline.c //------------------------------------------------------------- #if !defined(IC_SEPARATE_OBJS) -# define _CRT_SECURE_NO_WARNINGS // for msvc +# ifndef _CRT_NONSTDC_NO_WARNINGS +# define _CRT_NONSTDC_NO_WARNINGS // for msvc +# endif +# ifndef _CRT_SECURE_NO_WARNINGS +# define _CRT_SECURE_NO_WARNINGS // for msvc +# endif # define _XOPEN_SOURCE 700 // for wcwidth # define _DEFAULT_SOURCE // ensure usleep stays visible with _XOPEN_SOURCE >= 700 # include "attr.c" diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index d8af94dba..cd7a21d80 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -605,8 +605,6 @@ RETURN R0 1 TEST_CASE("TableLiteralsIndexConstant") { - ScopedFastFlag sff("LuauCompileTableIndexOpt", true); - // validate that we use SETTTABLEKS for constant variable keys CHECK_EQ("\n" + compileFunction0(R"( local a, b = "key", "value" @@ -2483,8 +2481,6 @@ return TEST_CASE("DebugLineInfoAssignment") { - ScopedFastFlag sff("LuauCompileTableIndexOpt", true); - Luau::BytecodeBuilder bcb; bcb.setDumpFlags(Luau::BytecodeBuilder::Dump_Code | Luau::BytecodeBuilder::Dump_Lines); Luau::compileOrThrow(bcb, R"( diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index e580949f5..8b58d2ce8 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -492,8 +492,6 @@ TEST_CASE("DateTime") TEST_CASE("Debug") { - ScopedFastFlag sffw("LuauBytecodeV2Write", true); - runConformance("debug.lua"); } diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index d1cc49b2d..577415fca 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1392,19 +1392,31 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedApi") {"DataCost", {typeChecker.numberType, /* deprecated= */ true}}, {"Wait", {typeChecker.anyType, /* deprecated= */ true}}, }; + + TypeId colorType = typeChecker.globalTypes.addType(TableTypeVar{{}, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}); + + getMutable(colorType)->props = { + {"toHSV", {typeChecker.anyType, /* deprecated= */ true, "Color3:ToHSV"} } + }; + + addGlobalBinding(typeChecker, "Color3", Binding{colorType, {}}); + freeze(typeChecker.globalTypes); LintResult result = lintTyped(R"( return function (i: Instance) i:Wait(1.0) print(i.Name) + print(Color3.toHSV()) + print(Color3.doesntexist, i.doesntexist) -- type error, but this verifies we correctly handle non-existent members return i.DataCost end )"); - REQUIRE_EQ(result.warnings.size(), 2); + REQUIRE_EQ(result.warnings.size(), 3); CHECK_EQ(result.warnings[0].text, "Member 'Instance.Wait' is deprecated"); - CHECK_EQ(result.warnings[1].text, "Member 'Instance.DataCost' is deprecated"); + CHECK_EQ(result.warnings[1].text, "Member 'toHSV' is deprecated, use 'Color3:ToHSV' instead"); + CHECK_EQ(result.warnings[2].text, "Member 'Instance.DataCost' is deprecated"); } TEST_CASE_FIXTURE(Fixture, "TableOperations") @@ -1475,9 +1487,11 @@ _ = (true and true) or true _ = (true and false) and (42 and false) _ = true and true or false -- no warning since this is is a common pattern used as a ternary replacement + +_ = if true then 1 elseif true then 2 else 3 )"); - REQUIRE_EQ(result.warnings.size(), 7); + REQUIRE_EQ(result.warnings.size(), 8); CHECK_EQ(result.warnings[0].text, "Condition has already been checked on line 2"); CHECK_EQ(result.warnings[0].location.begin.line + 1, 4); CHECK_EQ(result.warnings[1].text, "Condition has already been checked on column 5"); @@ -1487,6 +1501,7 @@ _ = true and true or false -- no warning since this is is a common pattern used CHECK_EQ(result.warnings[5].text, "Condition has already been checked on column 6"); CHECK_EQ(result.warnings[6].text, "Condition has already been checked on column 15"); CHECK_EQ(result.warnings[6].location.begin.line + 1, 19); + CHECK_EQ(result.warnings[7].text, "Condition has already been checked on column 8"); } TEST_CASE_FIXTURE(Fixture, "DuplicateConditionsExpr") @@ -1528,4 +1543,19 @@ return foo, moo, a1, a2 CHECK_EQ(result.warnings[3].text, "Function parameter 'self' already defined implicitly"); } +TEST_CASE_FIXTURE(Fixture, "MisleadingAndOr") +{ + LintResult result = lint(R"( +_ = math.random() < 0.5 and true or 42 +_ = math.random() < 0.5 and false or 42 -- misleading +_ = math.random() < 0.5 and nil or 42 -- misleading +_ = math.random() < 0.5 and 0 or 42 +_ = (math.random() < 0.5 and false) or 42 -- currently ignored +)"); + + REQUIRE_EQ(result.warnings.size(), 2); + CHECK_EQ(result.warnings[0].text, "The and-or expression always evaluates to the second alternative because the first alternative is false; consider using if-then-else expression instead"); + CHECK_EQ(result.warnings[1].text, "The and-or expression always evaluates to the second alternative because the first alternative is nil; consider using if-then-else expression instead"); +} + TEST_SUITE_END(); diff --git a/tests/Repl.test.cpp b/tests/Repl.test.cpp index f660bcd3f..1f9c97397 100644 --- a/tests/Repl.test.cpp +++ b/tests/Repl.test.cpp @@ -8,9 +8,22 @@ #include #include +#include #include #include +struct Completion +{ + std::string completion; + std::string display; + + bool operator<(Completion const& other) const + { + return std::tie(completion, display) < std::tie(other.completion, other.display); + } +}; + +using CompletionSet = std::set; class ReplFixture { @@ -34,6 +47,27 @@ class ReplFixture lua_pop(L, 1); return result; } + + CompletionSet getCompletionSet(const char* inputPrefix) + { + CompletionSet result; + int top = lua_gettop(L); + getCompletions(L, inputPrefix, [&result](const std::string& completion, const std::string& display) { + result.insert(Completion{completion, display}); + }); + // Ensure that generating completions doesn't change the position of luau's stack top. + CHECK(top == lua_gettop(L)); + + return result; + } + + bool checkCompletion(const CompletionSet& completions, const std::string& prefix, const std::string& expected) + { + std::string expectedDisplay(expected.substr(0, expected.find_first_of('('))); + Completion expectedCompletion{prefix + expected, expectedDisplay}; + return completions.count(expectedCompletion) == 1; + } + lua_State* L; private: @@ -115,3 +149,61 @@ TEST_CASE_FIXTURE(ReplFixture, "MultipleArguments") } TEST_SUITE_END(); + +TEST_SUITE_BEGIN("ReplCodeCompletion"); + +TEST_CASE_FIXTURE(ReplFixture, "CompleteGlobalVariables") +{ + runCode(L, R"( + myvariable1 = 5 + myvariable2 = 5 +)"); + CompletionSet completions = getCompletionSet("myvar"); + + std::string prefix = ""; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "myvariable1")); + CHECK(checkCompletion(completions, prefix, "myvariable2")); +} + +TEST_CASE_FIXTURE(ReplFixture, "CompleteTableKeys") +{ + runCode(L, R"( + t = { color = "red", size = 1, shape = "circle" } +)"); + { + CompletionSet completions = getCompletionSet("t."); + + std::string prefix = "t."; + CHECK(completions.size() == 3); + CHECK(checkCompletion(completions, prefix, "color")); + CHECK(checkCompletion(completions, prefix, "size")); + CHECK(checkCompletion(completions, prefix, "shape")); + } + + { + CompletionSet completions = getCompletionSet("t.s"); + + std::string prefix = "t."; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "size")); + CHECK(checkCompletion(completions, prefix, "shape")); + } +} + +TEST_CASE_FIXTURE(ReplFixture, "StringMethods") +{ + runCode(L, R"( + s = "" +)"); + { + CompletionSet completions = getCompletionSet("s:l"); + + std::string prefix = "s:"; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "len(")); + CHECK(checkCompletion(completions, prefix, "lower(")); + } +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 76ab23b3a..a87292683 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -595,4 +595,65 @@ TEST_CASE_FIXTURE(Fixture, "generic_typevars_are_not_considered_to_escape_their_ LUAU_REQUIRE_NO_ERRORS(result); } +/* + * The two-pass alias definition system starts by ascribing a free TypeVar to each alias. It then + * circles back to fill in the actual type later on. + * + * If this free type is unified with something degenerate like `any`, we need to take extra care + * to ensure that the alias actually binds to the type that the user expected. + */ +TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_unification_with_any") +{ + ScopedFastFlag sff[] = { + {"LuauTwoPassAliasDefinitionFix", true} + }; + + CheckResult result = check(R"( + local function x() + local y: FutureType = {}::any + return 1 + end + type FutureType = { foo: typeof(x()) } + local d: FutureType = { smth = true } -- missing error, 'd' is resolved to 'any' + )"); + + CHECK_EQ("{| foo: number |}", toString(requireType("d"), {true})); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_unification_with_any_2") +{ + ScopedFastFlag sff[] = { + {"LuauTwoPassAliasDefinitionFix", true}, + + // We also force these two flags because this surfaced an unfortunate interaction. + {"LuauErrorRecoveryType", true}, + {"LuauQuantifyInPlace2", true}, + }; + + CheckResult result = check(R"( + local B = {} + B.bar = 4 + + function B:smth1() + local self: FutureIntersection = self + self.foo = 4 + return 4 + end + + function B:smth2() + local self: FutureIntersection = self + self.bar = 5 -- error, even though we should have B part with bar + end + + type A = { foo: typeof(B.smth1({foo=3})) } -- trick toposort into sorting functions before types + type B = typeof(B) + + type FutureIntersection = A & B + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 6730bedb0..df06884d8 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -7,8 +7,6 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauFixTonumberReturnType) - using namespace Luau; LUAU_FASTFLAG(LuauUseCommittingTxnLog) @@ -850,11 +848,8 @@ TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type") local b: number = tonumber('asdf') )"); - if (FFlag::LuauFixTonumberReturnType) - { - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number?' could not be converted into 'number'", toString(result.errors[0])); - } + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number?' could not be converted into 'number'", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "tonumber_returns_optional_number_type2") @@ -893,7 +888,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types") { ScopedFastFlag sff[]{ {"LuauAssertStripsFalsyTypes", true}, - {"LuauDiscriminableUnions", true}, + {"LuauDiscriminableUnions2", true}, }; CheckResult result = check(R"( @@ -910,7 +905,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types_even_from_type_pack_tail_ { ScopedFastFlag sff[]{ {"LuauAssertStripsFalsyTypes", true}, - {"LuauDiscriminableUnions", true}, + {"LuauDiscriminableUnions2", true}, }; CheckResult result = check(R"( @@ -927,7 +922,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_returns_false_and_string_iff_it_knows_the_fir { ScopedFastFlag sff[]{ {"LuauAssertStripsFalsyTypes", true}, - {"LuauDiscriminableUnions", true}, + {"LuauDiscriminableUnions2", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index e5eb0dca0..2bcd840c8 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -262,7 +262,7 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_equals_another_lvalue_with_no_overlap") // Just needs to fully support equality refinement. Which is annoying without type states. TEST_CASE_FIXTURE(Fixture, "discriminate_from_x_not_equal_to_nil") { - ScopedFastFlag sff{"LuauDiscriminableUnions", true}; + ScopedFastFlag sff{"LuauDiscriminableUnions2", true}; CheckResult result = check(R"( type T = {x: string, y: number} | {x: nil, y: nil} @@ -616,4 +616,76 @@ local a: Self
CHECK_EQ(toString(requireType("a")), "Table
"); } +TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type_pack") +{ + ScopedFastFlag sff[]{ + {"LuauQuantifyInPlace2", true}, + {"LuauReturnAnyInsteadOfICE", true}, + }; + + // In-place quantification causes these types to have the wrong types but only because of nasty interaction with prototyping. + // The type of f is initially () -> free1... + // Then the prototype iterator advances, and checks the function expression assigned to g, which has the type () -> free2... + // In the body it calls f and returns what f() returns. This binds free2... with free1..., causing f and g to have same types. + // We then quantify g, leaving it with the final type () -> a... + // Because free1... and free2... were bound, in combination with in-place quantification, f's return type was also turned into a... + // Then the check iterator catches up, and checks the body of f, and attempts to quantify it too. + // Alas, one of the requirements for quantification is that a type must contain free types. () -> a... has no free types. + // Thus the quantification for f was no-op, which explains why f does not have any type parameters. + // Calling f() will attempt to instantiate the function type, which turns generics in type binders into to free types. + // However, instantiations only converts generics contained within the type binders of a function, so instantiation was also no-op. + // Which means that calling f() simply returned a... rather than an instantiation of it. And since the call site was not in tail position, + // picking first element in a... triggers an ICE because calls returning generic packs are unexpected. + CheckResult result = check(R"( + local function f() end + + local g = function() return f() end + + local x = (f()) -- should error: no return values to assign from the call to f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // f and g should have the type () -> () + CHECK_EQ("() -> (a...)", toString(requireType("f"))); + CHECK_EQ("() -> (a...)", toString(requireType("g"))); + CHECK_EQ("any", toString(requireType("x"))); // any is returned instead of ICE for now +} + +TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early") +{ + CheckResult result = check(R"( + local function id(x) return x end + local n2n: (number) -> number = id + local s2s: (string) -> string = id + )"); + + LUAU_REQUIRE_ERRORS(result); // Should not have any errors. +} + +TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") +{ + ScopedFastFlag sff{"LuauQuantifyInPlace2", true}; + + CheckResult result = check(R"( + local function f() return end + local g = function() return f() end + )"); + + LUAU_REQUIRE_ERRORS(result); // Should not have any errors. +} + +TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack") +{ + ScopedFastFlag sff{"LuauQuantifyInPlace2", true}; + + CheckResult result = check(R"( + --!strict + local function f(...) return ... end + local g = function(...) return f(...) end + )"); + + LUAU_REQUIRE_ERRORS(result); // Should not have any errors. +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 3a610c3a1..48e6be6a4 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -6,7 +6,7 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauDiscriminableUnions) +LUAU_FASTFLAG(LuauDiscriminableUnions2) LUAU_FASTFLAG(LuauWeakEqConstraint) LUAU_FASTFLAG(LuauQuantifyInPlace2) @@ -262,7 +262,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_only_look_up_types_from_global_scope") end )"); - if (FFlag::LuauDiscriminableUnions) + if (FFlag::LuauDiscriminableUnions2) { LUAU_REQUIRE_NO_ERRORS(result); @@ -435,7 +435,7 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_equal_to_a_term") TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") { ScopedFastFlag sff[] = { - {"LuauDiscriminableUnions", true}, + {"LuauDiscriminableUnions2", true}, {"LuauSingletonTypes", true}, }; @@ -485,7 +485,7 @@ TEST_CASE_FIXTURE(Fixture, "lvalue_is_not_nil") TEST_CASE_FIXTURE(Fixture, "free_type_is_equal_to_an_lvalue") { - ScopedFastFlag sff{"LuauDiscriminableUnions", true}; + ScopedFastFlag sff{"LuauDiscriminableUnions2", true}; ScopedFastFlag sff2{"LuauWeakEqConstraint", true}; CheckResult result = check(R"( @@ -589,7 +589,7 @@ TEST_CASE_FIXTURE(Fixture, "type_narrow_to_vector") end )"); - if (FFlag::LuauDiscriminableUnions) + if (FFlag::LuauDiscriminableUnions2) { LUAU_REQUIRE_NO_ERRORS(result); } @@ -1002,7 +1002,7 @@ TEST_CASE_FIXTURE(Fixture, "apply_refinements_on_astexprindexexpr_whose_subscrip TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") { ScopedFastFlag sff[] = { - {"LuauDiscriminableUnions", true}, + {"LuauDiscriminableUnions2", true}, {"LuauParseSingletonTypes", true}, {"LuauSingletonTypes", true}, }; @@ -1028,7 +1028,7 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") TEST_CASE_FIXTURE(Fixture, "discriminate_tag") { ScopedFastFlag sff[] = { - {"LuauDiscriminableUnions", true}, + {"LuauDiscriminableUnions2", true}, {"LuauParseSingletonTypes", true}, {"LuauSingletonTypes", true}, }; @@ -1069,7 +1069,7 @@ TEST_CASE_FIXTURE(Fixture, "narrow_boolean_to_true_or_false") ScopedFastFlag sff[]{ {"LuauParseSingletonTypes", true}, {"LuauSingletonTypes", true}, - {"LuauDiscriminableUnions", true}, + {"LuauDiscriminableUnions2", true}, {"LuauAssertStripsFalsyTypes", true}, }; @@ -1094,7 +1094,7 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_ ScopedFastFlag sff[]{ {"LuauParseSingletonTypes", true}, {"LuauSingletonTypes", true}, - {"LuauDiscriminableUnions", true}, + {"LuauDiscriminableUnions2", true}, {"LuauAssertStripsFalsyTypes", true}, }; @@ -1118,7 +1118,7 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_ TEST_CASE_FIXTURE(RefinementClassFixture, "discriminate_from_isa_of_x") { ScopedFastFlag sff[] = { - {"LuauDiscriminableUnions", true}, + {"LuauDiscriminableUnions2", true}, {"LuauParseSingletonTypes", true}, {"LuauSingletonTypes", true}, }; @@ -1157,7 +1157,7 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") end )"); - if (FFlag::LuauDiscriminableUnions) + if (FFlag::LuauDiscriminableUnions2) LUAU_REQUIRE_NO_ERRORS(result); else { diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index ead3d762c..531a382f5 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -5164,4 +5164,151 @@ function x:Destroy(): () end LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_2") +{ + ScopedFastFlag immutableTypes{"LuauImmutableTypes", true}; + + fileResolver.source["game/A"] = R"( +export type Type = { x: { a: number } } +return {} + )"; + + fileResolver.source["game/B"] = R"( +local types = require(game.A) +type Type = types.Type +local x: Type = { x = { a = 2 } } +type Rename = typeof(x.x) + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_3") +{ + ScopedFastFlag immutableTypes{"LuauImmutableTypes", true}; + + fileResolver.source["game/A"] = R"( +local y = setmetatable({}, {}) +export type Type = { x: typeof(y) } +return { x = y } + )"; + + fileResolver.source["game/B"] = R"( +local types = require(game.A) +type Type = types.Type +local x: Type = types +type Rename = typeof(x.x) + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "indexing_on_string_singletons") +{ + ScopedFastFlag sff[]{ + {"LuauDiscriminableUnions2", true}, + {"LuauRefactorTypeVarQuestions", true}, + {"LuauSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: string = "hi" + if a == "hi" then + local x = a:byte() + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 22}))); +} + +TEST_CASE_FIXTURE(Fixture, "indexing_on_union_of_string_singletons") +{ + ScopedFastFlag sff[]{ + {"LuauDiscriminableUnions2", true}, + {"LuauRefactorTypeVarQuestions", true}, + {"LuauSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: string = "hi" + if a == "hi" or a == "bye" then + local x = a:byte() + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("bye" | "hi")", toString(requireTypeAtPosition({3, 22}))); +} + +TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_string_singleton") +{ + ScopedFastFlag sff[]{ + {"LuauDiscriminableUnions2", true}, + {"LuauRefactorTypeVarQuestions", true}, + {"LuauSingletonTypes", true}, + {"LuauLengthOnCompositeType", true}, + }; + + CheckResult result = check(R"( + local a: string = "hi" + if a == "hi" then + local x = #a + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 23}))); +} + +TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_union_of_string_singleton") +{ + ScopedFastFlag sff[]{ + {"LuauDiscriminableUnions2", true}, + {"LuauRefactorTypeVarQuestions", true}, + {"LuauSingletonTypes", true}, + {"LuauLengthOnCompositeType", true}, + }; + + CheckResult result = check(R"( + local a: string = "hi" + if a == "hi" or a == "bye" then + local x = #a + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("bye" | "hi")", toString(requireTypeAtPosition({3, 23}))); +} + +/* + * When we add new properties to an unsealed table, we should do a level check and promote the property type to be at + * the level of the table. + */ +TEST_CASE_FIXTURE(Fixture, "inferred_properties_of_a_table_should_start_with_the_same_TypeLevel_of_that_table") +{ + CheckResult result = check(R"( + --!strict + local T = {} + + local function f(prop) + T[1] = { + prop = prop, + } + end + + local function g() + local l = T[1].prop + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 0aeca0965..8c7fb79ab 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -273,4 +273,21 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "recursive_metatable_getmatchtag") state.tryUnify(&metatable, &variant); } +TEST_CASE_FIXTURE(TryUnifyFixture, "cli_50320_follow_in_any_unification") +{ + ScopedFastFlag sffs[] = { + {"LuauUseCommittingTxnLog", true}, + {"LuauFollowWithCommittingTxnLogInAnyUnification", true}, + }; + + TypePackVar free{FreeTypePack{TypeLevel{}}}; + TypePackVar target{TypePack{}}; + + TypeVar func{FunctionTypeVar{&free, &free}}; + + state.tryUnify(&free, &target); + // Shouldn't assert or error. + state.tryUnify(&func, typeChecker.anyType); +} + TEST_SUITE_END(); diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index de091632d..78d900770 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -118,9 +118,7 @@ assert((function() return #_G end)() == 0) assert((function() return #{1,2} end)() == 2) assert((function() return #'g' end)() == 1) -local ud = newproxy(true) -getmetatable(ud).__len = function() return 42 end -assert((function() return #ud end)() == 42) +assert((function() local ud = newproxy(true) getmetatable(ud).__len = function() return 42 end return #ud end)() == 42) assert((function() local a = 1 a = -a return a end)() == -1) @@ -325,6 +323,10 @@ assert((function() local t = {6, 9, 7} t[4.5] = 10 return t[4.5] end)() == 10) assert((function() local t = {6, 9, 7} t['a'] = 11 return t['a'] end)() == 11) assert((function() local t = {6, 9, 7} setmetatable(t, { __newindex = function(t,i,v) rawset(t, i * 10, v) end }) t[1] = 17 t[5] = 1 return concat(t[1],t[5],t[50]) end)() == "17,nil,1") +-- userdata access +assert((function() local ud = newproxy(true) getmetatable(ud).__index = function(ud,i) return i * 10 end return ud[2] end)() == 20) +assert((function() local ud = newproxy(true) getmetatable(ud).__index = function() return function(self, i) return i * 10 end end return ud:meow(2) end)() == 20) + -- and/or -- rhs is a constant assert((function() local a = 1 a = a and 2 return a end)() == 2) @@ -462,7 +464,7 @@ assert((function() a = {} b = {} mt = { __eq = function(l, r) return #l == #r en -- metatable ops local function vec3t(x, y, z) - return setmetatable({ x=x, y=y, z=z}, { + return setmetatable({x=x, y=y, z=z}, { __add = function(l, r) return vec3t(l.x + r.x, l.y + r.y, l.z + r.z) end, __sub = function(l, r) return vec3t(l.x - r.x, l.y - r.y, l.z - r.z) end, __mul = function(l, r) return type(r) == "number" and vec3t(l.x * r, l.y * r, l.z * r) or vec3t(l.x * r.x, l.y * r.y, l.z * r.z) end, diff --git a/tests/conformance/debug.lua b/tests/conformance/debug.lua index 8c96ab335..0e4100005 100644 --- a/tests/conformance/debug.lua +++ b/tests/conformance/debug.lua @@ -37,6 +37,7 @@ coroutine.resume(co2, 0 / 0, 42) assert(debug.traceback(co2) == "debug.lua:31 function halp\n") assert(debug.info(co2, 0, "l") == 31) +assert(debug.info(co2, 0, "f") == halp) -- info errors function qux(...) diff --git a/tests/conformance/errors.lua b/tests/conformance/errors.lua index d5ff215b4..751188bed 100644 --- a/tests/conformance/errors.lua +++ b/tests/conformance/errors.lua @@ -260,8 +260,7 @@ local a,b = loadstring(s) assert(not a) --assert(string.find(b, "line 2")) --- Test for CLI-28786 --- The xpcall is intentially going to cause an exception +-- The xpcall is intentionally going to cause an exception -- followed by a forced exception in the error handler. -- If the secondary handler isn't trapped, it will cause -- the unit test to fail. If the xpcall captures the @@ -281,6 +280,19 @@ coroutine.wrap(function() assert(not pcall(debug.getinfo, coroutine.running(), 0, ">")) end)() +-- loadstring chunk truncation +local a,b = loadstring("nope", "@short") +assert(not a and b:match('[^ ]+') == "short:1:") + +local a,b = loadstring("nope", "@" .. string.rep("thisisaverylongstringitssolongthatitwontfitintotheinternalbufferprovidedtovariousdebugfacilities", 10)) +assert(not a and b:match('[^ ]+') == "...wontfitintotheinternalbufferprovidedtovariousdebugfacilitiesthisisaverylongstringitssolongthatitwontfitintotheinternalbufferprovidedtovariousdebugfacilitiesthisisaverylongstringitssolongthatitwontfitintotheinternalbufferprovidedtovariousdebugfacilities:1:") + +local a,b = loadstring("nope", "=short") +assert(not a and b:match('[^ ]+') == "short:1:") + +local a,b = loadstring("nope", "=" .. string.rep("thisisaverylongstringitssolongthatitwontfitintotheinternalbufferprovidedtovariousdebugfacilities", 10)) +assert(not a and b:match('[^ ]+') == "thisisaverylongstringitssolongthatitwontfitintotheinternalbufferprovidedtovariousdebugfacilitiesthisisaverylongstringitssolongthatitwontfitintotheinternalbufferprovidedtovariousdebugfacilitiesthisisaverylongstringitssolongthatitwontfitintotheinternalbuffe:1:") + -- arith errors function ecall(fn, ...) local ok, err = pcall(fn, ...) diff --git a/tests/conformance/gc.lua b/tests/conformance/gc.lua index 409cd2247..5804ea7f7 100644 --- a/tests/conformance/gc.lua +++ b/tests/conformance/gc.lua @@ -180,6 +180,11 @@ x,y,z=nil collectgarbage() assert(next(a) == string.rep('$', 11)) +-- shrinking tables reduce their capacity; confirming the shrinking is difficult but we can at least test the surface level behavior +a = {}; setmetatable(a, {__mode = 'ks'}) +for i=1,lim do a[{}] = i end +collectgarbage() +assert(next(a) == nil) -- testing userdata collectgarbage("stop") -- stop collection @@ -315,8 +320,6 @@ do end collectgarbage() - end - return('OK') diff --git a/tests/conformance/math.lua b/tests/conformance/math.lua index bfea0e1f1..79ea0fb69 100644 --- a/tests/conformance/math.lua +++ b/tests/conformance/math.lua @@ -289,6 +289,7 @@ assert(math.sqrt("4") == 2) assert(math.tanh("0") == 0) assert(math.tan("0") == 0) assert(math.clamp("0", 2, 3) == 2) +assert(math.clamp("4", 2, 3) == 3) assert(math.sign("2") == 1) assert(math.sign("-2") == -1) assert(math.sign("0") == 0) diff --git a/tests/conformance/vararg.lua b/tests/conformance/vararg.lua index d05f95776..178c56b82 100644 --- a/tests/conformance/vararg.lua +++ b/tests/conformance/vararg.lua @@ -139,6 +139,12 @@ assert(selectmany(1, 10, 20, 30) == "10,20,30") assert(selectone(2, 10, 20, 30) == 20) assert(selectmany(2, 10, 20, 30) == "20,30") +assert(selectone(3, 10, 20, 30) == 30) +assert(selectmany(3, 10, 20, 30) == "30") + +assert(selectone(4, 10, 20, 30) == nil) +assert(selectmany(4, 10, 20, 30) == "") + assert(selectone(-2, 10, 20, 30) == 20) assert(selectmany(-2, 10, 20, 30) == "20,30") diff --git a/tests/conformance/vector.lua b/tests/conformance/vector.lua index 7d18bda33..22d6adfc1 100644 --- a/tests/conformance/vector.lua +++ b/tests/conformance/vector.lua @@ -87,9 +87,18 @@ assert(pcall(function() local t = {} rawset(t, vector(0/0, 2, 3), 1) end) == fal -- make sure we cover both builtin and C impl assert(vector(1, 2, 4) == vector("1", "2", "4")) +-- validate component access (both cases) +assert(vector(1, 2, 3).x == 1) +assert(vector(1, 2, 3).X == 1) +assert(vector(1, 2, 3).y == 2) +assert(vector(1, 2, 3).Y == 2) +assert(vector(1, 2, 3).z == 3) +assert(vector(1, 2, 3).Z == 3) + -- additional checks for 4-component vectors if vector_size == 4 then assert(vector(1, 2, 3, 4).w == 4) + assert(vector(1, 2, 3, 4).W == 4) end return 'OK' diff --git a/tools/heapgraph.py b/tools/heapgraph.py index 106db5491..d4d29af16 100644 --- a/tools/heapgraph.py +++ b/tools/heapgraph.py @@ -7,10 +7,20 @@ # The result of analysis is a .svg file which can be viewed in a browser # To generate these dumps, use luaC_dump, ideally preceded by luaC_fullgc +import argparse import json import sys import svg +argumentParser = argparse.ArgumentParser(description='Luau heap snapshot analyzer') + +argumentParser.add_argument('--split', dest = 'split', type = str, default = 'none', help = 'Perform additional root split using memory categories', choices = ['none', 'custom', 'all']) + +argumentParser.add_argument('snapshot') +argumentParser.add_argument('snapshotnew', nargs='?') + +arguments = argumentParser.parse_args() + class Node(svg.Node): def __init__(self): svg.Node.__init__(self) @@ -30,14 +40,14 @@ def details(self, root): return "{} ({:,} bytes, {:.1%}); self: {:,} bytes in {:,} objects".format(self.name, self.width, self.width / root.width, self.size, self.count) # load files -if len(sys.argv) == 2: +if arguments.snapshotnew == None: dumpold = None - with open(sys.argv[1]) as f: + with open(arguments.snapshot) as f: dump = json.load(f) else: - with open(sys.argv[1]) as f: + with open(arguments.snapshot) as f: dumpold = json.load(f) - with open(sys.argv[2]) as f: + with open(arguments.snapshotnew) as f: dump = json.load(f) # reachability analysis: how much of the heap is reachable from roots? @@ -111,12 +121,15 @@ def details(self, root): if "object" in obj: queue.append((obj["object"], node)) -def annotateContainedCategories(node): +def annotateContainedCategories(node, start): for obj in node.objects: + if obj["cat"] < start: + obj["cat"] = 0 + node.categories.add(obj["cat"]) for child in node.children.values(): - annotateContainedCategories(child) + annotateContainedCategories(child, start) for cat in child.categories: node.categories.add(cat) @@ -172,9 +185,11 @@ def splitIntoCategories(root): return result -# temporarily disabled because it makes FG harder to read, maybe this should be a separate command line option? -if dump["stats"].get("categories") and False: - annotateContainedCategories(root) +if dump["stats"].get("categories") and arguments.split != 'none': + if arguments.split == 'custom': + annotateContainedCategories(root, 128) + else: + annotateContainedCategories(root, 0) root = splitIntoCategories(root) diff --git a/tools/svg.py b/tools/svg.py index 99853fb6e..21200eeb6 100644 --- a/tools/svg.py +++ b/tools/svg.py @@ -452,7 +452,7 @@ def display(root, title, colors, flip = False): .replace("$gradient-start", gradient_start) .replace("$gradient-end", gradient_end) .replace("$height", str(svgheight)) - .replace("$status", str(svgheight - 16 + 3)) + .replace("$status", str((svgheight - 16 + 3 if flip else 3 * 16 - 3))) .replace("$flip", str(int(flip))) ) From 49304095161c4532a236b8141faf1e0c7dda498f Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 17 Feb 2022 16:41:20 -0800 Subject: [PATCH 26/32] Sync to upstream/release/515 --- Analysis/include/Luau/Documentation.h | 3 + Analysis/include/Luau/Frontend.h | 3 +- Analysis/include/Luau/Linter.h | 7 +- Analysis/include/Luau/Module.h | 4 +- Analysis/include/Luau/Quantify.h | 5 +- Analysis/include/Luau/Substitution.h | 20 +- Analysis/include/Luau/TypeInfer.h | 45 +++- Analysis/src/Autocomplete.cpp | 45 ++-- Analysis/src/Config.cpp | 2 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 2 +- Analysis/src/Frontend.cpp | 35 ++- Analysis/src/JsonEncoder.cpp | 36 ++- Analysis/src/Linter.cpp | 126 ++++++++- Analysis/src/Quantify.cpp | 10 +- Analysis/src/Substitution.cpp | 23 -- Analysis/src/TopoSortStatements.cpp | 3 +- Analysis/src/Transpiler.cpp | 8 +- Analysis/src/TypeAttach.cpp | 4 +- Analysis/src/TypeInfer.cpp | 142 +++-------- Analysis/src/TypeUtils.cpp | 12 + Analysis/src/TypeVar.cpp | 6 +- Analysis/src/TypedAllocator.cpp | 3 + Analysis/src/Unifier.cpp | 16 +- Ast/include/Luau/Ast.h | 17 +- Ast/include/Luau/Lexer.h | 5 + Ast/include/Luau/ParseResult.h | 69 +++++ Ast/include/Luau/Parser.h | 56 +--- Ast/src/Ast.cpp | 29 +-- Ast/src/Lexer.cpp | 12 +- Ast/src/Parser.cpp | 161 +++++++----- Ast/src/TimeTrace.cpp | 6 + CLI/FileUtils.cpp | 7 +- CLI/Repl.cpp | 167 ++++++++---- CMakeLists.txt | 14 + Sources.cmake | 1 + VM/src/lapi.cpp | 25 +- VM/src/ldebug.cpp | 2 +- VM/src/ldo.cpp | 13 +- VM/src/ldo.h | 2 +- VM/src/lgc.cpp | 4 +- VM/src/lgc.h | 2 +- VM/src/lgcdebug.cpp | 10 +- VM/src/lperf.cpp | 6 + VM/src/lstate.cpp | 35 ++- VM/src/lstate.h | 8 +- VM/src/lvmexecute.cpp | 2 +- VM/src/lvmload.cpp | 6 +- VM/src/lvmutils.cpp | 4 +- fuzz/linter.cpp | 2 +- fuzz/proto.cpp | 2 +- tests/Autocomplete.test.cpp | 22 +- tests/Compiler.test.cpp | 14 +- tests/Conformance.test.cpp | 8 + tests/Fixture.cpp | 4 +- tests/Fixture.h | 1 - tests/Frontend.test.cpp | 3 - tests/JsonEncoder.test.cpp | 2 +- tests/Linter.test.cpp | 48 +++- tests/NonstrictMode.test.cpp | 1 - tests/Parser.test.cpp | 51 ++-- tests/Repl.test.cpp | 209 ++++++++++++++- tests/RequireTracer.test.cpp | 2 +- tests/TypeInfer.aliases.test.cpp | 6 +- tests/TypeInfer.annotations.test.cpp | 1 - tests/TypeInfer.builtins.test.cpp | 1 - tests/TypeInfer.classes.test.cpp | 1 - tests/TypeInfer.definitions.test.cpp | 1 - tests/TypeInfer.generics.test.cpp | 1 - tests/TypeInfer.intersectionTypes.test.cpp | 1 - tests/TypeInfer.provisional.test.cpp | 1 - tests/TypeInfer.singletons.test.cpp | 2 - tests/TypeInfer.tables.test.cpp | 43 +++- tests/TypeInfer.test.cpp | 27 +- tests/TypeInfer.tryUnify.test.cpp | 1 - tests/TypeInfer.typePacks.cpp | 3 - tests/TypeInfer.unionTypes.test.cpp | 1 - tests/TypePack.test.cpp | 1 - tests/TypeVar.test.cpp | 1 - tests/conformance/coroutine.lua | 9 + tests/conformance/coverage.lua | 8 + tests/conformance/debug.lua | 3 + tests/main.cpp | 7 +- tools/natvis/Analysis.natvis | 78 ++++++ tools/natvis/Ast.natvis | 25 ++ tools/natvis/VM.natvis | 269 ++++++++++++++++++++ 85 files changed, 1507 insertions(+), 576 deletions(-) create mode 100644 Ast/include/Luau/ParseResult.h create mode 100644 tools/natvis/Analysis.natvis create mode 100644 tools/natvis/Ast.natvis create mode 100644 tools/natvis/VM.natvis diff --git a/Analysis/include/Luau/Documentation.h b/Analysis/include/Luau/Documentation.h index 68ff3a7c2..7a2b56ffb 100644 --- a/Analysis/include/Luau/Documentation.h +++ b/Analysis/include/Luau/Documentation.h @@ -21,6 +21,7 @@ struct BasicDocumentation { std::string documentation; std::string learnMoreLink; + std::string codeSample; }; struct FunctionParameterDocumentation @@ -37,6 +38,7 @@ struct FunctionDocumentation std::vector parameters; std::vector returns; std::string learnMoreLink; + std::string codeSample; }; struct OverloadedFunctionDocumentation @@ -52,6 +54,7 @@ struct TableDocumentation std::string documentation; Luau::DenseHashMap keys; std::string learnMoreLink; + std::string codeSample; }; using DocumentationDatabase = Luau::DenseHashMap; diff --git a/Analysis/include/Luau/Frontend.h b/Analysis/include/Luau/Frontend.h index 1f64db30c..0bf8f362c 100644 --- a/Analysis/include/Luau/Frontend.h +++ b/Analysis/include/Luau/Frontend.h @@ -24,6 +24,7 @@ struct TypeChecker; struct FileResolver; struct ModuleResolver; struct ParseResult; +struct HotComment; struct LoadDefinitionFileResult { @@ -35,7 +36,7 @@ struct LoadDefinitionFileResult LoadDefinitionFileResult loadDefinitionFile( TypeChecker& typeChecker, ScopePtr targetScope, std::string_view definition, const std::string& packageName); -std::optional parseMode(const std::vector& hotcomments); +std::optional parseMode(const std::vector& hotcomments); std::vector parsePathExpr(const AstExpr& pathExpr); diff --git a/Analysis/include/Luau/Linter.h b/Analysis/include/Luau/Linter.h index ec3c124d7..6c7ce47fe 100644 --- a/Analysis/include/Luau/Linter.h +++ b/Analysis/include/Luau/Linter.h @@ -14,6 +14,7 @@ class AstStat; class AstNameTable; struct TypeChecker; struct Module; +struct HotComment; using ScopePtr = std::shared_ptr; @@ -50,6 +51,7 @@ struct LintWarning Code_TableOperations = 23, Code_DuplicateCondition = 24, Code_MisleadingAndOr = 25, + Code_CommentDirective = 26, Code__Count }; @@ -60,7 +62,7 @@ struct LintWarning static const char* getName(Code code); static Code parseName(const char* name); - static uint64_t parseMask(const std::vector& hotcomments); + static uint64_t parseMask(const std::vector& hotcomments); }; struct LintResult @@ -90,7 +92,8 @@ struct LintOptions void setDefaults(); }; -std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, const LintOptions& options); +std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, + const std::vector& hotcomments, const LintOptions& options); std::vector getDeprecatedGlobals(const AstNameTable& names); diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 1bf0473c1..612007711 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -6,7 +6,7 @@ #include "Luau/TypedAllocator.h" #include "Luau/ParseOptions.h" #include "Luau/Error.h" -#include "Luau/Parser.h" +#include "Luau/ParseResult.h" #include #include @@ -37,8 +37,8 @@ struct SourceModule AstStatBlock* root = nullptr; std::optional mode; - uint64_t ignoreLints = 0; + std::vector hotcomments; std::vector commentLocations; SourceModule() diff --git a/Analysis/include/Luau/Quantify.h b/Analysis/include/Luau/Quantify.h index f46df1460..e48cad404 100644 --- a/Analysis/include/Luau/Quantify.h +++ b/Analysis/include/Luau/Quantify.h @@ -6,9 +6,6 @@ namespace Luau { -struct Module; -using ModulePtr = std::shared_ptr; - -void quantify(ModulePtr module, TypeId ty, TypeLevel level); +void quantify(TypeId ty, TypeLevel level); } // namespace Luau diff --git a/Analysis/include/Luau/Substitution.h b/Analysis/include/Luau/Substitution.h index f85b42690..9662d5b39 100644 --- a/Analysis/include/Luau/Substitution.h +++ b/Analysis/include/Luau/Substitution.h @@ -101,9 +101,6 @@ struct Tarjan // This is hot code, so we optimize recursion to a stack. TarjanResult loop(); - // Clear the state - void clear(); - // Find or create the index for a vertex. // Return a boolean which is `true` if it's a freshly created index. std::pair indexify(TypeId ty); @@ -166,7 +163,17 @@ struct FindDirty : Tarjan // and replaces them with clean ones. struct Substitution : FindDirty { - ModulePtr currentModule; +protected: + Substitution(const TxnLog* log_, TypeArena* arena) + : arena(arena) + { + log = log_; + LUAU_ASSERT(log); + LUAU_ASSERT(arena); + } + +public: + TypeArena* arena; DenseHashMap newTypes{nullptr}; DenseHashMap newPacks{nullptr}; @@ -192,12 +199,13 @@ struct Substitution : FindDirty template TypeId addType(const T& tv) { - return currentModule->internalTypes.addType(tv); + return arena->addType(tv); } + template TypePackId addTypePack(const T& tp) { - return currentModule->internalTypes.addTypePack(TypePackVar{tp}); + return arena->addTypePack(TypePackVar{tp}); } }; diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 5592fa1f5..3c5ded3cc 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -5,7 +5,6 @@ #include "Luau/Error.h" #include "Luau/Module.h" #include "Luau/Symbol.h" -#include "Luau/Parser.h" #include "Luau/Substitution.h" #include "Luau/TxnLog.h" #include "Luau/TypePack.h" @@ -37,6 +36,15 @@ struct Unifier; // A substitution which replaces generic types in a given set by free types. struct ReplaceGenerics : Substitution { + ReplaceGenerics( + const TxnLog* log, TypeArena* arena, TypeLevel level, const std::vector& generics, const std::vector& genericPacks) + : Substitution(log, arena) + , level(level) + , generics(generics) + , genericPacks(genericPacks) + { + } + TypeLevel level; std::vector generics; std::vector genericPacks; @@ -50,8 +58,13 @@ struct ReplaceGenerics : Substitution // A substitution which replaces generic functions by monomorphic functions struct Instantiation : Substitution { + Instantiation(const TxnLog* log, TypeArena* arena, TypeLevel level) + : Substitution(log, arena) + , level(level) + { + } + TypeLevel level; - ReplaceGenerics replaceGenerics; bool ignoreChildren(TypeId ty) override; bool isDirty(TypeId ty) override; bool isDirty(TypePackId tp) override; @@ -62,6 +75,12 @@ struct Instantiation : Substitution // A substitution which replaces free types by generic types. struct Quantification : Substitution { + Quantification(TypeArena* arena, TypeLevel level) + : Substitution(TxnLog::empty(), arena) + , level(level) + { + } + TypeLevel level; std::vector generics; std::vector genericPacks; @@ -74,6 +93,13 @@ struct Quantification : Substitution // A substitution which replaces free types by any struct Anyification : Substitution { + Anyification(TypeArena* arena, TypeId anyType, TypePackId anyTypePack) + : Substitution(TxnLog::empty(), arena) + , anyType(anyType) + , anyTypePack(anyTypePack) + { + } + TypeId anyType; TypePackId anyTypePack; bool isDirty(TypeId ty) override; @@ -85,6 +111,13 @@ struct Anyification : Substitution // A substitution which replaces the type parameters of a type function by arguments struct ApplyTypeFunction : Substitution { + ApplyTypeFunction(TypeArena* arena, TypeLevel level) + : Substitution(TxnLog::empty(), arena) + , level(level) + , encounteredForwardedType(false) + { + } + TypeLevel level; bool encounteredForwardedType; std::unordered_map typeArguments; @@ -351,8 +384,7 @@ struct TypeChecker // Note: `scope` must be a fresh scope. GenericTypeDefinitions createGenericTypes(const ScopePtr& scope, std::optional levelOpt, const AstNode& node, - const AstArray& genericNames, const AstArray& genericPackNames, - bool useCache = false); + const AstArray& genericNames, const AstArray& genericPackNames, bool useCache = false); public: ErrorVec resolve(const PredicateVec& predicates, const ScopePtr& scope, bool sense); @@ -392,11 +424,6 @@ struct TypeChecker ModulePtr currentModule; ModuleName currentModuleName; - Instantiation instantiation; - Quantification quantification; - Anyification anyification; - ApplyTypeFunction applyTypeFunction; - std::function prepareModuleScope; InternalErrorReporter* iceHandler; diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 85099e12c..5a1ae3975 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -7,6 +7,7 @@ #include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypePack.h" +#include "Luau/Parser.h" // TODO: only needed for autocompleteSource which is deprecated #include #include @@ -14,9 +15,9 @@ LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); -LUAU_FASTFLAGVARIABLE(LuauCompleteBrokenStringParams, false); LUAU_FASTFLAGVARIABLE(LuauMissingFollowACMetatables, false); LUAU_FASTFLAGVARIABLE(PreferToCallFunctionsForIntersects, false); +LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -380,7 +381,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId { // We are walking up the class hierarchy, so if we encounter a property that we have // already populated, it takes precedence over the property we found just now. - if (result.count(name) == 0 && name != Parser::errorName) + if (result.count(name) == 0 && name != kParseNameError) { Luau::TypeId type = Luau::follow(prop.type); TypeCorrectKind typeCorrect = indexType == PropIndexType::Key ? TypeCorrectKind::Correct @@ -948,9 +949,12 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi } } - for (size_t i = 0; i < node->returnAnnotation.types.size; i++) + if (!node->returnAnnotation) + return result; + + for (size_t i = 0; i < node->returnAnnotation->types.size; i++) { - AstType* ret = node->returnAnnotation.types.data[i]; + AstType* ret = node->returnAnnotation->types.data[i]; if (ret->location.containsClosed(position)) { @@ -965,7 +969,7 @@ AutocompleteEntryMap autocompleteTypeNames(const Module& module, Position positi } } - if (AstTypePack* retTp = node->returnAnnotation.tailType) + if (AstTypePack* retTp = node->returnAnnotation->tailType) { if (auto variadic = retTp->as()) { @@ -1136,7 +1140,7 @@ static AutocompleteEntryMap autocompleteStatement( AstNode* parent = ancestry.rbegin()[1]; if (AstStatIf* statIf = parent->as()) { - if (!statIf->elsebody || (statIf->hasElse && statIf->elseLocation.containsClosed(position))) + if (!statIf->elsebody || (statIf->elseLocation && statIf->elseLocation->containsClosed(position))) { result.emplace("else", AutocompleteEntry{AutocompleteEntryKind::Keyword}); result.emplace("elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}); @@ -1164,8 +1168,7 @@ static AutocompleteEntryMap autocompleteStatement( return result; } -// Returns true if completions were generated (completions will be inserted into 'outResult') -// Returns false if no completions were generated +// Returns true iff `node` was handled by this function (completions, if any, are returned in `outResult`) static bool autocompleteIfElseExpression( const AstNode* node, const std::vector& ancestry, const Position& position, AutocompleteEntryMap& outResult) { @@ -1173,6 +1176,13 @@ static bool autocompleteIfElseExpression( if (!parent) return false; + if (FFlag::LuauIfElseExprFixCompletionIssue && node->is()) + { + // Don't try to complete when the current node is an if-else expression (i.e. only try to complete when the node is a child of an if-else + // expression. + return true; + } + AstExprIfElse* ifElseExpr = parent->as(); if (!ifElseExpr || ifElseExpr->condition->location.containsClosed(position)) { @@ -1310,7 +1320,7 @@ static std::optional autocompleteStringParams(const Source return std::nullopt; } - if (!nodes.back()->is() && (!FFlag::LuauCompleteBrokenStringParams || !nodes.back()->is())) + if (!nodes.back()->is() && !nodes.back()->is()) { return std::nullopt; } @@ -1408,8 +1418,8 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } else if (auto typeReference = node->as()) { - if (typeReference->hasPrefix) - return {autocompleteModuleTypes(*module, position, typeReference->prefix.value), finder.ancestry}; + if (typeReference->prefix) + return {autocompleteModuleTypes(*module, position, typeReference->prefix->value), finder.ancestry}; else return {autocompleteTypeNames(*module, position, finder.ancestry), finder.ancestry}; } @@ -1419,9 +1429,9 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M } else if (AstStatLocal* statLocal = node->as()) { - if (statLocal->vars.size == 1 && (!statLocal->hasEqualsSign || position < statLocal->equalsSignLocation.begin)) + if (statLocal->vars.size == 1 && (!statLocal->equalsSignLocation || position < statLocal->equalsSignLocation->begin)) return {{{"function", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; - else if (statLocal->hasEqualsSign && position >= statLocal->equalsSignLocation.end) + else if (statLocal->equalsSignLocation && position >= statLocal->equalsSignLocation->end) return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; else return {}; @@ -1449,7 +1459,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M if (!statForIn->hasIn || position <= statForIn->inLocation.begin) { AstLocal* lastName = statForIn->vars.data[statForIn->vars.size - 1]; - if (lastName->name == Parser::errorName || lastName->location.containsClosed(position)) + if (lastName->name == kParseNameError || lastName->location.containsClosed(position)) { // Here we are either working with a missing binding (as would be the case in a bare "for" keyword) or // the cursor is still touching a binding name. The user is still typing a new name, so we should not offer @@ -1499,7 +1509,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M else if (AstStatWhile* statWhile = extractStat(finder.ancestry); statWhile && !statWhile->hasDo) return {{{"do", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; - else if (AstStatIf* statIf = node->as(); statIf && !statIf->hasElse) + else if (AstStatIf* statIf = node->as(); statIf && !statIf->elseLocation.has_value()) { return {{{"else", AutocompleteEntry{AutocompleteEntryKind::Keyword}}, {"elseif", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; @@ -1508,11 +1518,11 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M { if (statIf->condition->is()) return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; - else if (!statIf->hasThen || statIf->thenLocation.containsClosed(position)) + else if (!statIf->thenLocation || statIf->thenLocation->containsClosed(position)) return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; } else if (AstStatIf* statIf = extractStat(finder.ancestry); - statIf && (!statIf->hasThen || statIf->thenLocation.containsClosed(position))) + statIf && (!statIf->thenLocation || statIf->thenLocation->containsClosed(position))) return {{{"then", AutocompleteEntry{AutocompleteEntryKind::Keyword}}}, finder.ancestry}; else if (AstStatRepeat* statRepeat = node->as(); statRepeat && statRepeat->condition->is()) return {autocompleteExpression(sourceModule, *module, typeChecker, typeArena, finder.ancestry, position), finder.ancestry}; @@ -1612,6 +1622,7 @@ AutocompleteResult autocomplete(Frontend& frontend, const ModuleName& moduleName OwningAutocompleteResult autocompleteSource(Frontend& frontend, std::string_view source, Position position, StringCompletionCallback callback) { + // TODO: Remove #include "Luau/Parser.h" with this function auto sourceModule = std::make_unique(); ParseOptions parseOptions; parseOptions.captureComments = true; diff --git a/Analysis/src/Config.cpp b/Analysis/src/Config.cpp index d9fc44f81..35a2259d1 100644 --- a/Analysis/src/Config.cpp +++ b/Analysis/src/Config.cpp @@ -1,7 +1,7 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Config.h" -#include "Luau/Parser.h" +#include "Luau/Lexer.h" #include "Luau/StringUtils.h" namespace diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index f3ef88fc5..bf6e1193f 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -167,7 +167,7 @@ declare function gcinfo(): number foreach: ({[K]: V}, (K, V) -> ()) -> (), foreachi: ({V}, (number, V) -> ()) -> (), - move: ({V}, number, number, number, {V}?) -> (), + move: ({V}, number, number, number, {V}?) -> {V}, clear: ({[K]: V}) -> (), freeze: ({[K]: V}) -> {[K]: V}, diff --git a/Analysis/src/Frontend.cpp b/Analysis/src/Frontend.cpp index 9001b19df..d8906f6e7 100644 --- a/Analysis/src/Frontend.cpp +++ b/Analysis/src/Frontend.cpp @@ -4,6 +4,7 @@ #include "Luau/Common.h" #include "Luau/Config.h" #include "Luau/FileResolver.h" +#include "Luau/Parser.h" #include "Luau/Scope.h" #include "Luau/StringUtils.h" #include "Luau/TimeTrace.h" @@ -16,23 +17,25 @@ #include LUAU_FASTFLAG(LuauInferInNoCheckMode) -LUAU_FASTFLAGVARIABLE(LuauTypeCheckTwice, false) LUAU_FASTFLAGVARIABLE(LuauKnowsTheDataModel3, false) namespace Luau { -std::optional parseMode(const std::vector& hotcomments) +std::optional parseMode(const std::vector& hotcomments) { - for (const std::string& hc : hotcomments) + for (const HotComment& hc : hotcomments) { - if (hc == "nocheck") + if (!hc.header) + continue; + + if (hc.content == "nocheck") return Mode::NoCheck; - if (hc == "nonstrict") + if (hc.content == "nonstrict") return Mode::Nonstrict; - if (hc == "strict") + if (hc.content == "strict") return Mode::Strict; } @@ -607,13 +610,15 @@ std::pair Frontend::lintFragment(std::string_view sour SourceModule sourceModule = parse(ModuleName{}, source, config.parseOptions); + uint64_t ignoreLints = LintWarning::parseMask(sourceModule.hotcomments); + Luau::LintOptions lintOptions = enabledLintWarnings.value_or(config.enabledLint); - lintOptions.warningMask &= sourceModule.ignoreLints; + lintOptions.warningMask &= ~ignoreLints; double timestamp = getTimestamp(); - std::vector warnings = - Luau::lint(sourceModule.root, *sourceModule.names.get(), typeChecker.globalScope, nullptr, enabledLintWarnings.value_or(config.enabledLint)); + std::vector warnings = Luau::lint(sourceModule.root, *sourceModule.names.get(), typeChecker.globalScope, nullptr, + sourceModule.hotcomments, enabledLintWarnings.value_or(config.enabledLint)); stats.timeLint += getTimestamp() - timestamp; @@ -651,8 +656,10 @@ LintResult Frontend::lint(const SourceModule& module, std::optionalgetConfig(module.name); + uint64_t ignoreLints = LintWarning::parseMask(module.hotcomments); + LintOptions options = enabledLintWarnings.value_or(config.enabledLint); - options.warningMask &= ~module.ignoreLints; + options.warningMask &= ~ignoreLints; Mode mode = module.mode.value_or(config.mode); if (mode != Mode::NoCheck) @@ -671,7 +678,7 @@ LintResult Frontend::lint(const SourceModule& module, std::optional warnings = Luau::lint(module.root, *module.names, environmentScope, modulePtr.get(), options); + std::vector warnings = Luau::lint(module.root, *module.names, environmentScope, modulePtr.get(), module.hotcomments, options); stats.timeLint += getTimestamp() - timestamp; @@ -839,7 +846,6 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const { sourceModule.root = parseResult.root; sourceModule.mode = parseMode(parseResult.hotcomments); - sourceModule.ignoreLints = LintWarning::parseMask(parseResult.hotcomments); } else { @@ -848,8 +854,13 @@ SourceModule Frontend::parse(const ModuleName& name, std::string_view src, const } sourceModule.name = name; + if (parseOptions.captureComments) + { sourceModule.commentLocations = std::move(parseResult.commentLocations); + sourceModule.hotcomments = std::move(parseResult.hotcomments); + } + return sourceModule; } diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index 8dd597e17..ec3991581 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -150,10 +150,21 @@ struct AstJsonEncoder : public AstVisitor { writeRaw(std::to_string(i)); } + void write(std::nullptr_t) + { + writeRaw("null"); + } void write(std::string_view str) { writeString(str); } + void write(std::optional name) + { + if (name) + write(*name); + else + writeRaw("null"); + } void write(AstName name) { writeString(name.value ? name.value : ""); @@ -177,7 +188,16 @@ struct AstJsonEncoder : public AstVisitor void write(AstLocal* local) { - write(local->name); + writeRaw("{"); + bool c = pushComma(); + if (local->annotation != nullptr) + write("type", local->annotation); + else + write("type", nullptr); + write("name", local->name); + write("location", local->location); + popComma(c); + writeRaw("}"); } void writeNode(AstNode* node) @@ -314,7 +334,7 @@ struct AstJsonEncoder : public AstVisitor if (node->self) PROP(self); PROP(args); - if (node->hasReturnAnnotation) + if (node->returnAnnotation) PROP(returnAnnotation); PROP(vararg); PROP(varargLocation); @@ -328,6 +348,14 @@ struct AstJsonEncoder : public AstVisitor }); } + void write(const std::optional& typeList) + { + if (typeList) + write(*typeList); + else + writeRaw("null"); + } + void write(const AstTypeList& typeList) { writeRaw("{"); @@ -531,7 +559,7 @@ struct AstJsonEncoder : public AstVisitor PROP(thenbody); if (node->elsebody) PROP(elsebody); - PROP(hasThen); + write("hasThen", node->thenLocation.has_value()); PROP(hasEnd); }); } @@ -715,7 +743,7 @@ struct AstJsonEncoder : public AstVisitor void write(class AstTypeReference* node) { writeNode(node, "AstTypeReference", [&]() { - if (node->hasPrefix) + if (node->prefix) PROP(prefix); PROP(name); PROP(parameters); diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 2ba6a0fce..8d7d2d97f 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -12,6 +12,8 @@ #include #include +LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) + namespace Luau { @@ -44,6 +46,7 @@ static const char* kWarningNames[] = { "TableOperations", "DuplicateCondition", "MisleadingAndOr", + "CommentDirective", }; // clang-format on @@ -732,13 +735,13 @@ class LintLocalHygiene : AstVisitor bool visit(AstTypeReference* node) override { - if (!node->hasPrefix) + if (!node->prefix) return true; - if (!imports.contains(node->prefix)) + if (!imports.contains(*node->prefix)) return true; - AstLocal* astLocal = imports[node->prefix]; + AstLocal* astLocal = imports[*node->prefix]; Local& local = locals[astLocal]; LUAU_ASSERT(local.import); local.used = true; @@ -2527,13 +2530,108 @@ static void fillBuiltinGlobals(LintContext& context, const AstNameTable& names, } } +static const char* fuzzyMatch(std::string_view str, const char** array, size_t size) +{ + if (FInt::LuauSuggestionDistance == 0) + return nullptr; + + size_t bestDistance = FInt::LuauSuggestionDistance; + size_t bestMatch = size; + + for (size_t i = 0; i < size; ++i) + { + size_t ed = editDistance(str, array[i]); + + if (ed <= bestDistance) + { + bestDistance = ed; + bestMatch = i; + } + } + + return bestMatch < size ? array[bestMatch] : nullptr; +} + +static void lintComments(LintContext& context, const std::vector& hotcomments) +{ + bool seenMode = false; + + for (const HotComment& hc : hotcomments) + { + // We reserve --! for various informational (non-directive) comments + if (hc.content.empty() || hc.content[0] == ' ' || hc.content[0] == '\t') + continue; + + if (!hc.header) + { + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, + "Comment directive is ignored because it is placed after the first non-comment token"); + } + else + { + std::string::size_type space = hc.content.find_first_of(" \t"); + std::string_view first = std::string_view(hc.content).substr(0, space); + + if (first == "nolint") + { + std::string::size_type notspace = hc.content.find_first_not_of(" \t", space); + + if (space == std::string::npos || notspace == std::string::npos) + { + // disables all lints + } + else if (LintWarning::parseName(hc.content.c_str() + notspace) == LintWarning::Code_Unknown) + { + const char* rule = hc.content.c_str() + notspace; + + // skip Unknown + if (const char* suggestion = fuzzyMatch(rule, kWarningNames + 1, LintWarning::Code__Count - 1)) + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, + "nolint directive refers to unknown lint rule '%s'; did you mean '%s'?", rule, suggestion); + else + emitWarning( + context, LintWarning::Code_CommentDirective, hc.location, "nolint directive refers to unknown lint rule '%s'", rule); + } + } + else if (first == "nocheck" || first == "nonstrict" || first == "strict") + { + if (space != std::string::npos) + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, + "Comment directive with the type checking mode has extra symbols at the end of the line"); + else if (seenMode) + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, + "Comment directive with the type checking mode has already been used"); + else + seenMode = true; + } + else + { + static const char* kHotComments[] = { + "nolint", + "nocheck", + "nonstrict", + "strict", + }; + + if (const char* suggestion = fuzzyMatch(first, kHotComments, std::size(kHotComments))) + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, "Unknown comment directive '%.*s'; did you mean '%s'?", + int(first.size()), first.data(), suggestion); + else + emitWarning(context, LintWarning::Code_CommentDirective, hc.location, "Unknown comment directive '%.*s'", int(first.size()), + first.data()); + } + } + } +} + void LintOptions::setDefaults() { // By default, we enable all warnings warningMask = ~0ull; } -std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, const LintOptions& options) +std::vector lint(AstStat* root, const AstNameTable& names, const ScopePtr& env, const Module* module, + const std::vector& hotcomments, const LintOptions& options) { LintContext context; @@ -2609,6 +2707,9 @@ std::vector lint(AstStat* root, const AstNameTable& names, const Sc if (context.warningEnabled(LintWarning::Code_MisleadingAndOr)) LintMisleadingAndOr::process(context); + if (context.warningEnabled(LintWarning::Code_CommentDirective)) + lintComments(context, hotcomments); + std::sort(context.result.begin(), context.result.end(), WarningComparator()); return context.result; @@ -2630,23 +2731,30 @@ LintWarning::Code LintWarning::parseName(const char* name) return Code_Unknown; } -uint64_t LintWarning::parseMask(const std::vector& hotcomments) +uint64_t LintWarning::parseMask(const std::vector& hotcomments) { uint64_t result = 0; - for (const std::string& hc : hotcomments) + for (const HotComment& hc : hotcomments) { - if (hc.compare(0, 6, "nolint") != 0) + if (!hc.header) + continue; + + if (hc.content.compare(0, 6, "nolint") != 0) continue; - std::string::size_type name = hc.find_first_not_of(" \t", 6); + std::string::size_type name = hc.content.find_first_not_of(" \t", 6); // --!nolint disables everything if (name == std::string::npos) return ~0ull; + // --!nolint needs to be followed by a whitespace character + if (name == 6) + continue; + // --!nolint name disables the specific lint - LintWarning::Code code = LintWarning::parseName(hc.c_str() + name); + LintWarning::Code code = LintWarning::parseName(hc.content.c_str() + name); if (code != LintWarning::Code_Unknown) result |= 1ull << int(code); diff --git a/Analysis/src/Quantify.cpp b/Analysis/src/Quantify.cpp index 04ebffc1b..94e169f1f 100644 --- a/Analysis/src/Quantify.cpp +++ b/Analysis/src/Quantify.cpp @@ -9,14 +9,12 @@ namespace Luau struct Quantifier { - ModulePtr module; TypeLevel level; std::vector generics; std::vector genericPacks; - Quantifier(ModulePtr module, TypeLevel level) - : module(module) - , level(level) + Quantifier(TypeLevel level) + : level(level) { } @@ -76,9 +74,9 @@ struct Quantifier } }; -void quantify(ModulePtr module, TypeId ty, TypeLevel level) +void quantify(TypeId ty, TypeLevel level) { - Quantifier q{std::move(module), level}; + Quantifier q{level}; DenseHashSet seen{nullptr}; visitTypeVarOnce(ty, q, seen); diff --git a/Analysis/src/Substitution.cpp b/Analysis/src/Substitution.cpp index bacbca762..770c7a47d 100644 --- a/Analysis/src/Substitution.cpp +++ b/Analysis/src/Substitution.cpp @@ -226,27 +226,11 @@ TarjanResult Tarjan::loop() return TarjanResult::Ok; } -void Tarjan::clear() -{ - typeToIndex.clear(); - indexToType.clear(); - packToIndex.clear(); - indexToPack.clear(); - lowlink.clear(); - stack.clear(); - onStack.clear(); - - edgesTy.clear(); - edgesTp.clear(); - worklist.clear(); -} - TarjanResult Tarjan::visitRoot(TypeId ty) { childCount = 0; ty = log->follow(ty); - clear(); auto [index, fresh] = indexify(ty); worklist.push_back({index, -1, -1}); return loop(); @@ -257,7 +241,6 @@ TarjanResult Tarjan::visitRoot(TypePackId tp) childCount = 0; tp = log->follow(tp); - clear(); auto [index, fresh] = indexify(tp); worklist.push_back({index, -1, -1}); return loop(); @@ -314,21 +297,17 @@ void FindDirty::visitSCC(int index) TarjanResult FindDirty::findDirty(TypeId ty) { - dirty.clear(); return visitRoot(ty); } TarjanResult FindDirty::findDirty(TypePackId tp) { - dirty.clear(); return visitRoot(tp); } std::optional Substitution::substitute(TypeId ty) { ty = log->follow(ty); - newTypes.clear(); - newPacks.clear(); auto result = findDirty(ty); if (result != TarjanResult::Ok) @@ -347,8 +326,6 @@ std::optional Substitution::substitute(TypeId ty) std::optional Substitution::substitute(TypePackId tp) { tp = log->follow(tp); - newTypes.clear(); - newPacks.clear(); auto result = findDirty(tp); if (result != TarjanResult::Ok) diff --git a/Analysis/src/TopoSortStatements.cpp b/Analysis/src/TopoSortStatements.cpp index dba694be0..678001bf8 100644 --- a/Analysis/src/TopoSortStatements.cpp +++ b/Analysis/src/TopoSortStatements.cpp @@ -26,9 +26,10 @@ * 3. Cyclic dependencies can be resolved by picking an arbitrary statement to check first. */ -#include "Luau/Parser.h" +#include "Luau/Ast.h" #include "Luau/DenseHash.h" #include "Luau/Common.h" +#include "Luau/StringUtils.h" #include #include diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index f59086834..54bd0d5e7 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -933,12 +933,12 @@ struct Printer writer.symbol(")"); - if (writeTypes && func.hasReturnAnnotation) + if (writeTypes && func.returnAnnotation) { writer.symbol(":"); writer.space(); - visualizeTypeList(func.returnAnnotation, false); + visualizeTypeList(*func.returnAnnotation, false); } visualizeBlock(*func.body); @@ -989,9 +989,9 @@ struct Printer advance(typeAnnotation.location.begin); if (const auto& a = typeAnnotation.as()) { - if (a->hasPrefix) + if (a->prefix) { - writer.write(a->prefix.value); + writer.write(a->prefix->value); writer.symbol("."); } diff --git a/Analysis/src/TypeAttach.cpp b/Analysis/src/TypeAttach.cpp index 2208213f7..d575e023d 100644 --- a/Analysis/src/TypeAttach.cpp +++ b/Analysis/src/TypeAttach.cpp @@ -3,7 +3,6 @@ #include "Luau/Error.h" #include "Luau/Module.h" -#include "Luau/Parser.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" #include "Luau/ToString.h" @@ -476,12 +475,11 @@ class TypeAttacher : public AstVisitor visitLocal(arg); } - if (!fn->hasReturnAnnotation) + if (!fn->returnAnnotation) { if (auto result = getScope(fn->body->location)) { TypePackId ret = result->returnType; - fn->hasReturnAnnotation = true; AstTypePack* variadicAnnotation = nullptr; const auto& [v, tail] = flatten(ret); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index f1c314cd2..c29699b7d 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -3,7 +3,6 @@ #include "Luau/Common.h" #include "Luau/ModuleResolver.h" -#include "Luau/Parser.h" #include "Luau/Quantify.h" #include "Luau/RecursionCounter.h" #include "Luau/Scope.h" @@ -24,16 +23,12 @@ LUAU_FASTINTVARIABLE(LuauTypeInferTypePackLoopLimit, 5000) LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) -LUAU_FASTFLAGVARIABLE(LuauGroupExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) -LUAU_FASTFLAGVARIABLE(LuauIfElseBranchTypeUnion, false) -LUAU_FASTFLAGVARIABLE(LuauIfElseExpectedType2, false) LUAU_FASTFLAGVARIABLE(LuauImmutableTypes, false) -LUAU_FASTFLAGVARIABLE(LuauLengthOnCompositeType, false) LUAU_FASTFLAGVARIABLE(LuauNoSealedTypeMod, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) @@ -43,7 +38,6 @@ LUAU_FASTFLAGVARIABLE(LuauTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) -LUAU_FASTFLAGVARIABLE(LuauPerModuleUnificationCache, false) LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) LUAU_FASTFLAG(LuauUnionTagMatchFix) @@ -293,13 +287,10 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona GenericError{"Free types leaked into this module's public interface. This is an internal Luau error; please report it."}}); } - if (FFlag::LuauPerModuleUnificationCache) - { - // Clear unifier cache since it's keyed off internal types that get deallocated - // This avoids fake cross-module cache hits and keeps cache size at bay when typechecking large module graphs. - unifierState.cachedUnify.clear(); - unifierState.skipCacheForType.clear(); - } + // Clear unifier cache since it's keyed off internal types that get deallocated + // This avoids fake cross-module cache hits and keeps cache size at bay when typechecking large module graphs. + unifierState.cachedUnify.clear(); + unifierState.skipCacheForType.clear(); if (FFlag::LuauTwoPassAliasDefinitionFix) duplicateTypeAliases.clear(); @@ -509,7 +500,7 @@ LUAU_NOINLINE void TypeChecker::checkBlockTypeAliases(const ScopePtr& scope, std { if (const auto& typealias = stat->as()) { - if (FFlag::LuauTwoPassAliasDefinitionFix && typealias->name == Parser::errorName) + if (FFlag::LuauTwoPassAliasDefinitionFix && typealias->name == kParseNameError) continue; auto& bindings = typealias->exported ? scope->exportedTypeBindings : scope->privateTypeBindings; @@ -1193,7 +1184,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias Name name = typealias.name.value; // If the alias is missing a name, we can't do anything with it. Ignore it. - if (FFlag::LuauTwoPassAliasDefinitionFix && name == Parser::errorName) + if (FFlag::LuauTwoPassAliasDefinitionFix && name == kParseNameError) return; std::optional binding; @@ -1222,7 +1213,8 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias if (FFlag::LuauProperTypeLevels) aliasScope->level.subLevel = subLevel; - auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks, /* useCache = */ true); + auto [generics, genericPacks] = + createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks, /* useCache = */ true); TypeId ty = freshType(aliasScope); FreeTypeVar* ftv = getMutable(ty); @@ -1464,7 +1456,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& ExprResult result; if (auto a = expr.as()) - result = checkExpr(scope, *a->expr, FFlag::LuauGroupExpectedType ? expectedType : std::nullopt); + result = checkExpr(scope, *a->expr, expectedType); else if (expr.is()) result = {nilType}; else if (const AstExprConstantBool* bexpr = expr.as()) @@ -1508,7 +1500,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& else if (auto a = expr.as()) result = checkExpr(scope, *a); else if (auto a = expr.as()) - result = checkExpr(scope, *a, FFlag::LuauIfElseExpectedType2 ? expectedType : std::nullopt); + result = checkExpr(scope, *a, expectedType); else ice("Unhandled AstExpr?"); @@ -2093,6 +2085,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn return {numberType}; } case AstExprUnary::Len: + { tablify(operandType); operandType = stripFromNilAndReport(operandType, expr.location); @@ -2100,30 +2093,13 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn if (get(operandType)) return {errorRecoveryType(scope)}; - if (FFlag::LuauLengthOnCompositeType) - { - DenseHashSet seen{nullptr}; - - if (!hasLength(operandType, seen, &recursionCount)) - reportError(TypeError{expr.location, NotATable{operandType}}); - } - else - { - if (get(operandType)) - return {numberType}; // Not strictly correct: metatables permit overriding this - - if (auto p = get(operandType)) - { - if (p->type == PrimitiveTypeVar::String) - return {numberType}; - } + DenseHashSet seen{nullptr}; - if (!getTableType(operandType)) - reportError(TypeError{expr.location, NotATable{operandType}}); - } + if (!hasLength(operandType, seen, &recursionCount)) + reportError(TypeError{expr.location, NotATable{operandType}}); return {numberType}; - + } default: ice("Unknown AstExprUnary " + std::to_string(int(expr.op))); } @@ -2618,22 +2594,11 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIf resolve(result.predicates, falseScope, false); ExprResult falseType = checkExpr(falseScope, *expr.falseExpr, expectedType); - if (FFlag::LuauIfElseBranchTypeUnion) - { - if (falseType.type == trueType.type) - return {trueType.type}; - - std::vector types = reduceUnion({trueType.type, falseType.type}); - return {types.size() == 1 ? types[0] : addType(UnionTypeVar{std::move(types)})}; - } - else - { - unify(falseType.type, trueType.type, expr.location); - - // TODO: normalize(UnionTypeVar{{trueType, falseType}}) - // For now both trueType and falseType must be the same type. + if (falseType.type == trueType.type) return {trueType.type}; - } + + std::vector types = reduceUnion({trueType.type, falseType.type}); + return {types.size() == 1 ? types[0] : addType(UnionTypeVar{std::move(types)})}; } TypeId TypeChecker::checkLValue(const ScopePtr& scope, const AstExpr& expr) @@ -2986,8 +2951,8 @@ std::pair TypeChecker::checkFunctionSignature( auto [generics, genericPacks] = createGenericTypes(funScope, std::nullopt, expr, expr.generics, expr.genericPacks); TypePackId retPack; - if (expr.hasReturnAnnotation) - retPack = resolveTypePack(funScope, expr.returnAnnotation); + if (expr.returnAnnotation) + retPack = resolveTypePack(funScope, *expr.returnAnnotation); else if (isNonstrictMode()) retPack = anyTypePack; else if (expectedFunctionType) @@ -3181,7 +3146,7 @@ void TypeChecker::checkFunctionBody(const ScopePtr& scope, TypeId ty, const AstE // If we're in nonstrict mode we want to only report this missing return // statement if there are type annotations on the function. In strict mode // we report it regardless. - if (!isNonstrictMode() || function.hasReturnAnnotation) + if (!isNonstrictMode() || function.returnAnnotation) { reportError(getEndLocation(function), FunctionExitsWithoutReturning{funTy->retType}); } @@ -4403,11 +4368,7 @@ TypeId Instantiation::clean(TypeId ty) // Annoyingly, we have to do this even if there are no generics, // to replace any generic tables. - replaceGenerics.log = log; - replaceGenerics.level = level; - replaceGenerics.currentModule = currentModule; - replaceGenerics.generics.assign(ftv->generics.begin(), ftv->generics.end()); - replaceGenerics.genericPacks.assign(ftv->genericPacks.begin(), ftv->genericPacks.end()); + ReplaceGenerics replaceGenerics{log, arena, level, ftv->generics, ftv->genericPacks}; // TODO: What to do if this returns nullopt? // We don't have access to the error-reporting machinery @@ -4573,16 +4534,11 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location if (FFlag::LuauQuantifyInPlace2) { - Luau::quantify(currentModule, ty, scope->level); + Luau::quantify(ty, scope->level); return ty; } - quantification.log = TxnLog::empty(); - quantification.level = scope->level; - quantification.generics.clear(); - quantification.genericPacks.clear(); - quantification.currentModule = currentModule; - + Quantification quantification{¤tModule->internalTypes, scope->level}; std::optional qty = quantification.substitute(ty); if (!qty.has_value()) @@ -4596,18 +4552,14 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location FunctionTypeVar* qftv = getMutable(*qty); LUAU_ASSERT(qftv); - qftv->generics = quantification.generics; - qftv->genericPacks = quantification.genericPacks; + qftv->generics = std::move(quantification.generics); + qftv->genericPacks = std::move(quantification.genericPacks); return *qty; } TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log) { - LUAU_ASSERT(log != nullptr); - - instantiation.log = FFlag::LuauUseCommittingTxnLog ? log : TxnLog::empty(); - instantiation.level = scope->level; - instantiation.currentModule = currentModule; + Instantiation instantiation{FFlag::LuauUseCommittingTxnLog ? log : TxnLog::empty(), ¤tModule->internalTypes, scope->level}; std::optional instantiated = instantiation.substitute(ty); if (instantiated.has_value()) return *instantiated; @@ -4620,10 +4572,7 @@ TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location locat TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) { - anyification.log = TxnLog::empty(); - anyification.anyType = anyType; - anyification.anyTypePack = anyTypePack; - anyification.currentModule = currentModule; + Anyification anyification{¤tModule->internalTypes, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); if (any.has_value()) return *any; @@ -4636,10 +4585,7 @@ TypeId TypeChecker::anyify(const ScopePtr& scope, TypeId ty, Location location) TypePackId TypeChecker::anyify(const ScopePtr& scope, TypePackId ty, Location location) { - anyification.log = TxnLog::empty(); - anyification.anyType = anyType; - anyification.anyTypePack = anyTypePack; - anyification.currentModule = currentModule; + Anyification anyification{¤tModule->internalTypes, anyType, anyTypePack}; std::optional any = anyification.substitute(ty); if (any.has_value()) return *any; @@ -4823,7 +4769,8 @@ TypePackId TypeChecker::errorRecoveryTypePack(TypePackId guess) return getSingletonTypes().errorRecoveryTypePack(guess); } -TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense) { +TypeIdPredicate TypeChecker::mkTruthyPredicate(bool sense) +{ return [this, sense](TypeId ty) -> std::optional { // any/error/free gets a special pass unconditionally because they can't be decided. if (get(ty) || get(ty) || get(ty)) @@ -4904,8 +4851,8 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (const auto& lit = annotation.as()) { std::optional tf; - if (lit->hasPrefix) - tf = scope->lookupImportedType(lit->prefix.value, lit->name.value); + if (lit->prefix) + tf = scope->lookupImportedType(lit->prefix->value, lit->name.value); else if (FFlag::DebugLuauMagicTypes && lit->name == "_luau_ice") ice("_luau_ice encountered", lit->location); @@ -4932,12 +4879,12 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (!tf) { - if (lit->name == Parser::errorName) + if (lit->name == kParseNameError) return errorRecoveryType(scope); std::string typeName; - if (lit->hasPrefix) - typeName = std::string(lit->prefix.value) + "."; + if (lit->prefix) + typeName = std::string(lit->prefix->value) + "."; typeName += lit->name.value; if (scope->lookupPack(typeName)) @@ -5038,12 +4985,7 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (notEnoughParameters && hasDefaultParameters) { // 'applyTypeFunction' is used to substitute default types that reference previous generic types - applyTypeFunction.log = TxnLog::empty(); - applyTypeFunction.typeArguments.clear(); - applyTypeFunction.typePackArguments.clear(); - applyTypeFunction.currentModule = currentModule; - applyTypeFunction.level = scope->level; - applyTypeFunction.encounteredForwardedType = false; + ApplyTypeFunction applyTypeFunction{¤tModule->internalTypes, scope->level}; for (size_t i = 0; i < typesProvided; ++i) applyTypeFunction.typeArguments[tf->typeParams[i].ty] = typeParams[i]; @@ -5362,18 +5304,14 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, if (tf.typeParams.empty() && tf.typePackParams.empty()) return tf.type; - applyTypeFunction.typeArguments.clear(); + ApplyTypeFunction applyTypeFunction{¤tModule->internalTypes, scope->level}; + for (size_t i = 0; i < tf.typeParams.size(); ++i) applyTypeFunction.typeArguments[tf.typeParams[i].ty] = typeParams[i]; - applyTypeFunction.typePackArguments.clear(); for (size_t i = 0; i < tf.typePackParams.size(); ++i) applyTypeFunction.typePackArguments[tf.typePackParams[i].tp] = typePackParams[i]; - applyTypeFunction.log = TxnLog::empty(); - applyTypeFunction.currentModule = currentModule; - applyTypeFunction.level = scope->level; - applyTypeFunction.encounteredForwardedType = false; std::optional maybeInstantiated = applyTypeFunction.substitute(tf.type); if (!maybeInstantiated.has_value()) { diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 8c6d5e49f..593b54c84 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -5,6 +5,8 @@ #include "Luau/ToString.h" #include "Luau/TypeInfer.h" +LUAU_FASTFLAGVARIABLE(LuauTerminateCyclicMetatableIndexLookup, false) + namespace Luau { @@ -48,9 +50,19 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, const Sc } std::optional mtIndex = findMetatableEntry(errors, globalScope, ty, "__index", location); + int count = 0; while (mtIndex) { TypeId index = follow(*mtIndex); + + if (FFlag::LuauTerminateCyclicMetatableIndexLookup) + { + if (count >= 100) + return std::nullopt; + + ++count; + } + if (const auto& itt = getTableType(index)) { const auto& fit = itt->props.find(name); diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 7e438e31c..b2358c277 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -23,8 +23,6 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) -LUAU_FASTFLAG(LuauLengthOnCompositeType) -LUAU_FASTFLAGVARIABLE(LuauMetatableAreEqualRecursion, false) LUAU_FASTFLAGVARIABLE(LuauRefactorTypeVarQuestions, false) LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAG(LuauUnionTagMatchFix) @@ -385,8 +383,6 @@ bool maybeSingleton(TypeId ty) bool hasLength(TypeId ty, DenseHashSet& seen, int* recursionCount) { - LUAU_ASSERT(FFlag::LuauLengthOnCompositeType); - RecursionLimiter _rl(recursionCount, FInt::LuauTypeInferRecursionLimit); ty = follow(ty); @@ -555,7 +551,7 @@ bool areEqual(SeenSet& seen, const TableTypeVar& lhs, const TableTypeVar& rhs) static bool areEqual(SeenSet& seen, const MetatableTypeVar& lhs, const MetatableTypeVar& rhs) { - if (FFlag::LuauMetatableAreEqualRecursion && areSeen(seen, &lhs, &rhs)) + if (areSeen(seen, &lhs, &rhs)) return true; return areEqual(seen, *lhs.table, *rhs.table) && areEqual(seen, *lhs.metatable, *rhs.metatable); diff --git a/Analysis/src/TypedAllocator.cpp b/Analysis/src/TypedAllocator.cpp index f037351e5..c7f31822d 100644 --- a/Analysis/src/TypedAllocator.cpp +++ b/Analysis/src/TypedAllocator.cpp @@ -7,6 +7,9 @@ #ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN #endif +#ifndef NOMINMAX +#define NOMINMAX +#endif #include const size_t kPageSize = 4096; diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index a8ad51593..322f6ebf5 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -14,7 +14,6 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); -LUAU_FASTFLAGVARIABLE(LuauCommittingTxnLogFreeTpPromote, false) LUAU_FASTFLAG(LuauImmutableTypes) LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); @@ -23,7 +22,6 @@ LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAG(LuauProperTypeLevels); -LUAU_FASTFLAGVARIABLE(LuauUnifyPackTails, false) LUAU_FASTFLAGVARIABLE(LuauUnionTagMatchFix, false) LUAU_FASTFLAGVARIABLE(LuauFollowWithCommittingTxnLogInAnyUnification, false) @@ -116,7 +114,7 @@ struct PromoteTypeLevels { // Surprise, it's actually a BoundTypePack that hasn't been committed yet. // Calling getMutable on this will trigger an assertion. - if (FFlag::LuauCommittingTxnLogFreeTpPromote && FFlag::LuauUseCommittingTxnLog && !log.is(tp)) + if (FFlag::LuauUseCommittingTxnLog && !log.is(tp)) return true; promote(tp, FFlag::LuauUseCommittingTxnLog ? log.getMutable(tp) : getMutable(tp)); @@ -1242,7 +1240,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { - if (FFlag::LuauUnifyPackTails && subTpv->tail && superTpv->tail) + if (subTpv->tail && superTpv->tail) { tryUnify_(*subTpv->tail, *superTpv->tail); break; @@ -1250,9 +1248,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; - if (!FFlag::LuauUnifyPackTails && lFreeTail && rFreeTail) - tryUnify_(*subTpv->tail, *superTpv->tail); - else if (lFreeTail) + if (lFreeTail) tryUnify_(emptyTp, *superTpv->tail); else if (rFreeTail) tryUnify_(emptyTp, *subTpv->tail); @@ -1448,7 +1444,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal // If both are at the end, we're done if (!superIter.good() && !subIter.good()) { - if (FFlag::LuauUnifyPackTails && subTpv->tail && superTpv->tail) + if (subTpv->tail && superTpv->tail) { tryUnify_(*subTpv->tail, *superTpv->tail); break; @@ -1456,9 +1452,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal const bool lFreeTail = superTpv->tail && get(follow(*superTpv->tail)) != nullptr; const bool rFreeTail = subTpv->tail && get(follow(*subTpv->tail)) != nullptr; - if (!FFlag::LuauUnifyPackTails && lFreeTail && rFreeTail) - tryUnify_(*subTpv->tail, *superTpv->tail); - else if (lFreeTail) + if (lFreeTail) tryUnify_(emptyTp, *superTpv->tail); else if (rFreeTail) tryUnify_(emptyTp, *subTpv->tail); diff --git a/Ast/include/Luau/Ast.h b/Ast/include/Luau/Ast.h index ac5950c0a..31cd01ccd 100644 --- a/Ast/include/Luau/Ast.h +++ b/Ast/include/Luau/Ast.h @@ -594,8 +594,7 @@ class AstExprFunction : public AstExpr AstArray genericPacks; AstLocal* self; AstArray args; - bool hasReturnAnnotation; - AstTypeList returnAnnotation; + std::optional returnAnnotation; bool vararg = false; Location varargLocation; AstTypePack* varargAnnotation; @@ -740,7 +739,7 @@ class AstStatIf : public AstStat public: LUAU_RTTI(AstStatIf) - AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, bool hasThen, const Location& thenLocation, + AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, const std::optional& thenLocation, const std::optional& elseLocation, bool hasEnd); void visit(AstVisitor* visitor) override; @@ -749,12 +748,10 @@ class AstStatIf : public AstStat AstStatBlock* thenbody; AstStat* elsebody; - bool hasThen = false; - Location thenLocation; + std::optional thenLocation; // Active for 'elseif' as well - bool hasElse = false; - Location elseLocation; + std::optional elseLocation; bool hasEnd = false; }; @@ -849,8 +846,7 @@ class AstStatLocal : public AstStat AstArray vars; AstArray values; - bool hasEqualsSign = false; - Location equalsSignLocation; + std::optional equalsSignLocation; }; class AstStatFor : public AstStat @@ -1053,9 +1049,8 @@ class AstTypeReference : public AstType void visit(AstVisitor* visitor) override; - bool hasPrefix; bool hasParameterList; - AstName prefix; + std::optional prefix; AstName name; AstArray parameters; }; diff --git a/Ast/include/Luau/Lexer.h b/Ast/include/Luau/Lexer.h index 460ef0565..d7d867f48 100644 --- a/Ast/include/Luau/Lexer.h +++ b/Ast/include/Luau/Lexer.h @@ -233,4 +233,9 @@ class Lexer bool readNames; }; +inline bool isSpace(char ch) +{ + return ch == ' ' || ch == '\t' || ch == '\r' || ch == '\n' || ch == '\v' || ch == '\f'; +} + } // namespace Luau diff --git a/Ast/include/Luau/ParseResult.h b/Ast/include/Luau/ParseResult.h new file mode 100644 index 000000000..17ce2e3bb --- /dev/null +++ b/Ast/include/Luau/ParseResult.h @@ -0,0 +1,69 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +#pragma once + +#include "Luau/Common.h" +#include "Luau/Location.h" +#include "Luau/Lexer.h" +#include "Luau/StringUtils.h" + +namespace Luau +{ + +class AstStatBlock; + +class ParseError : public std::exception +{ +public: + ParseError(const Location& location, const std::string& message); + + virtual const char* what() const throw(); + + const Location& getLocation() const; + const std::string& getMessage() const; + + static LUAU_NORETURN void raise(const Location& location, const char* format, ...) LUAU_PRINTF_ATTR(2, 3); + +private: + Location location; + std::string message; +}; + +class ParseErrors : public std::exception +{ +public: + ParseErrors(std::vector errors); + + virtual const char* what() const throw(); + + const std::vector& getErrors() const; + +private: + std::vector errors; + std::string message; +}; + +struct HotComment +{ + bool header; + Location location; + std::string content; +}; + +struct Comment +{ + Lexeme::Type type; // Comment, BlockComment, or BrokenComment + Location location; +}; + +struct ParseResult +{ + AstStatBlock* root; + std::vector hotcomments; + std::vector errors; + + std::vector commentLocations; +}; + +static constexpr const char* kParseNameError = "%error-id%"; + +} // namespace Luau diff --git a/Ast/include/Luau/Parser.h b/Ast/include/Luau/Parser.h index 40ecdcdd5..4b5ae3150 100644 --- a/Ast/include/Luau/Parser.h +++ b/Ast/include/Luau/Parser.h @@ -4,6 +4,7 @@ #include "Luau/Ast.h" #include "Luau/Lexer.h" #include "Luau/ParseOptions.h" +#include "Luau/ParseResult.h" #include "Luau/StringUtils.h" #include "Luau/DenseHash.h" #include "Luau/Common.h" @@ -14,37 +15,6 @@ namespace Luau { -class ParseError : public std::exception -{ -public: - ParseError(const Location& location, const std::string& message); - - virtual const char* what() const throw(); - - const Location& getLocation() const; - const std::string& getMessage() const; - - static LUAU_NORETURN void raise(const Location& location, const char* format, ...) LUAU_PRINTF_ATTR(2, 3); - -private: - Location location; - std::string message; -}; - -class ParseErrors : public std::exception -{ -public: - ParseErrors(std::vector errors); - - virtual const char* what() const throw(); - - const std::vector& getErrors() const; - -private: - std::vector errors; - std::string message; -}; - template class TempVector { @@ -80,34 +50,17 @@ class TempVector size_t size_; }; -struct Comment -{ - Lexeme::Type type; // Comment, BlockComment, or BrokenComment - Location location; -}; - -struct ParseResult -{ - AstStatBlock* root; - std::vector hotcomments; - std::vector errors; - - std::vector commentLocations; -}; - class Parser { public: static ParseResult parse( const char* buffer, std::size_t bufferSize, AstNameTable& names, Allocator& allocator, ParseOptions options = ParseOptions()); - static constexpr const char* errorName = "%error-id%"; - private: struct Name; struct Binding; - Parser(const char* buffer, std::size_t bufferSize, AstNameTable& names, Allocator& allocator); + Parser(const char* buffer, std::size_t bufferSize, AstNameTable& names, Allocator& allocator, const ParseOptions& options); bool blockFollow(const Lexeme& l); @@ -330,7 +283,7 @@ class Parser AstTypeError* reportTypeAnnotationError(const Location& location, const AstArray& types, bool isMissing, const char* format, ...) LUAU_PRINTF_ATTR(5, 6); - const Lexeme& nextLexeme(); + void nextLexeme(); struct Function { @@ -386,6 +339,9 @@ class Parser Allocator& allocator; std::vector commentLocations; + std::vector hotcomments; + + bool hotcommentHeader = true; unsigned int recursionCounter; diff --git a/Ast/src/Ast.cpp b/Ast/src/Ast.cpp index 9b5bc0c71..24a280da1 100644 --- a/Ast/src/Ast.cpp +++ b/Ast/src/Ast.cpp @@ -167,8 +167,7 @@ AstExprFunction::AstExprFunction(const Location& location, const AstArrayreturnAnnotation = *returnAnnotation; } void AstExprFunction::visit(AstVisitor* visitor) @@ -195,8 +192,8 @@ void AstExprFunction::visit(AstVisitor* visitor) if (varargAnnotation) varargAnnotation->visit(visitor); - if (hasReturnAnnotation) - visitTypeList(visitor, returnAnnotation); + if (returnAnnotation) + visitTypeList(visitor, *returnAnnotation); body->visit(visitor); } @@ -375,21 +372,16 @@ void AstStatBlock::visit(AstVisitor* visitor) } } -AstStatIf::AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, bool hasThen, - const Location& thenLocation, const std::optional& elseLocation, bool hasEnd) +AstStatIf::AstStatIf(const Location& location, AstExpr* condition, AstStatBlock* thenbody, AstStat* elsebody, + const std::optional& thenLocation, const std::optional& elseLocation, bool hasEnd) : AstStat(ClassIndex(), location) , condition(condition) , thenbody(thenbody) , elsebody(elsebody) - , hasThen(hasThen) , thenLocation(thenLocation) + , elseLocation(elseLocation) , hasEnd(hasEnd) { - if (bool(elseLocation)) - { - hasElse = true; - this->elseLocation = *elseLocation; - } } void AstStatIf::visit(AstVisitor* visitor) @@ -492,12 +484,8 @@ AstStatLocal::AstStatLocal( : AstStat(ClassIndex(), location) , vars(vars) , values(values) + , equalsSignLocation(equalsSignLocation) { - if (bool(equalsSignLocation)) - { - hasEqualsSign = true; - this->equalsSignLocation = *equalsSignLocation; - } } void AstStatLocal::visit(AstVisitor* visitor) @@ -750,9 +738,8 @@ void AstStatError::visit(AstVisitor* visitor) AstTypeReference::AstTypeReference( const Location& location, std::optional prefix, AstName name, bool hasParameterList, const AstArray& parameters) : AstType(ClassIndex(), location) - , hasPrefix(bool(prefix)) , hasParameterList(hasParameterList) - , prefix(prefix ? *prefix : AstName()) + , prefix(prefix) , name(name) , parameters(parameters) { diff --git a/Ast/src/Lexer.cpp b/Ast/src/Lexer.cpp index a7aa24ca9..d56c88608 100644 --- a/Ast/src/Lexer.cpp +++ b/Ast/src/Lexer.cpp @@ -101,11 +101,6 @@ Lexeme::Lexeme(const Location& location, Type type, const char* name) LUAU_ASSERT(type == Name || (type >= Reserved_BEGIN && type < Lexeme::Reserved_END)); } -static bool isComment(const Lexeme& lexeme) -{ - return lexeme.type == Lexeme::Comment || lexeme.type == Lexeme::BlockComment; -} - static const char* kReserved[] = {"and", "break", "do", "else", "elseif", "end", "false", "for", "function", "if", "in", "local", "nil", "not", "or", "repeat", "return", "then", "true", "until", "while"}; @@ -282,11 +277,6 @@ AstName AstNameTable::get(const char* name) const return getWithType(name, strlen(name)).first; } -inline bool isSpace(char ch) -{ - return ch == ' ' || ch == '\t' || ch == '\r' || ch == '\n' || ch == '\v' || ch == '\f'; -} - inline bool isAlpha(char ch) { // use or trick to convert to lower case and unsigned comparison to do range check @@ -372,7 +362,7 @@ const Lexeme& Lexer::next(bool skipComments) prevLocation = lexeme.location; lexeme = readNext(); - } while (skipComments && isComment(lexeme)); + } while (skipComments && (lexeme.type == Lexeme::Comment || lexeme.type == Lexeme::BlockComment)); return lexeme; } diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 30b32f914..235d6349d 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -13,18 +13,15 @@ LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauParseTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauParseRecoverTypePackEllipsis, false) -LUAU_FASTFLAGVARIABLE(LuauStartingBrokenComment, false) +LUAU_FASTFLAGVARIABLE(LuauParseAllHotComments, false) +LUAU_FASTFLAGVARIABLE(LuauTableFieldFunctionDebugname, false) namespace Luau { -inline bool isSpace(char ch) -{ - return ch == ' ' || ch == '\t' || ch == '\r' || ch == '\n' || ch == '\v' || ch == '\f'; -} - static bool isComment(const Lexeme& lexeme) { + LUAU_ASSERT(!FFlag::LuauParseAllHotComments); return lexeme.type == Lexeme::Comment || lexeme.type == Lexeme::BlockComment; } @@ -151,31 +148,37 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n { LUAU_TIMETRACE_SCOPE("Parser::parse", "Parser"); - Parser p(buffer, bufferSize, names, allocator); + Parser p(buffer, bufferSize, names, allocator, FFlag::LuauParseAllHotComments ? options : ParseOptions()); try { - std::vector hotcomments; + if (FFlag::LuauParseAllHotComments) + { + AstStatBlock* root = p.parseChunk(); - while (isComment(p.lexer.current()) || p.lexer.current().type == Lexeme::BrokenComment) + return ParseResult{root, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)}; + } + else { - const char* text = p.lexer.current().data; - unsigned int length = p.lexer.current().length; + std::vector hotcomments; - if (length && text[0] == '!') + while (isComment(p.lexer.current()) || p.lexer.current().type == Lexeme::BrokenComment) { - unsigned int end = length; - while (end > 0 && isSpace(text[end - 1])) - --end; + const char* text = p.lexer.current().data; + unsigned int length = p.lexer.current().length; - hotcomments.push_back(std::string(text + 1, text + end)); - } + if (length && text[0] == '!') + { + unsigned int end = length; + while (end > 0 && isSpace(text[end - 1])) + --end; - const Lexeme::Type type = p.lexer.current().type; - const Location loc = p.lexer.current().location; + hotcomments.push_back({true, p.lexer.current().location, std::string(text + 1, text + end)}); + } + + const Lexeme::Type type = p.lexer.current().type; + const Location loc = p.lexer.current().location; - if (FFlag::LuauStartingBrokenComment) - { if (options.captureComments) p.commentLocations.push_back(Comment{type, loc}); @@ -184,22 +187,15 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n p.lexer.next(); } - else - { - p.lexer.next(); - if (options.captureComments) - p.commentLocations.push_back(Comment{type, loc}); - } - } - - p.lexer.setSkipComments(true); + p.lexer.setSkipComments(true); - p.options = options; + p.options = options; - AstStatBlock* root = p.parseChunk(); + AstStatBlock* root = p.parseChunk(); - return ParseResult{root, hotcomments, p.parseErrors, std::move(p.commentLocations)}; + return ParseResult{root, hotcomments, p.parseErrors, std::move(p.commentLocations)}; + } } catch (ParseError& err) { @@ -210,8 +206,9 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n } } -Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Allocator& allocator) - : lexer(buffer, bufferSize, names) +Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Allocator& allocator, const ParseOptions& options) + : options(options) + , lexer(buffer, bufferSize, names) , allocator(allocator) , recursionCounter(0) , endMismatchSuspect(Location(), Lexeme::Eof) @@ -224,14 +221,20 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc nameSelf = names.addStatic("self"); nameNumber = names.addStatic("number"); - nameError = names.addStatic(errorName); + nameError = names.addStatic(kParseNameError); nameNil = names.getOrAdd("nil"); // nil is a reserved keyword matchRecoveryStopOnToken.assign(Lexeme::Type::Reserved_END, 0); matchRecoveryStopOnToken[Lexeme::Type::Eof] = 1; + if (FFlag::LuauParseAllHotComments) + lexer.setSkipComments(true); + // read first lexeme nextLexeme(); + + // all hot comments parsed after the first non-comment lexeme are special in that they don't affect type checking / linting mode + hotcommentHeader = false; } bool Parser::blockFollow(const Lexeme& l) @@ -396,7 +399,9 @@ AstStat* Parser::parseIf() AstExpr* cond = parseExpr(); Lexeme matchThen = lexer.current(); - bool hasThen = expectAndConsume(Lexeme::ReservedThen, "if statement"); + std::optional thenLocation; + if (expectAndConsume(Lexeme::ReservedThen, "if statement")) + thenLocation = matchThen.location; AstStatBlock* thenbody = parseBlock(); @@ -434,7 +439,7 @@ AstStat* Parser::parseIf() hasEnd = expectMatchEndAndConsume(Lexeme::ReservedEnd, matchThenElse); } - return allocator.alloc(Location(start, end), cond, thenbody, elsebody, hasThen, matchThen.location, elseLocation, hasEnd); + return allocator.alloc(Location(start, end), cond, thenbody, elsebody, thenLocation, elseLocation, hasEnd); } // while exp do block end @@ -769,7 +774,7 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) { // note: `type` token is already parsed for us, so we just need to parse the rest - auto name = parseNameOpt("type name"); + std::optional name = parseNameOpt("type name"); // Use error name if the name is missing if (!name) @@ -925,7 +930,7 @@ AstStat* Parser::parseDeclaration(const Location& start) return allocator.alloc(Location(classStart, classEnd), className.name, superName, copy(props)); } - else if (auto globalName = parseNameOpt("global variable name")) + else if (std::optional globalName = parseNameOpt("global variable name")) { expectAndConsume(':', "global variable declaration"); @@ -1066,7 +1071,7 @@ void Parser::parseExprList(TempVector& result) Parser::Binding Parser::parseBinding() { - auto name = parseNameOpt("variable name"); + std::optional name = parseNameOpt("variable name"); // Use placeholder if the name is missing if (!name) @@ -1325,7 +1330,7 @@ AstType* Parser::parseTableTypeAnnotation() } else { - auto name = parseNameOpt("table field"); + std::optional name = parseNameOpt("table field"); if (!name) break; @@ -1422,7 +1427,7 @@ AstType* Parser::parseFunctionTypeAnnotationTail(const Lexeme& begin, AstArray(Location(begin.location, endLocation), generics, genericPacks, paramTypes, paramNames, returnTypeList); @@ -1869,7 +1874,7 @@ AstExpr* Parser::parseExpr(unsigned int limit) // NAME AstExpr* Parser::parseNameExpr(const char* context) { - auto name = parseNameOpt(context); + std::optional name = parseNameOpt(context); if (!name) return allocator.alloc(lexer.current().location, copy({}), unsigned(parseErrors.size() - 1)); @@ -2233,6 +2238,12 @@ AstExpr* Parser::parseTableConstructor() AstExpr* key = allocator.alloc(name.location, nameString); AstExpr* value = parseExpr(); + if (FFlag::LuauTableFieldFunctionDebugname) + { + if (AstExprFunction* func = value->as()) + func->debugname = name.name; + } + items.push_back({AstExprTable::Item::Record, key, value}); } else @@ -2313,7 +2324,7 @@ std::optional Parser::parseNameOpt(const char* context) Parser::Name Parser::parseName(const char* context) { - if (auto name = parseNameOpt(context)) + if (std::optional name = parseNameOpt(context)) return *name; Location location = lexer.current().location; @@ -2324,7 +2335,7 @@ Parser::Name Parser::parseName(const char* context) Parser::Name Parser::parseIndexName(const char* context, const Position& previous) { - if (auto name = parseNameOpt(context)) + if (std::optional name = parseNameOpt(context)) return *name; // If we have a reserved keyword next at the same line, assume it's an incomplete name @@ -2379,7 +2390,7 @@ std::pair, AstArray> Parser::parseG if (shouldParseTypePackAnnotation(lexer)) { - auto typePack = parseTypePackAnnotation(); + AstTypePack* typePack = parseTypePackAnnotation(); namePacks.push_back({name, nameLocation, typePack}); } @@ -2451,7 +2462,7 @@ AstArray Parser::parseTypeParams() { if (shouldParseTypePackAnnotation(lexer)) { - auto typePack = parseTypePackAnnotation(); + AstTypePack* typePack = parseTypePackAnnotation(); parameters.push_back({{}, typePack}); } @@ -2821,25 +2832,57 @@ AstTypeError* Parser::reportTypeAnnotationError(const Location& location, const return allocator.alloc(location, types, isMissing, unsigned(parseErrors.size() - 1)); } -const Lexeme& Parser::nextLexeme() +void Parser::nextLexeme() { if (options.captureComments) { - while (true) + if (FFlag::LuauParseAllHotComments) { - const Lexeme& lexeme = lexer.next(/*skipComments*/ false); - // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. - // The parser will turn this into a proper syntax error. - if (lexeme.type == Lexeme::BrokenComment) - commentLocations.push_back(Comment{lexeme.type, lexeme.location}); - if (isComment(lexeme)) + Lexeme::Type type = lexer.next(/* skipComments= */ false).type; + + while (type == Lexeme::BrokenComment || type == Lexeme::Comment || type == Lexeme::BlockComment) + { + const Lexeme& lexeme = lexer.current(); commentLocations.push_back(Comment{lexeme.type, lexeme.location}); - else - return lexeme; + + // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. + // The parser will turn this into a proper syntax error. + if (lexeme.type == Lexeme::BrokenComment) + return; + + // Comments starting with ! are called "hot comments" and contain directives for type checking / linting + if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!') + { + const char* text = lexeme.data; + + unsigned int end = lexeme.length; + while (end > 0 && isSpace(text[end - 1])) + --end; + + hotcomments.push_back({hotcommentHeader, lexeme.location, std::string(text + 1, text + end)}); + } + + type = lexer.next(/* skipComments= */ false).type; + } + } + else + { + while (true) + { + const Lexeme& lexeme = lexer.next(/*skipComments*/ false); + // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. + // The parser will turn this into a proper syntax error. + if (lexeme.type == Lexeme::BrokenComment) + commentLocations.push_back(Comment{lexeme.type, lexeme.location}); + if (isComment(lexeme)) + commentLocations.push_back(Comment{lexeme.type, lexeme.location}); + else + return; + } } } else - return lexer.next(); + lexer.next(); } } // namespace Luau diff --git a/Ast/src/TimeTrace.cpp b/Ast/src/TimeTrace.cpp index ded50e53e..8079830b3 100644 --- a/Ast/src/TimeTrace.cpp +++ b/Ast/src/TimeTrace.cpp @@ -9,6 +9,12 @@ #include #ifdef _WIN32 +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#ifndef NOMINMAX +#define NOMINMAX +#endif #include #endif diff --git a/CLI/FileUtils.cpp b/CLI/FileUtils.cpp index fe005aece..fb6ac3734 100644 --- a/CLI/FileUtils.cpp +++ b/CLI/FileUtils.cpp @@ -4,8 +4,13 @@ #include "Luau/Common.h" #ifdef _WIN32 +#ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN -#include +#endif +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include #else #include #include diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 9a6e25c28..13304d57f 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -37,6 +37,8 @@ enum class CompileFormat Binary }; +constexpr int MaxTraversalLimit = 50; + struct GlobalOptions { int optimizationLevel = 1; @@ -243,72 +245,143 @@ std::string runCode(lua_State* L, const std::string& source) return std::string(); } -static void completeIndexer(lua_State* L, const std::string& editBuffer, const AddCompletionCallback& addCompletionCallback) +// Replaces the top of the lua stack with the metatable __index for the value +// if it exists. Returns true iff __index exists. +static bool tryReplaceTopWithIndex(lua_State* L) { - std::string_view lookup = editBuffer; - char lastSep = 0; - - for (;;) + if (luaL_getmetafield(L, -1, "__index")) { - size_t sep = lookup.find_first_of(".:"); - std::string_view prefix = lookup.substr(0, sep); + // Remove the table leaving __index on the top of stack + lua_remove(L, -2); + return true; + } + return false; +} - if (sep == std::string_view::npos) + +// This function is similar to lua_gettable, but it avoids calling any +// lua callback functions (e.g. __index) which might modify the Lua VM state. +static void safeGetTable(lua_State* L, int tableIndex) +{ + lua_pushvalue(L, tableIndex); // Duplicate the table + + // The loop invariant is that the table to search is at -1 + // and the key is at -2. + for (int loopCount = 0;; loopCount++) + { + lua_pushvalue(L, -2); // Duplicate the key + lua_rawget(L, -2); // Try to find the key + if (!lua_isnil(L, -1) || loopCount >= MaxTraversalLimit) { - // table, key - lua_pushnil(L); + // Either the key has been found, and/or we have reached the max traversal limit + break; + } + else + { + lua_pop(L, 1); // Pop the nil result + if (!luaL_getmetafield(L, -1, "__index")) + { + lua_pushnil(L); + break; + } + else if (lua_istable(L, -1)) + { + // Replace the current table being searched with __index table + lua_replace(L, -2); + } + else + { + lua_pop(L, 1); // Pop the value + lua_pushnil(L); + break; + } + } + } - while (lua_next(L, -2) != 0) + lua_remove(L, -2); // Remove the table + lua_remove(L, -2); // Remove the original key +} + +// completePartialMatches finds keys that match the specified 'prefix' +// Note: the table/object to be searched must be on the top of the Lua stack +static void completePartialMatches(lua_State* L, bool completeOnlyFunctions, const std::string& editBuffer, std::string_view prefix, + const AddCompletionCallback& addCompletionCallback) +{ + for (int i = 0; i < MaxTraversalLimit && lua_istable(L, -1); i++) + { + // table, key + lua_pushnil(L); + + // Loop over all the keys in the current table + while (lua_next(L, -2) != 0) + { + if (lua_type(L, -2) == LUA_TSTRING) { - if (lua_type(L, -2) == LUA_TSTRING) - { - // table, key, value - std::string_view key = lua_tostring(L, -2); - int valueType = lua_type(L, -1); + // table, key, value + std::string_view key = lua_tostring(L, -2); + int valueType = lua_type(L, -1); - // If the last separator was a ':' (i.e. a method call) then only functions should be completed. - bool requiredValueType = (lastSep != ':' || valueType == LUA_TFUNCTION); + // If the last separator was a ':' (i.e. a method call) then only functions should be completed. + bool requiredValueType = (!completeOnlyFunctions || valueType == LUA_TFUNCTION); - if (!key.empty() && requiredValueType && Luau::startsWith(key, prefix)) + if (!key.empty() && requiredValueType && Luau::startsWith(key, prefix)) + { + std::string completedComponent(key.substr(prefix.size())); + std::string completion(editBuffer + completedComponent); + if (valueType == LUA_TFUNCTION) { - std::string completedComponent(key.substr(prefix.size())); - std::string completion(editBuffer + completedComponent); - if (valueType == LUA_TFUNCTION) - { - // Add an opening paren for function calls by default. - completion += "("; - } - addCompletionCallback(completion, std::string(key)); + // Add an opening paren for function calls by default. + completion += "("; } + addCompletionCallback(completion, std::string(key)); } - lua_pop(L, 1); } + lua_pop(L, 1); + } + + // Replace the current table being searched with an __index table if one exists + if (!tryReplaceTopWithIndex(L)) + { + break; + } + } +} + +static void completeIndexer(lua_State* L, const std::string& editBuffer, const AddCompletionCallback& addCompletionCallback) +{ + std::string_view lookup = editBuffer; + bool completeOnlyFunctions = false; + + // Push the global variable table to begin the search + lua_pushvalue(L, LUA_GLOBALSINDEX); + + for (;;) + { + size_t sep = lookup.find_first_of(".:"); + std::string_view prefix = lookup.substr(0, sep); + if (sep == std::string_view::npos) + { + completePartialMatches(L, completeOnlyFunctions, editBuffer, prefix, addCompletionCallback); break; } else { // find the key in the table lua_pushlstring(L, prefix.data(), prefix.size()); - lua_rawget(L, -2); + safeGetTable(L, -2); lua_remove(L, -2); - if (lua_type(L, -1) == LUA_TSTRING) + if (lua_istable(L, -1) || tryReplaceTopWithIndex(L)) { - // Replace the string object with the string class to perform further lookups of string functions - // Note: We retrieve the string class from _G to prevent issues if the user assigns to `string`. - lua_pop(L, 1); // Pop the string instance - lua_getglobal(L, "_G"); - lua_pushlstring(L, "string", 6); - lua_rawget(L, -2); - lua_remove(L, -2); // Remove the global table - LUAU_ASSERT(lua_istable(L, -1)); + completeOnlyFunctions = lookup[sep] == ':'; + lookup.remove_prefix(sep + 1); } - else if (!lua_istable(L, -1)) + else + { + // Unable to search for keys, so stop searching break; - - lastSep = lookup[sep]; - lookup.remove_prefix(sep + 1); + } } } @@ -317,12 +390,6 @@ static void completeIndexer(lua_State* L, const std::string& editBuffer, const A void getCompletions(lua_State* L, const std::string& editBuffer, const AddCompletionCallback& addCompletionCallback) { - // look the value up in current global table first - lua_pushvalue(L, LUA_GLOBALSINDEX); - completeIndexer(L, editBuffer, addCompletionCallback); - - // and in actual global table after that - lua_getglobal(L, "_G"); completeIndexer(L, editBuffer, addCompletionCallback); } @@ -365,9 +432,7 @@ struct LinenoiseScopedHistory ic_set_history(historyFilepath.c_str(), -1 /* default entries (= 200) */); } - ~LinenoiseScopedHistory() - { - } + ~LinenoiseScopedHistory() {} std::string historyFilepath; }; diff --git a/CMakeLists.txt b/CMakeLists.txt index c19d2b40b..c6ccebc54 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -104,6 +104,20 @@ if (MSVC AND MSVC_VERSION GREATER_EQUAL 1924) set_source_files_properties(VM/src/lvmexecute.cpp PROPERTIES COMPILE_FLAGS /d2ssa-pre-) endif() +# embed .natvis inside the library debug information +if(MSVC) + target_link_options(Luau.Ast INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/Ast.natvis) + target_link_options(Luau.Analysis INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/Analysis.natvis) + target_link_options(Luau.VM INTERFACE /NATVIS:${CMAKE_CURRENT_SOURCE_DIR}/tools/natvis/VM.natvis) +endif() + +# make .natvis visible inside the solution +if(MSVC_IDE) + target_sources(Luau.Ast PRIVATE tools/natvis/Ast.natvis) + target_sources(Luau.Analysis PRIVATE tools/natvis/Analysis.natvis) + target_sources(Luau.VM PRIVATE tools/natvis/VM.natvis) +endif() + if(LUAU_BUILD_CLI) target_compile_options(Luau.Repl.CLI PRIVATE ${LUAU_OPTIONS}) target_compile_options(Luau.Analyze.CLI PRIVATE ${LUAU_OPTIONS}) diff --git a/Sources.cmake b/Sources.cmake index 773f6f351..615641eb1 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -8,6 +8,7 @@ target_sources(Luau.Ast PRIVATE Ast/include/Luau/Location.h Ast/include/Luau/ParseOptions.h Ast/include/Luau/Parser.h + Ast/include/Luau/ParseResult.h Ast/include/Luau/StringUtils.h Ast/include/Luau/TimeTrace.h diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 5cffba63c..29d5f397e 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -36,12 +36,9 @@ const char* luau_ident = "$Luau: Copyright (C) 2019-2022 Roblox Corporation $\n" static Table* getcurrenv(lua_State* L) { if (L->ci == L->base_ci) /* no enclosing function? */ - return hvalue(gt(L)); /* use global table as environment */ + return L->gt; /* use global table as environment */ else - { - Closure* func = curr_func(L); - return func->env; - } + return curr_func(L)->env; } static LUAU_NOINLINE TValue* pseudo2addr(lua_State* L, int idx) @@ -53,11 +50,14 @@ static LUAU_NOINLINE TValue* pseudo2addr(lua_State* L, int idx) return registry(L); case LUA_ENVIRONINDEX: { - sethvalue(L, &L->env, getcurrenv(L)); - return &L->env; + sethvalue(L, &L->global->pseudotemp, getcurrenv(L)); + return &L->global->pseudotemp; } case LUA_GLOBALSINDEX: - return gt(L); + { + sethvalue(L, &L->global->pseudotemp, L->gt); + return &L->global->pseudotemp; + } default: { Closure* func = curr_func(L); @@ -237,6 +237,11 @@ void lua_replace(lua_State* L, int idx) func->env = hvalue(L->top - 1); luaC_barrier(L, func, L->top - 1); } + else if (idx == LUA_GLOBALSINDEX) + { + api_check(L, ttistable(L->top - 1)); + L->gt = hvalue(L->top - 1); + } else { setobj(L, o, L->top - 1); @@ -783,7 +788,7 @@ void lua_getfenv(lua_State* L, int idx) sethvalue(L, L->top, clvalue(o)->env); break; case LUA_TTHREAD: - setobj2s(L, L->top, gt(thvalue(o))); + sethvalue(L, L->top, thvalue(o)->gt); break; default: setnilvalue(L->top); @@ -914,7 +919,7 @@ int lua_setfenv(lua_State* L, int idx) clvalue(o)->env = hvalue(L->top - 1); break; case LUA_TTHREAD: - sethvalue(L, gt(thvalue(o)), hvalue(L->top - 1)); + thvalue(o)->gt = hvalue(L->top - 1); break; default: res = 0; diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index e9930f7ab..a4f93c621 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -419,7 +419,7 @@ static void getcoverage(Proto* p, int depth, int* buffer, size_t size, void* con } const char* debugname = p->debugname ? getstr(p->debugname) : NULL; - int linedefined = luaG_getline(p, 0); + int linedefined = getlinedefined(p); callback(context, debugname, linedefined, depth, buffer, size); diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index a3982bc68..d87f06618 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,6 +17,8 @@ #include +LUAU_FASTFLAG(LuauReduceStackReallocs) + /* ** {====================================================== ** Error-recovery functions @@ -164,13 +166,14 @@ static void correctstack(lua_State* L, TValue* oldstack) void luaD_reallocstack(lua_State* L, int newsize) { TValue* oldstack = L->stack; - int realsize = newsize + 1 + EXTRA_STACK; - LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - EXTRA_STACK - 1); + int realsize = newsize + (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK); + LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK)); luaM_reallocarray(L, L->stack, L->stacksize, realsize, TValue, L->memcat); + TValue* newstack = L->stack; for (int i = L->stacksize; i < realsize; i++) - setnilvalue(L->stack + i); /* erase new segment */ + setnilvalue(newstack + i); /* erase new segment */ L->stacksize = realsize; - L->stack_last = L->stack + newsize; + L->stack_last = newstack + newsize; correctstack(L, oldstack); } @@ -512,7 +515,7 @@ static void callerrfunc(lua_State* L, void* ud) static void restore_stack_limit(lua_State* L) { - LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - EXTRA_STACK - 1); + LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK)); if (L->size_ci > LUAI_MAXCALLS) { /* there was an overflow? */ int inuse = cast_int(L->ci - L->base_ci); diff --git a/VM/src/ldo.h b/VM/src/ldo.h index 72807f0f3..1c1480d68 100644 --- a/VM/src/ldo.h +++ b/VM/src/ldo.h @@ -11,7 +11,7 @@ if ((char*)L->stack_last - (char*)L->top <= (n) * (int)sizeof(TValue)) \ luaD_growstack(L, n); \ else \ - condhardstacktests(luaD_reallocstack(L, L->stacksize - EXTRA_STACK - 1)); + condhardstacktests(luaD_reallocstack(L, L->stacksize - (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK))); #define incr_top(L) \ { \ diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 835572fa7..724b24b2a 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -268,7 +268,7 @@ static void traverseclosure(global_State* g, Closure* cl) static void traversestack(global_State* g, lua_State* l, bool clearstack) { - markvalue(g, gt(l)); + markobject(g, l->gt); if (l->namecall) stringmark(l->namecall); for (StkId o = l->stack; o < l->top; o++) @@ -643,7 +643,7 @@ static void markroot(lua_State* L) g->weak = NULL; markobject(g, g->mainthread); /* make global table be traversed before main stack */ - markvalue(g, gt(g->mainthread)); + markobject(g, g->mainthread->gt); markvalue(g, registry(L)); markmt(g); g->gcstate = GCSpropagate; diff --git a/VM/src/lgc.h b/VM/src/lgc.h index 528d09446..2acb5d8aa 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -77,7 +77,7 @@ #define luaC_checkGC(L) \ { \ - condhardstacktests(luaD_reallocstack(L, L->stacksize - EXTRA_STACK - 1)); \ + condhardstacktests(luaD_reallocstack(L, L->stacksize - (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK))); \ if (L->global->totalbytes >= L->global->GCthreshold) \ { \ condhardmemtests(luaC_validate(L), 1); \ diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index ce1965200..30242e526 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -88,7 +88,7 @@ static void validateclosure(global_State* g, Closure* cl) static void validatestack(global_State* g, lua_State* l) { - validateref(g, obj2gco(l), gt(l)); + validateobjref(g, obj2gco(l), obj2gco(l->gt)); for (CallInfo* ci = l->base_ci; ci <= l->ci; ++ci) { @@ -370,6 +370,7 @@ static void dumpclosure(FILE* f, Closure* cl) fprintf(f, ",\"env\":"); dumpref(f, obj2gco(cl->env)); + if (cl->isC) { if (cl->nupvalues) @@ -411,11 +412,8 @@ static void dumpthread(FILE* f, lua_State* th) fprintf(f, "{\"type\":\"thread\",\"cat\":%d,\"size\":%d", th->memcat, int(size)); - if (iscollectable(&th->l_gt)) - { - fprintf(f, ",\"env\":"); - dumpref(f, gcvalue(&th->l_gt)); - } + fprintf(f, ",\"env\":"); + dumpref(f, obj2gco(th->gt)); Closure* tcl = 0; for (CallInfo* ci = th->base_ci; ci <= th->ci; ++ci) diff --git a/VM/src/lperf.cpp b/VM/src/lperf.cpp index 2f6c72972..da68e3766 100644 --- a/VM/src/lperf.cpp +++ b/VM/src/lperf.cpp @@ -3,6 +3,12 @@ #include "lua.h" #ifdef _WIN32 +#ifndef WIN32_LEAN_AND_MEAN +#define WIN32_LEAN_AND_MEAN +#endif +#ifndef NOMINMAX +#define NOMINMAX +#endif #include #endif diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index 6762c6380..d6d127c02 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -11,6 +11,7 @@ #include "ldebug.h" LUAU_FASTFLAG(LuauGcPagedSweep) +LUAU_FASTFLAGVARIABLE(LuauReduceStackReallocs, false) /* ** Main thread combines a thread state and the global state @@ -31,10 +32,11 @@ static void stack_init(lua_State* L1, lua_State* L) /* initialize stack array */ L1->stack = luaM_newarray(L, BASIC_STACK_SIZE + EXTRA_STACK, TValue, L1->memcat); L1->stacksize = BASIC_STACK_SIZE + EXTRA_STACK; + TValue* stack = L1->stack; for (int i = 0; i < BASIC_STACK_SIZE + EXTRA_STACK; i++) - setnilvalue(L1->stack + i); /* erase new stack */ - L1->top = L1->stack; - L1->stack_last = L1->stack + (L1->stacksize - EXTRA_STACK) - 1; + setnilvalue(stack + i); /* erase new stack */ + L1->top = stack; + L1->stack_last = stack + (L1->stacksize - (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK)); /* initialize first ci */ L1->ci->func = L1->top; setnilvalue(L1->top++); /* `function' entry for this `ci' */ @@ -55,7 +57,7 @@ static void f_luaopen(lua_State* L, void* ud) { global_State* g = L->global; stack_init(L, L); /* init stack */ - sethvalue(L, gt(L), luaH_new(L, 0, 2)); /* table of globals */ + L->gt = luaH_new(L, 0, 2); /* table of globals */ sethvalue(L, registry(L), luaH_new(L, 0, 2)); /* registry */ luaS_resize(L, LUA_MINSTRTABSIZE); /* initial size of string table */ luaT_init(L); @@ -69,6 +71,7 @@ static void preinit_state(lua_State* L, global_State* g) L->global = g; L->stack = NULL; L->stacksize = 0; + L->gt = NULL; L->openupval = NULL; L->size_ci = 0; L->nCcalls = L->baseCcalls = 0; @@ -80,7 +83,6 @@ static void preinit_state(lua_State* L, global_State* g) L->stackstate = 0; L->activememcat = 0; L->userdata = NULL; - setnilvalue(gt(L)); } static void close_state(lua_State* L) @@ -116,7 +118,7 @@ lua_State* luaE_newthread(lua_State* L) preinit_state(L1, L->global); L1->activememcat = L->activememcat; // inherit the active memory category stack_init(L1, L); /* init stack */ - setobj2n(L, gt(L1), gt(L)); /* share table of globals */ + L1->gt = L->gt; /* share table of globals */ L1->singlestep = L->singlestep; LUAU_ASSERT(iswhite(obj2gco(L1))); return L1; @@ -144,14 +146,30 @@ void lua_resetthread(lua_State* L) ci->top = ci->base + LUA_MINSTACK; setnilvalue(ci->func); L->ci = ci; - luaD_reallocCI(L, BASIC_CI_SIZE); + if (FFlag::LuauReduceStackReallocs) + { + if (L->size_ci != BASIC_CI_SIZE) + luaD_reallocCI(L, BASIC_CI_SIZE); + } + else + { + luaD_reallocCI(L, BASIC_CI_SIZE); + } /* clear thread state */ L->status = LUA_OK; L->base = L->ci->base; L->top = L->ci->base; L->nCcalls = L->baseCcalls = 0; /* clear thread stack */ - luaD_reallocstack(L, BASIC_STACK_SIZE); + if (FFlag::LuauReduceStackReallocs) + { + if (L->stacksize != BASIC_STACK_SIZE + EXTRA_STACK) + luaD_reallocstack(L, BASIC_STACK_SIZE); + } + else + { + luaD_reallocstack(L, BASIC_STACK_SIZE); + } for (int i = 0; i < L->stacksize; i++) setnilvalue(L->stack + i); } @@ -193,6 +211,7 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) g->strt.size = 0; g->strt.nuse = 0; g->strt.hash = NULL; + setnilvalue(&g->pseudotemp); setnilvalue(registry(L)); g->gcstate = GCSpause; if (!FFlag::LuauGcPagedSweep) diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 0708b71f3..6dd891382 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -5,9 +5,6 @@ #include "lobject.h" #include "ltm.h" -/* table of globals */ -#define gt(L) (&L->l_gt) - /* registry */ #define registry(L) (&L->global->registry) @@ -177,6 +174,8 @@ typedef struct global_State TString* ttname[LUA_T_COUNT]; /* names for basic types */ TString* tmname[TM_N]; /* array with tag-method names */ + TValue pseudotemp; /* storage for temporary values used in pseudo2addr */ + TValue registry; /* registry table, used by lua_ref and LUA_REGISTRYINDEX */ int registryfree; /* next free slot in registry */ @@ -231,8 +230,7 @@ struct lua_State int cachedslot; /* when table operations or INDEX/NEWINDEX is invoked from Luau, what is the expected slot for lookup? */ - TValue l_gt; /* table of globals */ - TValue env; /* temporary place for environments */ + Table* gt; /* table of globals */ UpVal* openupval; /* list of open upvalues in this stack */ GCObject* gclist; diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index c3b662a2c..6c31d36f2 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -39,7 +39,7 @@ // When calling luau_callTM, we usually push the arguments to the top of the stack. // This is safe to do for complicated reasons: -// - stack guarantees 1 + EXTRA_STACK room beyond stack_last (see luaD_reallocstack) +// - stack guarantees EXTRA_STACK room beyond stack_last (see luaD_reallocstack) // - stack reallocation copies values past stack_last // All external function calls that can cause stack realloc or Lua calls have to be wrapped in VM_PROTECT diff --git a/VM/src/lvmload.cpp b/VM/src/lvmload.cpp index 2472cd902..4e5435b7f 100644 --- a/VM/src/lvmload.cpp +++ b/VM/src/lvmload.cpp @@ -116,12 +116,12 @@ static void resolveImportSafe(lua_State* L, Table* env, TValue* k, uint32_t id) // note: we call getimport with nil propagation which means that accesses to table chains like A.B.C will resolve in nil // this is technically not necessary but it reduces the number of exceptions when loading scripts that rely on getfenv/setfenv for global // injection - luaV_getimport(L, hvalue(gt(L)), self->k, self->id, /* propagatenil= */ true); + luaV_getimport(L, L->gt, self->k, self->id, /* propagatenil= */ true); } }; ResolveImport ri = {k, id}; - if (hvalue(gt(L))->safeenv) + if (L->gt->safeenv) { // luaD_pcall will make sure that if any C/Lua calls during import resolution fail, the thread state is restored back int oldTop = lua_gettop(L); @@ -171,7 +171,7 @@ int luau_load(lua_State* L, const char* chunkname, const char* data, size_t size L->global->GCthreshold = SIZE_MAX; // env is 0 for current environment and a stack index otherwise - Table* envt = (env == 0) ? hvalue(gt(L)) : hvalue(luaA_toobject(L, env)); + Table* envt = (env == 0) ? L->gt : hvalue(luaA_toobject(L, env)); TString* source = luaS_new(L, chunkname); diff --git a/VM/src/lvmutils.cpp b/VM/src/lvmutils.cpp index 31dd59c86..8a18a4d46 100644 --- a/VM/src/lvmutils.cpp +++ b/VM/src/lvmutils.cpp @@ -55,7 +55,7 @@ static void callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1 { ptrdiff_t result = savestack(L, res); // using stack room beyond top is technically safe here, but for very complicated reasons: - // * The stack guarantees 1 + EXTRA_STACK room beyond stack_last (see luaD_reallocstack) will be allocated + // * The stack guarantees EXTRA_STACK room beyond stack_last (see luaD_reallocstack) will be allocated // * we cannot move luaD_checkstack above because the arguments are *sometimes* pointers to the lua // stack and checkstack may invalidate those pointers // * we cannot use savestack/restorestack because the arguments are sometimes on the C++ stack @@ -76,7 +76,7 @@ static void callTMres(lua_State* L, StkId res, const TValue* f, const TValue* p1 static void callTM(lua_State* L, const TValue* f, const TValue* p1, const TValue* p2, const TValue* p3) { // using stack room beyond top is technically safe here, but for very complicated reasons: - // * The stack guarantees 1 + EXTRA_STACK room beyond stack_last (see luaD_reallocstack) will be allocated + // * The stack guarantees EXTRA_STACK room beyond stack_last (see luaD_reallocstack) will be allocated // * we cannot move luaD_checkstack above because the arguments are *sometimes* pointers to the lua // stack and checkstack may invalidate those pointers // * we cannot use savestack/restorestack because the arguments are sometimes on the C++ stack diff --git a/fuzz/linter.cpp b/fuzz/linter.cpp index 55e0888b9..04638d23d 100644 --- a/fuzz/linter.cpp +++ b/fuzz/linter.cpp @@ -32,7 +32,7 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) Luau::LintOptions lintOptions; lintOptions.warningMask = ~0ull; - Luau::lint(parseResult.root, names, typeck.globalScope, nullptr, lintOptions); + Luau::lint(parseResult.root, names, typeck.globalScope, nullptr, {}, lintOptions); } return 0; diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index 27e534927..f407248a5 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -227,7 +227,7 @@ DEFINE_PROTO_FUZZER(const luau::StatBlock& message) if (kFuzzLinter) { Luau::LintOptions lintOptions = {~0u}; - Luau::lint(parseResult.root, names, sharedEnv.globalScope, module.get(), lintOptions); + Luau::lint(parseResult.root, names, sharedEnv.globalScope, module.get(), {}, lintOptions); } } diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 59e125740..1978a0d31 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -1,7 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Autocomplete.h" #include "Luau/BuiltinDefinitions.h" -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" #include "Luau/VisitTypeVar.h" @@ -2610,6 +2609,27 @@ a = if temp then even elseif true then temp else e@9 CHECK(ac.entryMap.count("elseif") == 0); } +TEST_CASE_FIXTURE(ACFixture, "autocomplete_if_else_regression") +{ + ScopedFastFlag FFlagLuauIfElseExprFixCompletionIssue("LuauIfElseExprFixCompletionIssue", true); + check(R"( +local abcdef = 0; +local temp = false +local even = true; +local a +a = if temp then even else@1 +a = if temp then even else @2 +a = if temp then even else abc@3 + )"); + + auto ac = autocomplete('1'); + CHECK(ac.entryMap.count("else") == 0); + ac = autocomplete('2'); + CHECK(ac.entryMap.count("else") == 0); + ac = autocomplete('3'); + CHECK(ac.entryMap.count("abcdef")); +} + TEST_CASE_FIXTURE(ACFixture, "autocomplete_explicit_type_pack") { check(R"( diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index cd7a21d80..f982c86fa 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -10,6 +10,11 @@ #include #include +namespace Luau +{ +std::string rep(const std::string& s, size_t n); +} + using namespace Luau; static std::string compileFunction(const char* source, uint32_t id) @@ -1960,15 +1965,6 @@ RETURN R8 -1 )"); } -static std::string rep(const std::string& s, size_t n) -{ - std::string r; - r.reserve(s.length() * n); - for (size_t i = 0; i < n; ++i) - r += s; - return r; -} - TEST_CASE("RecursionParse") { // The test forcibly pushes the stack limit during compilation; in NoOpt, the stack consumption is much larger so we need to reduce the limit to diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 8b58d2ce8..b09c1efb9 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -492,6 +492,8 @@ TEST_CASE("DateTime") TEST_CASE("Debug") { + ScopedFastFlag luauTableFieldFunctionDebugname{"LuauTableFieldFunctionDebugname", true}; + runConformance("debug.lua"); } @@ -890,6 +892,12 @@ TEST_CASE("Coverage") lua_pushstring(L, function); lua_setfield(L, -2, "name"); + lua_pushinteger(L, linedefined); + lua_setfield(L, -2, "linedefined"); + + lua_pushinteger(L, depth); + lua_setfield(L, -2, "depth"); + for (size_t i = 0; i < size; ++i) if (hits[i] != -1) { diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index c74bfa272..dbdd06a44 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -2,6 +2,7 @@ #include "Fixture.h" #include "Luau/AstQuery.h" +#include "Luau/Parser.h" #include "Luau/TypeVar.h" #include "Luau/TypeAttach.h" #include "Luau/Transpiler.h" @@ -112,7 +113,7 @@ AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& pars sourceModule->name = fromString(mainModuleName); sourceModule->root = result.root; sourceModule->mode = parseMode(result.hotcomments); - sourceModule->ignoreLints = LintWarning::parseMask(result.hotcomments); + sourceModule->hotcomments = std::move(result.hotcomments); if (!result.errors.empty()) { @@ -157,6 +158,7 @@ CheckResult Fixture::check(const std::string& source) LintResult Fixture::lint(const std::string& source, const std::optional& lintOptions) { ParseOptions parseOptions; + parseOptions.captureComments = true; configResolver.defaultConfig.mode = Mode::Nonstrict; parse(source, parseOptions); diff --git a/tests/Fixture.h b/tests/Fixture.h index ab852ef6d..4e45a952a 100644 --- a/tests/Fixture.h +++ b/tests/Fixture.h @@ -8,7 +8,6 @@ #include "Luau/Linter.h" #include "Luau/Location.h" #include "Luau/ModuleResolver.h" -#include "Luau/Parser.h" #include "Luau/ToString.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/Frontend.test.cpp b/tests/Frontend.test.cpp index ea1a08fe7..8a59acd18 100644 --- a/tests/Frontend.test.cpp +++ b/tests/Frontend.test.cpp @@ -2,7 +2,6 @@ #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" #include "Luau/Frontend.h" -#include "Luau/Parser.h" #include "Luau/RequireTracer.h" #include "Fixture.h" @@ -897,8 +896,6 @@ TEST_CASE_FIXTURE(FrontendFixture, "clearStats") TEST_CASE_FIXTURE(FrontendFixture, "typecheck_twice_for_ast_types") { - ScopedFastFlag sffs("LuauTypeCheckTwice", true); - fileResolver.source["Module/A"] = R"( local a = 1 )"; diff --git a/tests/JsonEncoder.test.cpp b/tests/JsonEncoder.test.cpp index 4a717275f..cb5080720 100644 --- a/tests/JsonEncoder.test.cpp +++ b/tests/JsonEncoder.test.cpp @@ -46,7 +46,7 @@ TEST_CASE("encode_AstStatBlock") AstStatBlock block{Location(), bodyArray}; CHECK_EQ( - (R"({"type":"AstStatBlock","location":"0,0 - 0,0","body":[{"type":"AstStatLocal","location":"0,0 - 0,0","vars":["a_local"],"values":[]}]})"), + (R"({"type":"AstStatBlock","location":"0,0 - 0,0","body":[{"type":"AstStatLocal","location":"0,0 - 0,0","vars":[{"type":null,"name":"a_local","location":"0,0 - 0,0"}],"values":[]}]})"), toJson(&block)); } diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 577415fca..d4b973607 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1395,12 +1395,10 @@ TEST_CASE_FIXTURE(Fixture, "DeprecatedApi") TypeId colorType = typeChecker.globalTypes.addType(TableTypeVar{{}, std::nullopt, typeChecker.globalScope->level, Luau::TableState::Sealed}); - getMutable(colorType)->props = { - {"toHSV", {typeChecker.anyType, /* deprecated= */ true, "Color3:ToHSV"} } - }; + getMutable(colorType)->props = {{"toHSV", {typeChecker.anyType, /* deprecated= */ true, "Color3:ToHSV"}}}; addGlobalBinding(typeChecker, "Color3", Binding{colorType, {}}); - + freeze(typeChecker.globalTypes); LintResult result = lintTyped(R"( @@ -1554,8 +1552,46 @@ _ = (math.random() < 0.5 and false) or 42 -- currently ignored )"); REQUIRE_EQ(result.warnings.size(), 2); - CHECK_EQ(result.warnings[0].text, "The and-or expression always evaluates to the second alternative because the first alternative is false; consider using if-then-else expression instead"); - CHECK_EQ(result.warnings[1].text, "The and-or expression always evaluates to the second alternative because the first alternative is nil; consider using if-then-else expression instead"); + CHECK_EQ(result.warnings[0].text, "The and-or expression always evaluates to the second alternative because the first alternative is false; " + "consider using if-then-else expression instead"); + CHECK_EQ(result.warnings[1].text, "The and-or expression always evaluates to the second alternative because the first alternative is nil; " + "consider using if-then-else expression instead"); +} + +TEST_CASE_FIXTURE(Fixture, "WrongComment") +{ + ScopedFastFlag sff("LuauParseAllHotComments", true); + + LintResult result = lint(R"( +--!strict +--!struct +--!nolintGlobal +--!nolint Global +--!nolint KnownGlobal +--!nolint UnknownGlobal +--! no more lint +--!strict here +do end +--!nolint +)"); + + REQUIRE_EQ(result.warnings.size(), 6); + CHECK_EQ(result.warnings[0].text, "Unknown comment directive 'struct'; did you mean 'strict'?"); + CHECK_EQ(result.warnings[1].text, "Unknown comment directive 'nolintGlobal'"); + CHECK_EQ(result.warnings[2].text, "nolint directive refers to unknown lint rule 'Global'"); + CHECK_EQ(result.warnings[3].text, "nolint directive refers to unknown lint rule 'KnownGlobal'; did you mean 'UnknownGlobal'?"); + CHECK_EQ(result.warnings[4].text, "Comment directive with the type checking mode has extra symbols at the end of the line"); + CHECK_EQ(result.warnings[5].text, "Comment directive is ignored because it is placed after the first non-comment token"); +} + +TEST_CASE_FIXTURE(Fixture, "WrongCommentMuteSelf") +{ + LintResult result = lint(R"( +--!nolint +--!struct +)"); + + REQUIRE_EQ(result.warnings.size(), 0); // --!nolint disables WrongComment lint :) } TEST_SUITE_END(); diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index 931a8403a..5bad99014 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index c1a8887b6..0d4c088dd 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1,6 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/Parser.h" -#include "Luau/TypeInfer.h" #include "Fixture.h" #include "ScopedFlags.h" @@ -300,8 +299,9 @@ TEST_CASE_FIXTURE(Fixture, "functions_can_have_return_annotations") AstStatFunction* statFunction = block->body.data[0]->as(); REQUIRE(statFunction != nullptr); - CHECK_EQ(statFunction->func->returnAnnotation.types.size, 1); - CHECK(statFunction->func->returnAnnotation.tailType == nullptr); + REQUIRE(statFunction->func->returnAnnotation.has_value()); + CHECK_EQ(statFunction->func->returnAnnotation->types.size, 1); + CHECK(statFunction->func->returnAnnotation->tailType == nullptr); } TEST_CASE_FIXTURE(Fixture, "functions_can_have_a_function_type_annotation") @@ -316,9 +316,9 @@ TEST_CASE_FIXTURE(Fixture, "functions_can_have_a_function_type_annotation") AstStatFunction* statFunc = block->body.data[0]->as(); REQUIRE(statFunc != nullptr); - AstArray& retTypes = statFunc->func->returnAnnotation.types; - REQUIRE(statFunc->func->hasReturnAnnotation); - CHECK(statFunc->func->returnAnnotation.tailType == nullptr); + REQUIRE(statFunc->func->returnAnnotation.has_value()); + CHECK(statFunc->func->returnAnnotation->tailType == nullptr); + AstArray& retTypes = statFunc->func->returnAnnotation->types; REQUIRE(retTypes.size == 1); AstTypeFunction* funTy = retTypes.data[0]->as(); @@ -337,9 +337,9 @@ TEST_CASE_FIXTURE(Fixture, "function_return_type_should_disambiguate_from_functi AstStatFunction* statFunc = block->body.data[0]->as(); REQUIRE(statFunc != nullptr); - AstArray& retTypes = statFunc->func->returnAnnotation.types; - REQUIRE(statFunc->func->hasReturnAnnotation); - CHECK(statFunc->func->returnAnnotation.tailType == nullptr); + REQUIRE(statFunc->func->returnAnnotation.has_value()); + CHECK(statFunc->func->returnAnnotation->tailType == nullptr); + AstArray& retTypes = statFunc->func->returnAnnotation->types; REQUIRE(retTypes.size == 2); AstTypeReference* ty0 = retTypes.data[0]->as(); @@ -363,9 +363,9 @@ TEST_CASE_FIXTURE(Fixture, "function_return_type_should_parse_as_function_type_a AstStatFunction* statFunc = block->body.data[0]->as(); REQUIRE(statFunc != nullptr); - AstArray& retTypes = statFunc->func->returnAnnotation.types; - REQUIRE(statFunc->func->hasReturnAnnotation); - CHECK(statFunc->func->returnAnnotation.tailType == nullptr); + REQUIRE(statFunc->func->returnAnnotation.has_value()); + CHECK(statFunc->func->returnAnnotation->tailType == nullptr); + AstArray& retTypes = statFunc->func->returnAnnotation->types; REQUIRE(retTypes.size == 1); AstTypeFunction* funTy = retTypes.data[0]->as(); @@ -707,12 +707,25 @@ TEST_CASE_FIXTURE(Fixture, "mode_is_unset_if_no_hot_comment") TEST_CASE_FIXTURE(Fixture, "sense_hot_comment_on_first_line") { - ParseResult result = parseEx(" --!strict "); + ParseOptions options; + options.captureComments = true; + + ParseResult result = parseEx(" --!strict ", options); std::optional mode = parseMode(result.hotcomments); REQUIRE(bool(mode)); CHECK_EQ(int(*mode), int(Mode::Strict)); } +TEST_CASE_FIXTURE(Fixture, "non_header_hot_comments") +{ + ParseOptions options; + options.captureComments = true; + + ParseResult result = parseEx("do end --!strict", options); + std::optional mode = parseMode(result.hotcomments); + REQUIRE(!mode); +} + TEST_CASE_FIXTURE(Fixture, "stop_if_line_ends_with_hyphen") { CHECK_THROWS_AS(parse(" -"), std::exception); @@ -720,7 +733,10 @@ TEST_CASE_FIXTURE(Fixture, "stop_if_line_ends_with_hyphen") TEST_CASE_FIXTURE(Fixture, "nonstrict_mode") { - ParseResult result = parseEx("--!nonstrict"); + ParseOptions options; + options.captureComments = true; + + ParseResult result = parseEx("--!nonstrict", options); CHECK(result.errors.empty()); std::optional mode = parseMode(result.hotcomments); REQUIRE(bool(mode)); @@ -729,7 +745,10 @@ TEST_CASE_FIXTURE(Fixture, "nonstrict_mode") TEST_CASE_FIXTURE(Fixture, "nocheck_mode") { - ParseResult result = parseEx("--!nocheck"); + ParseOptions options; + options.captureComments = true; + + ParseResult result = parseEx("--!nocheck", options); CHECK(result.errors.empty()); std::optional mode = parseMode(result.hotcomments); REQUIRE(bool(mode)); @@ -1498,8 +1517,6 @@ return TEST_CASE_FIXTURE(Fixture, "parse_error_broken_comment") { - ScopedFastFlag luauStartingBrokenComment{"LuauStartingBrokenComment", true}; - const char* expected = "Expected identifier when parsing expression, got unfinished comment"; matchParseError("--[[unfinished work", expected); diff --git a/tests/Repl.test.cpp b/tests/Repl.test.cpp index 1f9c97397..87a1e1e2a 100644 --- a/tests/Repl.test.cpp +++ b/tests/Repl.test.cpp @@ -73,7 +73,7 @@ class ReplFixture private: std::unique_ptr luaState; - // This is a simplicitic and incomplete pretty printer. + // This is a simplistic and incomplete pretty printer. // It is included here to test that the pretty printer hook is being called. // More elaborate tests to ensure correct output can be added if we introduce // a more feature rich pretty printer. @@ -158,12 +158,25 @@ TEST_CASE_FIXTURE(ReplFixture, "CompleteGlobalVariables") myvariable1 = 5 myvariable2 = 5 )"); - CompletionSet completions = getCompletionSet("myvar"); + { + // Try to complete globals that are added by the user's script + CompletionSet completions = getCompletionSet("myvar"); + + std::string prefix = ""; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "myvariable1")); + CHECK(checkCompletion(completions, prefix, "myvariable2")); + } + { + // Try completing some builtin functions + CompletionSet completions = getCompletionSet("math.m"); - std::string prefix = ""; - CHECK(completions.size() == 2); - CHECK(checkCompletion(completions, prefix, "myvariable1")); - CHECK(checkCompletion(completions, prefix, "myvariable2")); + std::string prefix = "math."; + CHECK(completions.size() == 3); + CHECK(checkCompletion(completions, prefix, "max(")); + CHECK(checkCompletion(completions, prefix, "min(")); + CHECK(checkCompletion(completions, prefix, "modf(")); + } } TEST_CASE_FIXTURE(ReplFixture, "CompleteTableKeys") @@ -206,4 +219,188 @@ TEST_CASE_FIXTURE(ReplFixture, "StringMethods") } } +TEST_CASE_FIXTURE(ReplFixture, "TableWithMetatableIndexTable") +{ + runCode(L, R"( + -- Create 't' which is a table with a metatable with an __index table + mt = {} + mt.__index = mt + + t = {} + setmetatable(t, mt) + + mt.mtkey1 = {x="x value", y="y value", 1, 2} + mt.mtkey2 = 2 + + t.tkey1 = {data1 = 2, data2 = "str", 3, 4} + t.tkey2 = 4 +)"); + { + CompletionSet completions = getCompletionSet("t.t"); + + std::string prefix = "t."; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "tkey1")); + CHECK(checkCompletion(completions, prefix, "tkey2")); + } + { + CompletionSet completions = getCompletionSet("t.tkey1.data2:re"); + + std::string prefix = "t.tkey1.data2:"; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "rep(")); + CHECK(checkCompletion(completions, prefix, "reverse(")); + } + { + CompletionSet completions = getCompletionSet("t.mtk"); + + std::string prefix = "t."; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "mtkey1")); + CHECK(checkCompletion(completions, prefix, "mtkey2")); + } + { + CompletionSet completions = getCompletionSet("t.mtkey1."); + + std::string prefix = "t.mtkey1."; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "x")); + CHECK(checkCompletion(completions, prefix, "y")); + } +} + +TEST_CASE_FIXTURE(ReplFixture, "TableWithMetatableIndexFunction") +{ + runCode(L, R"( + -- Create 't' which is a table with a metatable with an __index function + mt = {} + mt.__index = function(table, key) + print("mt.__index called") + if key == "foo" then + return "FOO" + elseif key == "bar" then + return "BAR" + else + return nil + end + end + + t = {} + setmetatable(t, mt) + t.tkey = 0 +)"); + { + CompletionSet completions = getCompletionSet("t.t"); + + std::string prefix = "t."; + CHECK(completions.size() == 1); + CHECK(checkCompletion(completions, prefix, "tkey")); + } + { + // t.foo is a valid key, but should not be completed because it requires calling an __index function + CompletionSet completions = getCompletionSet("t.foo"); + + CHECK(completions.size() == 0); + } + { + // t.foo is a valid key, but should not be found because it requires calling an __index function + CompletionSet completions = getCompletionSet("t.foo:"); + + CHECK(completions.size() == 0); + } +} + +TEST_CASE_FIXTURE(ReplFixture, "TableWithMultipleMetatableIndexTables") +{ + runCode(L, R"( + -- Create a table with a chain of metatables + mt2 = {} + mt2.__index = mt2 + + mt = {} + mt.__index = mt + setmetatable(mt, mt2) + + t = {} + setmetatable(t, mt) + + mt2.mt2key = {x=1, y=2} + mt.mtkey = 2 + t.tkey = 3 +)"); + { + CompletionSet completions = getCompletionSet("t."); + + std::string prefix = "t."; + CHECK(completions.size() == 4); + CHECK(checkCompletion(completions, prefix, "__index")); + CHECK(checkCompletion(completions, prefix, "tkey")); + CHECK(checkCompletion(completions, prefix, "mtkey")); + CHECK(checkCompletion(completions, prefix, "mt2key")); + } + { + CompletionSet completions = getCompletionSet("t.__index."); + + std::string prefix = "t.__index."; + CHECK(completions.size() == 3); + CHECK(checkCompletion(completions, prefix, "__index")); + CHECK(checkCompletion(completions, prefix, "mtkey")); + CHECK(checkCompletion(completions, prefix, "mt2key")); + } + { + CompletionSet completions = getCompletionSet("t.mt2key."); + + std::string prefix = "t.mt2key."; + CHECK(completions.size() == 2); + CHECK(checkCompletion(completions, prefix, "x")); + CHECK(checkCompletion(completions, prefix, "y")); + } +} + +TEST_CASE_FIXTURE(ReplFixture, "TableWithDeepMetatableIndexTables") +{ + runCode(L, R"( +-- Creates a table with a chain of metatables of length `count` +function makeChainedTable(count) + local result = {} + result.__index = result + result[string.format("entry%d", count)] = { count = count } + if count == 0 then + return result + else + return setmetatable(result, makeChainedTable(count - 1)) + end +end + +t30 = makeChainedTable(30) +t60 = makeChainedTable(60) +)"); + { + // Check if entry0 exists + CompletionSet completions = getCompletionSet("t30.entry0"); + + std::string prefix = "t30."; + CHECK(checkCompletion(completions, prefix, "entry0")); + } + { + // Check if entry0.count exists + CompletionSet completions = getCompletionSet("t30.entry0.co"); + + std::string prefix = "t30.entry0."; + CHECK(checkCompletion(completions, prefix, "count")); + } + { + // Check if entry0 exists. With the max traversal limit of 50 in the repl, this should fail. + CompletionSet completions = getCompletionSet("t60.entry0"); + + CHECK(completions.size() == 0); + } + { + // Check if entry0.count exists. With the max traversal limit of 50 in the repl, this should fail. + CompletionSet completions = getCompletionSet("t60.entry0.co"); + + CHECK(completions.size() == 0); + } +} + TEST_SUITE_END(); diff --git a/tests/RequireTracer.test.cpp b/tests/RequireTracer.test.cpp index b9fd04d69..ba03f3638 100644 --- a/tests/RequireTracer.test.cpp +++ b/tests/RequireTracer.test.cpp @@ -1,6 +1,6 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/RequireTracer.h" +#include "Luau/Parser.h" #include "Fixture.h" diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index a87292683..31d7ef10b 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -598,15 +598,13 @@ TEST_CASE_FIXTURE(Fixture, "generic_typevars_are_not_considered_to_escape_their_ /* * The two-pass alias definition system starts by ascribing a free TypeVar to each alias. It then * circles back to fill in the actual type later on. - * + * * If this free type is unified with something degenerate like `any`, we need to take extra care * to ensure that the alias actually binds to the type that the user expected. */ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_unification_with_any") { - ScopedFastFlag sff[] = { - {"LuauTwoPassAliasDefinitionFix", true} - }; + ScopedFastFlag sff[] = {{"LuauTwoPassAliasDefinitionFix", true}}; CheckResult result = check(R"( local function x() diff --git a/tests/TypeInfer.annotations.test.cpp b/tests/TypeInfer.annotations.test.cpp index 572b882d8..2ad11d01c 100644 --- a/tests/TypeInfer.annotations.test.cpp +++ b/tests/TypeInfer.annotations.test.cpp @@ -1,6 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index df06884d8..f3dfb214d 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/BuiltinDefinitions.h" diff --git a/tests/TypeInfer.classes.test.cpp b/tests/TypeInfer.classes.test.cpp index 0283ae192..98fa66eb1 100644 --- a/tests/TypeInfer.classes.test.cpp +++ b/tests/TypeInfer.classes.test.cpp @@ -1,6 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index 2652486be..c6d55793d 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -1,6 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 8a2c6f27e..f8fccf6b0 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 93c0baf6d..d677e28d8 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index 2bcd840c8..eee0e0f17 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Fixture.h" diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index df365fda4..9021700dc 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -426,8 +426,6 @@ TEST_CASE_FIXTURE(Fixture, "if_then_else_expression_singleton_options") {"LuauSingletonTypes", true}, {"LuauParseSingletonTypes", true}, {"LuauExpectedTypesOfProperties", true}, - {"LuauIfElseExpectedType2", true}, - {"LuauIfElseBranchTypeUnion", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index f19cb618b..6bcd4b99a 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1,6 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" @@ -2168,8 +2167,6 @@ b() TEST_CASE_FIXTURE(Fixture, "length_operator_union") { - ScopedFastFlag luauLengthOnCompositeType{"LuauLengthOnCompositeType", true}; - CheckResult result = check(R"( local x: {number} | {string} local y = #x @@ -2180,8 +2177,6 @@ local y = #x TEST_CASE_FIXTURE(Fixture, "length_operator_intersection") { - ScopedFastFlag luauLengthOnCompositeType{"LuauLengthOnCompositeType", true}; - CheckResult result = check(R"( local x: {number} & {z:string} -- mixed tables are evil local y = #x @@ -2192,8 +2187,6 @@ local y = #x TEST_CASE_FIXTURE(Fixture, "length_operator_non_table_union") { - ScopedFastFlag luauLengthOnCompositeType{"LuauLengthOnCompositeType", true}; - CheckResult result = check(R"( local x: {number} | any | string local y = #x @@ -2204,8 +2197,6 @@ local y = #x TEST_CASE_FIXTURE(Fixture, "length_operator_union_errors") { - ScopedFastFlag luauLengthOnCompositeType{"LuauLengthOnCompositeType", true}; - CheckResult result = check(R"( local x: {number} | number | string local y = #x @@ -2214,4 +2205,38 @@ local y = #x LUAU_REQUIRE_ERROR_COUNT(1, result); } +TEST_CASE_FIXTURE(Fixture, "dont_hang_when_trying_to_look_up_in_cyclic_metatable_index") +{ + ScopedFastFlag sff{"LuauTerminateCyclicMetatableIndexLookup", true}; + + // t :: t1 where t1 = {metatable {__index: t1, __tostring: (t1) -> string}} + CheckResult result = check(R"( + local mt = {} + local t = setmetatable({}, mt) + mt.__index = t + + function mt:__tostring() + return t.p + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 't' does not have key 'p'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "give_up_after_one_metatable_index_look_up") +{ + CheckResult result = check(R"( + local data = { x = 5 } + local t1 = setmetatable({}, { __index = data }) + local t2 = setmetatable({}, t1) -- note: must be t1, not a new table + + local x1 = t1.x -- ok + local x2 = t2.x -- nope + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 't2' does not have key 'x'", toString(result.errors[0])); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 531a382f5..323585712 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -2,7 +2,6 @@ #include "Luau/AstQuery.h" #include "Luau/BuiltinDefinitions.h" -#include "Luau/Parser.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" @@ -4654,8 +4653,6 @@ a = setmetatable(a, { __call = function(x) end }) TEST_CASE_FIXTURE(Fixture, "infer_through_group_expr") { - ScopedFastFlag luauGroupExpectedType{"LuauGroupExpectedType", true}; - CheckResult result = check(R"( local function f(a: (number, number) -> number) return a(1, 3) end f(((function(a, b) return a + b end))) @@ -4735,21 +4732,14 @@ local a = if false then "a" elseif false then "b" else "c" TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_type_union") { - ScopedFastFlag sff3{"LuauIfElseBranchTypeUnion", true}; - - { - CheckResult result = check(R"(local a: number? = if true then 42 else nil)"); + CheckResult result = check(R"(local a: number? = if true then 42 else nil)"); - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireType("a"), {true}), "number?"); - } + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("a"), {true}), "number?"); } TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_1") { - ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true}; - ScopedFastFlag luauIfElseBranchTypeUnion{"LuauIfElseBranchTypeUnion", true}; - CheckResult result = check(R"( type X = {number | string} local a: X = if true then {"1", 2, 3} else {4, 5, 6} @@ -4761,9 +4751,6 @@ local a: X = if true then {"1", 2, 3} else {4, 5, 6} TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_2") { - ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true}; - ScopedFastFlag luauIfElseBranchTypeUnion{"LuauIfElseBranchTypeUnion", true}; - CheckResult result = check(R"( local a: number? = if true then 1 else nil )"); @@ -4773,8 +4760,6 @@ local a: number? = if true then 1 else nil TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions_expected_type_3") { - ScopedFastFlag luauIfElseExpectedType2{"LuauIfElseExpectedType2", true}; - CheckResult result = check(R"( local function times(n: any, f: () -> T) local result: {T} = {} @@ -5058,8 +5043,6 @@ end TEST_CASE_FIXTURE(Fixture, "recursive_metatable_crash") { - ScopedFastFlag luauMetatableAreEqualRecursion{"LuauMetatableAreEqualRecursion", true}; - CheckResult result = check(R"( local function getIt() local y @@ -5076,8 +5059,6 @@ local c = a or b TEST_CASE_FIXTURE(Fixture, "bound_typepack_promote") { - ScopedFastFlag luauCommittingTxnLogFreeTpPromote{"LuauCommittingTxnLogFreeTpPromote", true}; - // No assertions should trigger check(R"( local function p() @@ -5251,7 +5232,6 @@ TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_string_singleton") {"LuauDiscriminableUnions2", true}, {"LuauRefactorTypeVarQuestions", true}, {"LuauSingletonTypes", true}, - {"LuauLengthOnCompositeType", true}, }; CheckResult result = check(R"( @@ -5272,7 +5252,6 @@ TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_union_of_string_singleton") {"LuauDiscriminableUnions2", true}, {"LuauRefactorTypeVarQuestions", true}, {"LuauSingletonTypes", true}, - {"LuauLengthOnCompositeType", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 8c7fb79ab..4669ea8eb 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 079870f57..cbe2e48f2 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -1,6 +1,5 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/BuiltinDefinitions.h" -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" @@ -930,8 +929,6 @@ type R = { m: F } TEST_CASE_FIXTURE(Fixture, "pack_tail_unification_check") { - ScopedFastFlag luauUnifyPackTails{"LuauUnifyPackTails", true}; - CheckResult result = check(R"( local a: () -> (number, ...string) local b: () -> (number, ...boolean) diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 759794e63..3b53ddfe3 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/TypePack.test.cpp b/tests/TypePack.test.cpp index 8b056544a..c4931578d 100644 --- a/tests/TypePack.test.cpp +++ b/tests/TypePack.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index 329e7b1f6..e43161fa3 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -1,5 +1,4 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details -#include "Luau/Parser.h" #include "Luau/Scope.h" #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" diff --git a/tests/conformance/coroutine.lua b/tests/conformance/coroutine.lua index f2ecc96bb..b4f81bbae 100644 --- a/tests/conformance/coroutine.lua +++ b/tests/conformance/coroutine.lua @@ -371,6 +371,15 @@ do st, msg = coroutine.close(co) assert(st and msg == nil) assert(f() == 42) + + -- closing a coroutine with a large stack + co = coroutine.create(function() + local function f(depth) return if depth > 0 then f(depth - 1) + depth else 0 end + coroutine.yield(f(100)) + end) + assert(coroutine.resume(co)) + st, msg = coroutine.close(co) + assert(st and msg == nil) end return 'OK' diff --git a/tests/conformance/coverage.lua b/tests/conformance/coverage.lua index f899603f9..14d843a43 100644 --- a/tests/conformance/coverage.lua +++ b/tests/conformance/coverage.lua @@ -49,16 +49,24 @@ foo() c = getcoverage(foo) assert(#c == 1) assert(c[1].name == "foo") +assert(c[1].linedefined == 4) +assert(c[1].depth == 0) assert(validate(c[1], {5, 6, 7}, {})) bar() c = getcoverage(bar) assert(#c == 3) assert(c[1].name == "bar") +assert(c[1].linedefined == 10) +assert(c[1].depth == 0) assert(validate(c[1], {11, 15, 19}, {})) assert(c[2].name == "one") +assert(c[2].linedefined == 11) +assert(c[2].depth == 1) assert(validate(c[2], {12}, {})) assert(c[3].name == nil) +assert(c[3].linedefined == 15) +assert(c[3].depth == 1) assert(validate(c[3], {}, {16})) return 'OK' diff --git a/tests/conformance/debug.lua b/tests/conformance/debug.lua index 0e4100005..0c8cc2d87 100644 --- a/tests/conformance/debug.lua +++ b/tests/conformance/debug.lua @@ -76,6 +76,9 @@ assert(baz(co, 2, "n") == nil) assert(baz(math.sqrt, "n") == "sqrt") assert(baz(math.sqrt, "f") == math.sqrt) -- yes this is pointless +local t = { foo = function() return 1 end } +assert(baz(t.foo, "n") == "foo") + -- info multi-arg returns function quux(...) return {debug.info(...)} diff --git a/tests/main.cpp b/tests/main.cpp index cd24e100f..2af9f7023 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -10,6 +10,9 @@ #ifndef WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN #endif +#ifndef NOMINMAX +#define NOMINMAX +#endif #include // IsDebuggerPresent #endif @@ -52,7 +55,7 @@ static bool debuggerPresent() #endif } -static int assertionHandler(const char* expr, const char* file, int line, const char* function) +static int testAssertionHandler(const char* expr, const char* file, int line, const char* function) { if (debuggerPresent()) LUAU_DEBUGBREAK(); @@ -218,7 +221,7 @@ static void setFastFlags(const std::vector& flags) int main(int argc, char** argv) { - Luau::assertHandler() = assertionHandler; + Luau::assertHandler() = testAssertionHandler; doctest::registerReporter("boost", 0, true); diff --git a/tools/natvis/Analysis.natvis b/tools/natvis/Analysis.natvis new file mode 100644 index 000000000..5de0140e2 --- /dev/null +++ b/tools/natvis/Analysis.natvis @@ -0,0 +1,78 @@ + + + + + AnyTypeVar + + + + {{ index=0, value={*($T1*)storage} }} + {{ index=1, value={*($T2*)storage} }} + {{ index=2, value={*($T3*)storage} }} + {{ index=3, value={*($T4*)storage} }} + {{ index=4, value={*($T5*)storage} }} + {{ index=5, value={*($T6*)storage} }} + {{ index=6, value={*($T7*)storage} }} + {{ index=7, value={*($T8*)storage} }} + {{ index=8, value={*($T9*)storage} }} + {{ index=9, value={*($T10*)storage} }} + {{ index=10, value={*($T11*)storage} }} + {{ index=11, value={*($T12*)storage} }} + {{ index=12, value={*($T13*)storage} }} + {{ index=13, value={*($T14*)storage} }} + {{ index=14, value={*($T15*)storage} }} + {{ index=15, value={*($T16*)storage} }} + {{ index=16, value={*($T17*)storage} }} + {{ index=17, value={*($T18*)storage} }} + {{ index=18, value={*($T19*)storage} }} + {{ index=19, value={*($T20*)storage} }} + {{ index=20, value={*($T21*)storage} }} + {{ index=21, value={*($T22*)storage} }} + {{ index=22, value={*($T23*)storage} }} + {{ index=23, value={*($T24*)storage} }} + {{ index=24, value={*($T25*)storage} }} + {{ index=25, value={*($T26*)storage} }} + {{ index=26, value={*($T27*)storage} }} + {{ index=27, value={*($T28*)storage} }} + {{ index=28, value={*($T29*)storage} }} + {{ index=29, value={*($T30*)storage} }} + {{ index=30, value={*($T31*)storage} }} + {{ index=31, value={*($T32*)storage} }} + + typeId + *($T1*)storage + *($T2*)storage + *($T3*)storage + *($T4*)storage + *($T5*)storage + *($T6*)storage + *($T7*)storage + *($T8*)storage + *($T9*)storage + *($T10*)storage + *($T11*)storage + *($T12*)storage + *($T13*)storage + *($T14*)storage + *($T15*)storage + *($T16*)storage + *($T17*)storage + *($T18*)storage + *($T19*)storage + *($T20*)storage + *($T21*)storage + *($T22*)storage + *($T23*)storage + *($T24*)storage + *($T25*)storage + *($T26*)storage + *($T27*)storage + *($T28*)storage + *($T29*)storage + *($T30*)storage + *($T31*)storage + *($T32*)storage + + + + diff --git a/tools/natvis/Ast.natvis b/tools/natvis/Ast.natvis new file mode 100644 index 000000000..322eb8f67 --- /dev/null +++ b/tools/natvis/Ast.natvis @@ -0,0 +1,25 @@ + + + + + AstArray size={size} + + size + + size + data + + + + + + + size_ + + size_ + storage._Mypair._Myval2._Myfirst + offset + + + + + diff --git a/tools/natvis/VM.natvis b/tools/natvis/VM.natvis new file mode 100644 index 000000000..9924e194f --- /dev/null +++ b/tools/natvis/VM.natvis @@ -0,0 +1,269 @@ + + + + + nil + {(bool)value.b} + lightuserdata {value.p} + number = {value.n} + vector = {value.v[0]}, {value.v[1]}, {*(float*)&extra} + {value.gc->ts} + {value.gc->h} + function {value.gc->cl,view(short)} + userdata {value.gc->u} + thread {value.gc->th} + proto {value.gc->p} + upvalue {value.gc->uv} + deadkey + empty + + value.p + value.gc->ts + value.gc->h + value.gc->cl + value.gc->cl + value.gc->u + value.gc->th + value.gc->p + value.gc->uv + + fixed ({(int)value.gc->gch.marked}) + black ({(int)value.gc->gch.marked}) + white ({(int)value.gc->gch.marked}) + white ({(int)value.gc->gch.marked}) + gray ({(int)value.gc->gch.marked}) + + + + + + nil + {(bool)value.b} + lightuserdata {value.p} + number = {value.n} + vector = {value.v[0]}, {value.v[1]}, {*(float*)&extra} + {value.gc->ts} + {value.gc->h} + function {value.gc->cl,view(short)} + userdata {value.gc->u} + thread {value.gc->th} + proto {value.gc->p} + upvalue {value.gc->uv} + deadkey + empty + + (void**)value.p + value.gc->ts + value.gc->h + value.gc->cl + value.gc->cl + value.gc->u + value.gc->th + value.gc->p + value.gc->uv + + next + + + + + {key,na} = {val} + --- + + + + table + + metatable + + + [size] {1<<lsizenode} + + + 1<<lsizenode + node[$i] + + + + + [size] {sizearray} + + + sizearray + array[$i] + + + + + + + + + + + + + 1 + + + + + metatable->node[i].val + + + + i = i + 1 + + + "unknown",sb + + + tag + len + metatable + data + + + + + {c.f,na} + {l.p,na} + {c} + {l} + invalid + + + + {data,s} + + + + + {ci->func->value.gc->cl.c.f,na} + + + {ci->func->value.gc->cl.l.p->source->data,sb}:{ci->func->value.gc->cl.l.p->linedefined,d} {ci->func->value.gc->cl.l.p->debugname->data,sb} + + + {ci->func->value.gc->cl.l.p->source->data,sb}:{ci->func->value.gc->cl.l.p->linedefined,d} + + thread + + + {ci-base_ci} frames + + + ci-base_ci + + + base_ci[ci-base_ci - $i].func->value.gc->cl,view(short) + + + + + + {top-base} values + + + top-base + base + + + + + {top-stack} values + + + top-stack + stack + + + + + + + openupval + u.l.next + this + + + + l_gt + env + userdata + + + + + {source->data,sb}:{linedefined} function {debugname->data,sb} [{(int)numparams} arg, {(int)nups} upval] + {source->data,sb}:{linedefined} [{(int)numparams} arg, {(int)nups} upval] + + debugname + + constants + + + sizek + k[$i] + + + + + locals + + + sizelocvars + locvars[$i] + + + + + bytecode + + + sizecode + code[$i] + + + + + functions + + + sizep + p[$i] + + + + + upvals + + + sizeupvalues + upvalues[$i] + + + + + source + + + + + + + {(lua_Type)tt} + + + fixed ({(int)marked}) + black ({(int)marked}) + white ({(int)marked}) + white ({(int)marked}) + gray ({(int)marked}) + unknown + + memcat + + + + From a8eabedd570e9b3aba7e02ff2b0f4d8bdbf9efbb Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 24 Feb 2022 15:15:41 -0800 Subject: [PATCH 27/32] Sync to upstream/release/516 --- Analysis/include/Luau/Module.h | 9 +- Analysis/include/Luau/TypeInfer.h | 2 + Analysis/include/Luau/TypeUtils.h | 4 +- Analysis/include/Luau/Unifier.h | 29 ++- Analysis/src/Autocomplete.cpp | 4 +- Analysis/src/Linter.cpp | 1 + Analysis/src/Module.cpp | 27 +-- Analysis/src/Transpiler.cpp | 8 + Analysis/src/TypeInfer.cpp | 67 +++--- Analysis/src/TypeUtils.cpp | 8 +- Analysis/src/TypeVar.cpp | 110 +++------- Analysis/src/Unifier.cpp | 126 +++++++----- Ast/src/Parser.cpp | 8 +- VM/include/lua.h | 12 +- VM/src/lapi.cpp | 20 +- VM/src/ldo.cpp | 3 +- VM/src/lfunc.cpp | 62 ++---- VM/src/lgc.cpp | 218 +++----------------- VM/src/lgc.h | 10 +- VM/src/lgcdebug.cpp | 72 +------ VM/src/linit.cpp | 2 +- VM/src/lmem.cpp | 127 ++---------- VM/src/lmem.h | 4 - VM/src/lobject.h | 5 +- VM/src/lstate.cpp | 32 +-- VM/src/lstate.h | 5 - VM/src/lstring.cpp | 70 ++----- VM/src/ltable.cpp | 4 +- VM/src/ludata.cpp | 2 +- fuzz/linter.cpp | 8 +- fuzz/luau.proto | 55 +++-- fuzz/proto.cpp | 227 ++++++++++++++------- fuzz/protoprint.cpp | 129 ++++++++++-- fuzz/prototest.cpp | 12 +- fuzz/typeck.cpp | 6 +- tests/Autocomplete.test.cpp | 1 - tests/Conformance.test.cpp | 92 ++++++++- tests/Linter.test.cpp | 13 ++ tests/Parser.test.cpp | 1 - tests/Transpiler.test.cpp | 15 ++ tests/TypeInfer.intersectionTypes.test.cpp | 6 +- tests/TypeInfer.refinements.test.cpp | 16 ++ tests/TypeInfer.singletons.test.cpp | 124 +++++++++++ tests/TypeInfer.tables.test.cpp | 18 ++ tests/TypeInfer.test.cpp | 6 - tests/TypeInfer.tryUnify.test.cpp | 4 +- tests/TypeVar.test.cpp | 12 -- tools/natvis/VM.natvis | 5 +- 48 files changed, 914 insertions(+), 887 deletions(-) diff --git a/Analysis/include/Luau/Module.h b/Analysis/include/Luau/Module.h index 612007711..6c689b7ca 100644 --- a/Analysis/include/Luau/Module.h +++ b/Analysis/include/Luau/Module.h @@ -13,8 +13,6 @@ #include #include -LUAU_FASTFLAG(LuauPrepopulateUnionOptionsBeforeAllocation) - namespace Luau { @@ -60,11 +58,8 @@ struct TypeArena template TypeId addType(T tv) { - if (FFlag::LuauPrepopulateUnionOptionsBeforeAllocation) - { - if constexpr (std::is_same_v) - LUAU_ASSERT(tv.options.size() >= 2); - } + if constexpr (std::is_same_v) + LUAU_ASSERT(tv.options.size() >= 2); return addTV(TypeVar(std::move(tv))); } diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 3c5ded3cc..2440c810b 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -31,6 +31,7 @@ bool doesCallError(const AstExprCall* call); bool hasBreak(AstStat* node); const AstStat* getFallthrough(const AstStat* node); +struct UnifierOptions; struct Unifier; // A substitution which replaces generic types in a given set by free types. @@ -245,6 +246,7 @@ struct TypeChecker * Treat any failures as type errors in the final typecheck report. */ bool unify(TypeId subTy, TypeId superTy, const Location& location); + bool unify(TypeId subTy, TypeId superTy, const Location& location, const UnifierOptions& options); bool unify(TypePackId subTy, TypePackId superTy, const Location& location, CountMismatch::Context ctx = CountMismatch::Context::Arg); /** Attempt to unify the types. diff --git a/Analysis/include/Luau/TypeUtils.h b/Analysis/include/Luau/TypeUtils.h index ffddfe4b3..42c1bc0be 100644 --- a/Analysis/include/Luau/TypeUtils.h +++ b/Analysis/include/Luau/TypeUtils.h @@ -13,7 +13,7 @@ namespace Luau using ScopePtr = std::shared_ptr; -std::optional findMetatableEntry(ErrorVec& errors, const ScopePtr& globalScope, TypeId type, std::string entry, Location location); -std::optional findTablePropertyRespectingMeta(ErrorVec& errors, const ScopePtr& globalScope, TypeId ty, Name name, Location location); +std::optional findMetatableEntry(ErrorVec& errors, TypeId type, std::string entry, Location location); +std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId ty, Name name, Location location); } // namespace Luau diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 9db4e22b0..fe822b012 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -19,11 +19,31 @@ enum Variance Invariant }; +// A substitution which replaces singleton types by their wider types +struct Widen : Substitution +{ + Widen(TypeArena* arena) + : Substitution(TxnLog::empty(), arena) + { + } + + bool isDirty(TypeId ty) override; + bool isDirty(TypePackId ty) override; + TypeId clean(TypeId ty) override; + TypePackId clean(TypePackId ty) override; + bool ignoreChildren(TypeId ty) override; +}; + +// TODO: Use this more widely. +struct UnifierOptions +{ + bool isFunctionCall = false; +}; + struct Unifier { TypeArena* const types; Mode mode; - ScopePtr globalScope; // sigh. Needed solely to get at string's metatable. DEPRECATED_TxnLog DEPRECATED_log; TxnLog log; @@ -34,9 +54,9 @@ struct Unifier UnifierSharedState& sharedState; - Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState, + Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); - Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, + Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); // Test whether the two type vars unify. Never commits the result. @@ -65,7 +85,10 @@ struct Unifier void tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed); void tryUnifyIndexer(const TableIndexer& subIndexer, const TableIndexer& superIndexer); + + TypeId widen(TypeId ty); TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); + void cacheResult(TypeId subTy, TypeId superTy); public: diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 5a1ae3975..29a2c6b54 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -236,10 +236,10 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ { ty = follow(ty); - auto canUnify = [&typeArena, &module](TypeId subTy, TypeId superTy) { + auto canUnify = [&typeArena](TypeId subTy, TypeId superTy) { InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); - Unifier unifier(typeArena, Mode::Strict, module.getModuleScope(), Location(), Variance::Covariant, unifierState); + Unifier unifier(typeArena, Mode::Strict, Location(), Variance::Covariant, unifierState); if (FFlag::LuauAutocompleteAvoidMutation && !FFlag::LuauUseCommittingTxnLog) { diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 8d7d2d97f..7635dc0ff 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -201,6 +201,7 @@ static bool similar(AstExpr* lhs, AstExpr* rhs) return true; } + CASE(AstExprIfElse) return similar(le->condition, re->condition) && similar(le->trueExpr, re->trueExpr) && similar(le->falseExpr, re->falseExpr); else { LUAU_ASSERT(!"Unknown expression type"); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 817a33e9f..412b78bbb 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -16,7 +16,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) // Remove with FFlagLuau LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) LUAU_FASTFLAG(LuauTypeAliasDefaults) LUAU_FASTFLAG(LuauImmutableTypes) -LUAU_FASTFLAGVARIABLE(LuauPrepopulateUnionOptionsBeforeAllocation, false) namespace Luau { @@ -379,28 +378,14 @@ void TypeCloner::operator()(const AnyTypeVar& t) void TypeCloner::operator()(const UnionTypeVar& t) { - if (FFlag::LuauPrepopulateUnionOptionsBeforeAllocation) - { - std::vector options; - options.reserve(t.options.size()); - - for (TypeId ty : t.options) - options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); - - TypeId result = dest.addType(UnionTypeVar{std::move(options)}); - seenTypes[typeId] = result; - } - else - { - TypeId result = dest.addType(UnionTypeVar{}); - seenTypes[typeId] = result; + std::vector options; + options.reserve(t.options.size()); - UnionTypeVar* option = getMutable(result); - LUAU_ASSERT(option != nullptr); + for (TypeId ty : t.options) + options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); - for (TypeId ty : t.options) - option->options.push_back(clone(ty, dest, seenTypes, seenTypePacks, cloneState)); - } + TypeId result = dest.addType(UnionTypeVar{std::move(options)}); + seenTypes[typeId] = result; } void TypeCloner::operator()(const IntersectionTypeVar& t) diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index 54bd0d5e7..a02d396bd 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -1153,6 +1153,14 @@ struct Printer writer.symbol(")"); } } + else if (const auto& a = typeAnnotation.as()) + { + writer.keyword(a->value ? "true" : "false"); + } + else if (const auto& a = typeAnnotation.as()) + { + writer.string(std::string_view(a->value.data, a->value.size)); + } else if (typeAnnotation.is()) { writer.symbol("%error-type%"); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index c29699b7d..faf60eb3e 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -29,7 +29,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) LUAU_FASTFLAGVARIABLE(LuauImmutableTypes, false) -LUAU_FASTFLAGVARIABLE(LuauNoSealedTypeMod, false) LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) @@ -38,14 +37,14 @@ LUAU_FASTFLAGVARIABLE(LuauTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) -LUAU_FASTFLAGVARIABLE(LuauProperTypeLevels, false) LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) -LUAU_FASTFLAG(LuauUnionTagMatchFix) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) LUAU_FASTFLAGVARIABLE(LuauAssertStripsFalsyTypes, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. LUAU_FASTFLAGVARIABLE(LuauAnotherTypeLevelFix, false) +LUAU_FASTFLAG(LuauWidenIfSupertypeIsFree) +LUAU_FASTFLAGVARIABLE(LuauDoNotTryToReduce, false) namespace Luau { @@ -1125,7 +1124,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco ty = follow(ty); - if (tableSelf && (FFlag::LuauNoSealedTypeMod ? tableSelf->state != TableState::Sealed : !selfTy->persistent)) + if (tableSelf && tableSelf->state != TableState::Sealed) tableSelf->props[indexName->index.value] = {ty, /* deprecated */ false, {}, indexName->indexLocation}; const FunctionTypeVar* funTy = get(ty); @@ -1138,7 +1137,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco checkFunctionBody(funScope, ty, *function.func); - if (tableSelf && (FFlag::LuauNoSealedTypeMod ? tableSelf->state != TableState::Sealed : !selfTy->persistent)) + if (tableSelf && tableSelf->state != TableState::Sealed) tableSelf->props[indexName->index.value] = { follow(quantify(funScope, ty, indexName->indexLocation)), /* deprecated */ false, {}, indexName->indexLocation}; } @@ -1210,8 +1209,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { ScopePtr aliasScope = childScope(scope, typealias.location); aliasScope->level = scope->level.incr(); - if (FFlag::LuauProperTypeLevels) - aliasScope->level.subLevel = subLevel; + aliasScope->level.subLevel = subLevel; auto [generics, genericPacks] = createGenericTypes(aliasScope, scope->level, typealias, typealias.generics, typealias.genericPacks, /* useCache = */ true); @@ -1624,7 +1622,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprIn std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsType, Name name, const Location& location) { ErrorVec errors; - auto result = Luau::findTablePropertyRespectingMeta(errors, globalScope, lhsType, name, location); + auto result = Luau::findTablePropertyRespectingMeta(errors, lhsType, name, location); reportErrors(errors); return result; } @@ -1632,7 +1630,7 @@ std::optional TypeChecker::findTablePropertyRespectingMeta(TypeId lhsTyp std::optional TypeChecker::findMetatableEntry(TypeId type, std::string entry, const Location& location) { ErrorVec errors; - auto result = Luau::findMetatableEntry(errors, globalScope, type, entry, location); + auto result = Luau::findMetatableEntry(errors, type, entry, location); reportErrors(errors); return result; } @@ -1751,13 +1749,23 @@ std::optional TypeChecker::getIndexTypeFromType( return std::nullopt; } - // TODO(amccord): Write some logic to correctly handle intersections. CLI-34659 - std::vector result = reduceUnion(parts); + if (FFlag::LuauDoNotTryToReduce) + { + if (parts.size() == 1) + return parts[0]; - if (result.size() == 1) - return result[0]; + return addType(IntersectionTypeVar{std::move(parts)}); // Not at all correct. + } + else + { + // TODO(amccord): Write some logic to correctly handle intersections. CLI-34659 + std::vector result = reduceUnion(parts); - return addType(IntersectionTypeVar{result}); + if (result.size() == 1) + return result[0]; + + return addType(IntersectionTypeVar{result}); + } } if (addErrors) @@ -2823,10 +2831,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, TypeLevel level) { auto freshTy = [&]() { - if (FFlag::LuauProperTypeLevels) - return freshType(level); - else - return freshType(scope); + return freshType(level); }; if (auto globalName = funName.as()) @@ -3790,7 +3795,14 @@ std::optional> TypeChecker::checkCallOverload(const Scope // has been instantiated, so is a monotype. We can therefore // unify it with a monomorphic function. TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); - unify(fn, r, expr.location); + if (FFlag::LuauWidenIfSupertypeIsFree) + { + UnifierOptions options; + options.isFunctionCall = true; + unify(r, fn, expr.location, options); + } + else + unify(fn, r, expr.location); return {{retPack}}; } @@ -4243,9 +4255,15 @@ TypeId TypeChecker::anyIfNonstrict(TypeId ty) const } bool TypeChecker::unify(TypeId subTy, TypeId superTy, const Location& location) +{ + UnifierOptions options; + return unify(subTy, superTy, location, options); +} + +bool TypeChecker::unify(TypeId subTy, TypeId superTy, const Location& location, const UnifierOptions& options) { Unifier state = mkUnifier(location); - state.tryUnify(subTy, superTy); + state.tryUnify(subTy, superTy, options.isFunctionCall); if (FFlag::LuauUseCommittingTxnLog) state.log.commit(); @@ -4654,7 +4672,7 @@ void TypeChecker::diagnoseMissingTableKey(UnknownProperty* utk, TypeErrorData& d } }; - if (auto ttv = getTableType(FFlag::LuauUnionTagMatchFix ? utk->table : follow(utk->table))) + if (auto ttv = getTableType(utk->table)) accumulate(ttv->props); else if (auto ctv = get(follow(utk->table))) { @@ -4691,8 +4709,7 @@ ScopePtr TypeChecker::childFunctionScope(const ScopePtr& parent, const Location& ScopePtr TypeChecker::childScope(const ScopePtr& parent, const Location& location) { ScopePtr scope = std::make_shared(parent); - if (FFlag::LuauProperTypeLevels) - scope->level = parent->level; + scope->level = parent->level; scope->varargPack = parent->varargPack; currentModule->scopes.push_back(std::make_pair(location, scope)); @@ -4724,7 +4741,7 @@ void TypeChecker::merge(RefinementMap& l, const RefinementMap& r) Unifier TypeChecker::mkUnifier(const Location& location) { - return Unifier{¤tModule->internalTypes, currentModule->mode, globalScope, location, Variance::Covariant, unifierState}; + return Unifier{¤tModule->internalTypes, currentModule->mode, location, Variance::Covariant, unifierState}; } TypeId TypeChecker::freshType(const ScopePtr& scope) @@ -5444,7 +5461,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st void TypeChecker::refineLValue(const LValue& lvalue, RefinementMap& refis, const ScopePtr& scope, TypeIdPredicate predicate) { - LUAU_ASSERT(FFlag::LuauDiscriminableUnions2); + LUAU_ASSERT(FFlag::LuauDiscriminableUnions2 || FFlag::LuauAssertStripsFalsyTypes); const LValue* target = &lvalue; std::optional key; // If set, we know we took the base of the lvalue path and should be walking down each option of the base's type. diff --git a/Analysis/src/TypeUtils.cpp b/Analysis/src/TypeUtils.cpp index 593b54c84..c24358900 100644 --- a/Analysis/src/TypeUtils.cpp +++ b/Analysis/src/TypeUtils.cpp @@ -10,7 +10,7 @@ LUAU_FASTFLAGVARIABLE(LuauTerminateCyclicMetatableIndexLookup, false) namespace Luau { -std::optional findMetatableEntry(ErrorVec& errors, const ScopePtr& globalScope, TypeId type, std::string entry, Location location) +std::optional findMetatableEntry(ErrorVec& errors, TypeId type, std::string entry, Location location) { type = follow(type); @@ -37,7 +37,7 @@ std::optional findMetatableEntry(ErrorVec& errors, const ScopePtr& globa return std::nullopt; } -std::optional findTablePropertyRespectingMeta(ErrorVec& errors, const ScopePtr& globalScope, TypeId ty, Name name, Location location) +std::optional findTablePropertyRespectingMeta(ErrorVec& errors, TypeId ty, Name name, Location location) { if (get(ty)) return ty; @@ -49,7 +49,7 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, const Sc return it->second.type; } - std::optional mtIndex = findMetatableEntry(errors, globalScope, ty, "__index", location); + std::optional mtIndex = findMetatableEntry(errors, ty, "__index", location); int count = 0; while (mtIndex) { @@ -82,7 +82,7 @@ std::optional findTablePropertyRespectingMeta(ErrorVec& errors, const Sc else errors.push_back(TypeError{location, GenericError{"__index should either be a function or table. Got " + toString(index)}}); - mtIndex = findMetatableEntry(errors, globalScope, *mtIndex, "__index", location); + mtIndex = findMetatableEntry(errors, *mtIndex, "__index", location); } return std::nullopt; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index b2358c277..a1dcfdbec 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -23,9 +23,7 @@ LUAU_FASTFLAG(DebugLuauFreezeArena) LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) -LUAU_FASTFLAGVARIABLE(LuauRefactorTypeVarQuestions, false) LUAU_FASTFLAG(LuauErrorRecoveryType) -LUAU_FASTFLAG(LuauUnionTagMatchFix) LUAU_FASTFLAG(LuauDiscriminableUnions2) namespace Luau @@ -145,20 +143,13 @@ bool isNil(TypeId ty) bool isBoolean(TypeId ty) { - if (FFlag::LuauRefactorTypeVarQuestions) - { - if (isPrim(ty, PrimitiveTypeVar::Boolean) || get(get(follow(ty)))) - return true; + if (isPrim(ty, PrimitiveTypeVar::Boolean) || get(get(follow(ty)))) + return true; - if (auto utv = get(follow(ty))) - return std::all_of(begin(utv), end(utv), isBoolean); + if (auto utv = get(follow(ty))) + return std::all_of(begin(utv), end(utv), isBoolean); - return false; - } - else - { - return isPrim(ty, PrimitiveTypeVar::Boolean); - } + return false; } bool isNumber(TypeId ty) @@ -168,20 +159,13 @@ bool isNumber(TypeId ty) bool isString(TypeId ty) { - if (FFlag::LuauRefactorTypeVarQuestions) - { - if (isPrim(ty, PrimitiveTypeVar::String) || get(get(follow(ty)))) - return true; + if (isPrim(ty, PrimitiveTypeVar::String) || get(get(follow(ty)))) + return true; - if (auto utv = get(follow(ty))) - return std::all_of(begin(utv), end(utv), isString); + if (auto utv = get(follow(ty))) + return std::all_of(begin(utv), end(utv), isString); - return false; - } - else - { - return isPrim(ty, PrimitiveTypeVar::String); - } + return false; } bool isThread(TypeId ty) @@ -194,45 +178,11 @@ bool isOptional(TypeId ty) if (isNil(ty)) return true; - if (FFlag::LuauRefactorTypeVarQuestions) - { - auto utv = get(follow(ty)); - if (!utv) - return false; - - return std::any_of(begin(utv), end(utv), isNil); - } - else - { - std::unordered_set seen; - std::deque queue{ty}; - while (!queue.empty()) - { - TypeId current = follow(queue.front()); - queue.pop_front(); - - if (seen.count(current)) - continue; - - seen.insert(current); - - if (isNil(current)) - return true; - - if (auto u = get(current)) - { - for (TypeId option : u->options) - { - if (isNil(option)) - return true; - - queue.push_back(option); - } - } - } - + auto utv = get(follow(ty)); + if (!utv) return false; - } + + return std::any_of(begin(utv), end(utv), isNil); } bool isTableIntersection(TypeId ty) @@ -263,38 +213,24 @@ std::optional getMetatable(TypeId type) return mtType->metatable; else if (const ClassTypeVar* classType = get(type)) return classType->metatable; - else if (FFlag::LuauRefactorTypeVarQuestions) + else if (isString(type)) { - if (isString(type)) - { - auto ptv = get(getSingletonTypes().stringType); - LUAU_ASSERT(ptv && ptv->metatable); - return ptv->metatable; - } - else - return std::nullopt; - } - else - { - if (const PrimitiveTypeVar* primitiveType = get(type); primitiveType && primitiveType->metatable) - { - LUAU_ASSERT(primitiveType->type == PrimitiveTypeVar::String); - return primitiveType->metatable; - } - else - return std::nullopt; + auto ptv = get(getSingletonTypes().stringType); + LUAU_ASSERT(ptv && ptv->metatable); + return ptv->metatable; } + + return std::nullopt; } const TableTypeVar* getTableType(TypeId type) { - if (FFlag::LuauUnionTagMatchFix) - type = follow(type); + type = follow(type); if (const TableTypeVar* ttv = get(type)) return ttv; else if (const MetatableTypeVar* mtv = get(type)) - return get(FFlag::LuauUnionTagMatchFix ? follow(mtv->table) : mtv->table); + return get(follow(mtv->table)); else return nullptr; } @@ -311,7 +247,7 @@ const std::string* getName(TypeId type) { if (mtv->syntheticName) return &*mtv->syntheticName; - type = FFlag::LuauUnionTagMatchFix ? follow(mtv->table) : mtv->table; + type = follow(mtv->table); } if (auto ttv = get(type)) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 322f6ebf5..d0eba0135 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -21,9 +21,8 @@ LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauErrorRecoveryType); -LUAU_FASTFLAG(LuauProperTypeLevels); -LUAU_FASTFLAGVARIABLE(LuauUnionTagMatchFix, false) LUAU_FASTFLAGVARIABLE(LuauFollowWithCommittingTxnLogInAnyUnification, false) +LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree, false) namespace Luau { @@ -122,7 +121,7 @@ struct PromoteTypeLevels } }; -void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypeId ty) +static void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypeId ty) { // Type levels of types from other modules are already global, so we don't need to promote anything inside if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) @@ -133,6 +132,7 @@ void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, const Typ visitTypeVarOnce(ty, ptl, seen); } +// TODO: use this and make it static. void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp) { // Type levels of types from other modules are already global, so we don't need to promote anything inside @@ -247,6 +247,48 @@ struct SkipCacheForType bool result = false; }; +bool Widen::isDirty(TypeId ty) +{ + return FFlag::LuauUseCommittingTxnLog ? log->is(ty) : bool(get(ty)); +} + +bool Widen::isDirty(TypePackId) +{ + return false; +} + +TypeId Widen::clean(TypeId ty) +{ + LUAU_ASSERT(isDirty(ty)); + auto stv = FFlag::LuauUseCommittingTxnLog ? log->getMutable(ty) : getMutable(ty); + LUAU_ASSERT(stv); + + if (get(stv)) + return getSingletonTypes().stringType; + else + { + // If this assert trips, it's likely we now have number singletons. + LUAU_ASSERT(get(stv)); + return getSingletonTypes().booleanType; + } +} + +TypePackId Widen::clean(TypePackId) +{ + throw std::runtime_error("Widen attempted to clean a dirty type pack?"); +} + +bool Widen::ignoreChildren(TypeId ty) +{ + // Sometimes we unify ("hi") -> free1 with (free2) -> free3, so don't ignore functions. + // TODO: should we be doing this? we would need to rework how checkCallOverload does the unification. + if (FFlag::LuauUseCommittingTxnLog ? log->is(ty) : bool(get(ty))) + return false; + + // We only care about unions. + return !(FFlag::LuauUseCommittingTxnLog ? log->is(ty) : bool(get(ty))); +} + static std::optional hasUnificationTooComplex(const ErrorVec& errors) { auto isUnificationTooComplex = [](const TypeError& te) { @@ -263,43 +305,22 @@ static std::optional hasUnificationTooComplex(const ErrorVec& errors) // Used for tagged union matching heuristic, returns first singleton type field static std::optional> getTableMatchTag(TypeId type) { - if (FFlag::LuauUnionTagMatchFix) - { - if (auto ttv = getTableType(type)) - { - for (auto&& [name, prop] : ttv->props) - { - if (auto sing = get(follow(prop.type))) - return {{name, sing}}; - } - } - } - else + if (auto ttv = getTableType(type)) { - type = follow(type); - - if (auto ttv = get(type)) - { - for (auto&& [name, prop] : ttv->props) - { - if (auto sing = get(follow(prop.type))) - return {{name, sing}}; - } - } - else if (auto mttv = get(type)) + for (auto&& [name, prop] : ttv->props) { - return getTableMatchTag(mttv->table); + if (auto sing = get(follow(prop.type))) + return {{name, sing}}; } } return std::nullopt; } -Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Location& location, Variance variance, UnifierSharedState& sharedState, +Unifier::Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) : types(types) , mode(mode) - , globalScope(std::move(globalScope)) , log(parentLog) , location(location) , variance(variance) @@ -308,11 +329,10 @@ Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, const Locati LUAU_ASSERT(sharedState.iceHandler); } -Unifier::Unifier(TypeArena* types, Mode mode, ScopePtr globalScope, std::vector>* sharedSeen, const Location& location, +Unifier::Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) : types(types) , mode(mode) - , globalScope(std::move(globalScope)) , DEPRECATED_log(sharedSeen) , log(parentLog, sharedSeen) , location(location) @@ -435,6 +455,8 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool } else if (superFree) { + TypeLevel superLevel = superFree->level; + occursCheck(superTy, subTy); bool occursFailed = false; if (FFlag::LuauUseCommittingTxnLog) @@ -442,8 +464,6 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else occursFailed = bool(get(superTy)); - TypeLevel superLevel = superFree->level; - // Unification can't change the level of a generic. auto subGeneric = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy); if (subGeneric && !subGeneric->level.subsumes(superLevel)) @@ -459,20 +479,14 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (FFlag::LuauUseCommittingTxnLog) { promoteTypeLevels(DEPRECATED_log, log, types, superLevel, subTy); - log.replace(superTy, BoundTypeVar(subTy)); + log.replace(superTy, BoundTypeVar(widen(subTy))); } else { - if (FFlag::LuauProperTypeLevels) - promoteTypeLevels(DEPRECATED_log, log, types, superLevel, subTy); - else if (auto subLevel = getMutableLevel(subTy)) - { - if (!subLevel->subsumes(superFree->level)) - *subLevel = superFree->level; - } + promoteTypeLevels(DEPRECATED_log, log, types, superLevel, subTy); DEPRECATED_log(superTy); - *asMutable(superTy) = BoundTypeVar(subTy); + *asMutable(superTy) = BoundTypeVar(widen(subTy)); } } @@ -507,16 +521,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool } else { - if (FFlag::LuauProperTypeLevels) - promoteTypeLevels(DEPRECATED_log, log, types, subLevel, superTy); - else if (auto superLevel = getMutableLevel(superTy)) - { - if (!superLevel->subsumes(subFree->level)) - { - DEPRECATED_log(superTy); - *superLevel = subFree->level; - } - } + promoteTypeLevels(DEPRECATED_log, log, types, subLevel, superTy); DEPRECATED_log(subTy); *asMutable(subTy) = BoundTypeVar(superTy); @@ -2064,6 +2069,17 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } } +TypeId Unifier::widen(TypeId ty) +{ + if (!FFlag::LuauWidenIfSupertypeIsFree) + return ty; + + Widen widen{types}; + std::optional result = widen.substitute(ty); + // TODO: what does it mean for substitution to fail to widen? + return result.value_or(ty); +} + TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map seen) { ty = follow(ty); @@ -2915,7 +2931,7 @@ void Unifier::tryUnifyWithAny(TypePackId subTy, TypePackId anyTp) std::optional Unifier::findTablePropertyRespectingMeta(TypeId lhsType, Name name) { - return Luau::findTablePropertyRespectingMeta(errors, globalScope, lhsType, name, location); + return Luau::findTablePropertyRespectingMeta(errors, lhsType, name, location); } void Unifier::occursCheck(TypeId needle, TypeId haystack) @@ -3096,9 +3112,9 @@ void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, Typ Unifier Unifier::makeChildUnifier() { if (FFlag::LuauUseCommittingTxnLog) - return Unifier{types, mode, globalScope, log.sharedSeen, location, variance, sharedState, &log}; + return Unifier{types, mode, log.sharedSeen, location, variance, sharedState, &log}; else - return Unifier{types, mode, globalScope, DEPRECATED_log.sharedSeen, location, variance, sharedState, &log}; + return Unifier{types, mode, DEPRECATED_log.sharedSeen, location, variance, sharedState, &log}; } bool Unifier::isNonstrictMode() const diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 235d6349d..8767daa05 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -12,7 +12,6 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauParseTypeAliasDefaults, false) -LUAU_FASTFLAGVARIABLE(LuauParseRecoverTypePackEllipsis, false) LUAU_FASTFLAGVARIABLE(LuauParseAllHotComments, false) LUAU_FASTFLAGVARIABLE(LuauTableFieldFunctionDebugname, false) @@ -2372,11 +2371,11 @@ std::pair, AstArray> Parser::parseG { Location nameLocation = lexer.current().location; AstName name = parseName().name; - if (lexer.current().type == Lexeme::Dot3 || (FFlag::LuauParseRecoverTypePackEllipsis && seenPack)) + if (lexer.current().type == Lexeme::Dot3 || seenPack) { seenPack = true; - if (FFlag::LuauParseRecoverTypePackEllipsis && lexer.current().type != Lexeme::Dot3) + if (lexer.current().type != Lexeme::Dot3) report(lexer.current().location, "Generic types come before generic type packs"); else nextLexeme(); @@ -2414,9 +2413,6 @@ std::pair, AstArray> Parser::parseG } else { - if (!FFlag::LuauParseRecoverTypePackEllipsis && seenPack) - report(lexer.current().location, "Generic types come before generic type packs"); - if (withDefaultValues && lexer.current().type == '=') { seenDefault = true; diff --git a/VM/include/lua.h b/VM/include/lua.h index c5dcef251..af0e28354 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -44,7 +44,7 @@ typedef int (*lua_Continuation)(lua_State* L, int status); ** prototype for memory-allocation functions */ -typedef void* (*lua_Alloc)(lua_State* L, void* ud, void* ptr, size_t osize, size_t nsize); +typedef void* (*lua_Alloc)(void* ud, void* ptr, size_t osize, size_t nsize); /* non-return type */ #define l_noret void LUA_NORETURN @@ -178,11 +178,11 @@ LUA_API int lua_pushthread(lua_State* L); /* ** get functions (Lua -> stack) */ -LUA_API void lua_gettable(lua_State* L, int idx); -LUA_API void lua_getfield(lua_State* L, int idx, const char* k); -LUA_API void lua_rawgetfield(lua_State* L, int idx, const char* k); -LUA_API void lua_rawget(lua_State* L, int idx); -LUA_API void lua_rawgeti(lua_State* L, int idx, int n); +LUA_API int lua_gettable(lua_State* L, int idx); +LUA_API int lua_getfield(lua_State* L, int idx, const char* k); +LUA_API int lua_rawgetfield(lua_State* L, int idx, const char* k); +LUA_API int lua_rawget(lua_State* L, int idx); +LUA_API int lua_rawgeti(lua_State* L, int idx, int n); LUA_API void lua_createtable(lua_State* L, int narr, int nrec); LUA_API void lua_setreadonly(lua_State* L, int idx, int enabled); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 29d5f397e..39c76e087 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -659,16 +659,16 @@ int lua_pushthread(lua_State* L) ** get functions (Lua -> stack) */ -void lua_gettable(lua_State* L, int idx) +int lua_gettable(lua_State* L, int idx) { luaC_checkthreadsleep(L); StkId t = index2addr(L, idx); api_checkvalidindex(L, t); luaV_gettable(L, t, L->top - 1, L->top - 1); - return; + return ttype(L->top - 1); } -void lua_getfield(lua_State* L, int idx, const char* k) +int lua_getfield(lua_State* L, int idx, const char* k) { luaC_checkthreadsleep(L); StkId t = index2addr(L, idx); @@ -677,10 +677,10 @@ void lua_getfield(lua_State* L, int idx, const char* k) setsvalue(L, &key, luaS_new(L, k)); luaV_gettable(L, t, &key, L->top); api_incr_top(L); - return; + return ttype(L->top - 1); } -void lua_rawgetfield(lua_State* L, int idx, const char* k) +int lua_rawgetfield(lua_State* L, int idx, const char* k) { luaC_checkthreadsleep(L); StkId t = index2addr(L, idx); @@ -689,26 +689,26 @@ void lua_rawgetfield(lua_State* L, int idx, const char* k) setsvalue(L, &key, luaS_new(L, k)); setobj2s(L, L->top, luaH_getstr(hvalue(t), tsvalue(&key))); api_incr_top(L); - return; + return ttype(L->top - 1); } -void lua_rawget(lua_State* L, int idx) +int lua_rawget(lua_State* L, int idx) { luaC_checkthreadsleep(L); StkId t = index2addr(L, idx); api_check(L, ttistable(t)); setobj2s(L, L->top - 1, luaH_get(hvalue(t), L->top - 1)); - return; + return ttype(L->top - 1); } -void lua_rawgeti(lua_State* L, int idx, int n) +int lua_rawgeti(lua_State* L, int idx, int n) { luaC_checkthreadsleep(L); StkId t = index2addr(L, idx); api_check(L, ttistable(t)); setobj2s(L, L->top, luaH_getnum(hvalue(t), n)); api_incr_top(L); - return; + return ttype(L->top - 1); } void lua_createtable(lua_State* L, int narray, int nrec) diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index d87f06618..b5ae496b5 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -151,8 +151,7 @@ l_noret luaD_throw(lua_State* L, int errcode) static void correctstack(lua_State* L, TValue* oldstack) { L->top = (L->top - oldstack) + L->stack; - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - for (UpVal* up = L->openupval; up != NULL; up = (UpVal*)up->next) + for (UpVal* up = L->openupval; up != NULL; up = up->u.l.threadnext) up->v = (up->v - oldstack) + L->stack; for (CallInfo* ci = L->base_ci; ci <= L->ci; ci++) { diff --git a/VM/src/lfunc.cpp b/VM/src/lfunc.cpp index 582d46277..66447a95a 100644 --- a/VM/src/lfunc.cpp +++ b/VM/src/lfunc.cpp @@ -6,13 +6,10 @@ #include "lmem.h" #include "lgc.h" -LUAU_FASTFLAGVARIABLE(LuauNoDirectUpvalRemoval, false) -LUAU_FASTFLAG(LuauGcPagedSweep) - Proto* luaF_newproto(lua_State* L) { Proto* f = luaM_newgco(L, Proto, sizeof(Proto), L->activememcat); - luaC_link(L, f, LUA_TPROTO); + luaC_init(L, f, LUA_TPROTO); f->k = NULL; f->sizek = 0; f->p = NULL; @@ -40,7 +37,7 @@ Proto* luaF_newproto(lua_State* L) Closure* luaF_newLclosure(lua_State* L, int nelems, Table* e, Proto* p) { Closure* c = luaM_newgco(L, Closure, sizeLclosure(nelems), L->activememcat); - luaC_link(L, c, LUA_TFUNCTION); + luaC_init(L, c, LUA_TFUNCTION); c->isC = 0; c->env = e; c->nupvalues = cast_byte(nelems); @@ -55,7 +52,7 @@ Closure* luaF_newLclosure(lua_State* L, int nelems, Table* e, Proto* p) Closure* luaF_newCclosure(lua_State* L, int nelems, Table* e) { Closure* c = luaM_newgco(L, Closure, sizeCclosure(nelems), L->activememcat); - luaC_link(L, c, LUA_TFUNCTION); + luaC_init(L, c, LUA_TFUNCTION); c->isC = 1; c->env = e; c->nupvalues = cast_byte(nelems); @@ -82,8 +79,7 @@ UpVal* luaF_findupval(lua_State* L, StkId level) return p; } - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - pp = (UpVal**)&p->next; + pp = &p->u.l.threadnext; } UpVal* uv = luaM_newgco(L, UpVal, sizeof(UpVal), L->activememcat); /* not found: create a new one */ @@ -94,19 +90,10 @@ UpVal* luaF_findupval(lua_State* L, StkId level) // chain the upvalue in the threads open upvalue list at the proper position UpVal* next = *pp; - - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - uv->next = (GCObject*)next; - - if (FFlag::LuauGcPagedSweep) - { - uv->u.l.threadprev = pp; - if (next) - { - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - next->u.l.threadprev = (UpVal**)&uv->next; - } - } + uv->u.l.threadnext = next; + uv->u.l.threadprev = pp; + if (next) + next->u.l.threadprev = &uv->u.l.threadnext; *pp = uv; @@ -125,15 +112,11 @@ void luaF_unlinkupval(UpVal* uv) uv->u.l.next->u.l.prev = uv->u.l.prev; uv->u.l.prev->u.l.next = uv->u.l.next; - if (FFlag::LuauGcPagedSweep) - { - // unlink upvalue from the thread open upvalue list - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and this and the following cast will not be required - *uv->u.l.threadprev = (UpVal*)uv->next; + // unlink upvalue from the thread open upvalue list + *uv->u.l.threadprev = uv->u.l.threadnext; - if (UpVal* next = (UpVal*)uv->next) - next->u.l.threadprev = uv->u.l.threadprev; - } + if (UpVal* next = uv->u.l.threadnext) + next->u.l.threadprev = uv->u.l.threadprev; } void luaF_freeupval(lua_State* L, UpVal* uv, lua_Page* page) @@ -145,34 +128,27 @@ void luaF_freeupval(lua_State* L, UpVal* uv, lua_Page* page) void luaF_close(lua_State* L, StkId level) { - global_State* g = L->global; // TODO: remove with FFlagLuauNoDirectUpvalRemoval + global_State* g = L->global; UpVal* uv; while (L->openupval != NULL && (uv = L->openupval)->v >= level) { GCObject* o = obj2gco(uv); LUAU_ASSERT(!isblack(o) && uv->v != &uv->u.value); - if (!FFlag::LuauGcPagedSweep) - L->openupval = (UpVal*)uv->next; /* remove from `open' list */ + // by removing the upvalue from global/thread open upvalue lists, L->openupval will be pointing to the next upvalue + luaF_unlinkupval(uv); - if (FFlag::LuauGcPagedSweep && isdead(g, o)) + if (isdead(g, o)) { - // by removing the upvalue from global/thread open upvalue lists, L->openupval will be pointing to the next upvalue - luaF_unlinkupval(uv); // close the upvalue without copying the dead data so that luaF_freeupval will not unlink again uv->v = &uv->u.value; } - else if (!FFlag::LuauNoDirectUpvalRemoval && isdead(g, o)) - { - luaF_freeupval(L, uv, NULL); /* free upvalue */ - } else { - // by removing the upvalue from global/thread open upvalue lists, L->openupval will be pointing to the next upvalue - luaF_unlinkupval(uv); setobj(L, &uv->u.value, uv->v); - uv->v = &uv->u.value; /* now current value lives here */ - luaC_linkupval(L, uv); /* link upvalue into `gcroot' list */ + uv->v = &uv->u.value; + // GC state of a new closed upvalue has to be initialized + luaC_initupval(L, uv); } } } diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 724b24b2a..8c3a20296 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -13,8 +13,6 @@ #include -LUAU_FASTFLAGVARIABLE(LuauGcPagedSweep, false) - #define GC_SWEEPMAX 40 #define GC_SWEEPCOST 10 #define GC_SWEEPPAGESTEPCOST 4 @@ -64,7 +62,6 @@ static void recordGcStateTime(global_State* g, int startgcstate, double seconds, case GCSatomic: g->gcstats.currcycle.atomictime += seconds; break; - case GCSsweepstring: case GCSsweep: g->gcstats.currcycle.sweeptime += seconds; break; @@ -490,65 +487,6 @@ static void freeobj(lua_State* L, GCObject* o, lua_Page* page) } } -#define sweepwholelist(L, p) sweeplist(L, p, SIZE_MAX) - -static GCObject** sweeplist(lua_State* L, GCObject** p, size_t count) -{ - LUAU_ASSERT(!FFlag::LuauGcPagedSweep); - - GCObject* curr; - global_State* g = L->global; - int deadmask = otherwhite(g); - LUAU_ASSERT(testbit(deadmask, FIXEDBIT)); /* make sure we never sweep fixed objects */ - while ((curr = *p) != NULL && count-- > 0) - { - int alive = (curr->gch.marked ^ WHITEBITS) & deadmask; - if (curr->gch.tt == LUA_TTHREAD) - { - sweepwholelist(L, (GCObject**)&gco2th(curr)->openupval); /* sweep open upvalues */ - - lua_State* th = gco2th(curr); - - if (alive) - { - resetbit(th->stackstate, THREAD_SLEEPINGBIT); - shrinkstack(th); - } - } - if (alive) - { /* not dead? */ - LUAU_ASSERT(!isdead(g, curr)); - makewhite(g, curr); /* make it white (for next cycle) */ - p = &curr->gch.next; - } - else - { /* must erase `curr' */ - LUAU_ASSERT(isdead(g, curr)); - *p = curr->gch.next; - if (curr == g->rootgc) /* is the first element of the list? */ - g->rootgc = curr->gch.next; /* adjust first */ - freeobj(L, curr, NULL); - } - } - - return p; -} - -static void deletelist(lua_State* L, GCObject** p, GCObject* limit) -{ - LUAU_ASSERT(!FFlag::LuauGcPagedSweep); - - GCObject* curr; - while ((curr = *p) != limit) - { - if (curr->gch.tt == LUA_TTHREAD) /* delete open upvalues of each thread */ - deletelist(L, (GCObject**)&gco2th(curr)->openupval, NULL); - - *p = curr->gch.next; - freeobj(L, curr, NULL); - } -} - static void shrinkbuffers(lua_State* L) { global_State* g = L->global; @@ -570,8 +508,6 @@ static void shrinkbuffersfull(lua_State* L) static bool deletegco(void* context, lua_Page* page, GCObject* gco) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - // we are in the process of deleting everything // threads with open upvalues will attempt to close them all on removal // but those upvalues might point to stack values that were already deleted @@ -598,32 +534,13 @@ void luaC_freeall(lua_State* L) LUAU_ASSERT(L == g->mainthread); - if (FFlag::LuauGcPagedSweep) - { - luaM_visitgco(L, L, deletegco); - - for (int i = 0; i < g->strt.size; i++) /* free all string lists */ - LUAU_ASSERT(g->strt.hash[i] == NULL); - - LUAU_ASSERT(L->global->strt.nuse == 0); - LUAU_ASSERT(g->strbufgc == NULL); - } - else - { - LUAU_ASSERT(L->next == NULL); /* mainthread is at the end of rootgc list */ - - deletelist(L, &g->rootgc, obj2gco(L)); + luaM_visitgco(L, L, deletegco); - for (int i = 0; i < g->strt.size; i++) /* free all string lists */ - deletelist(L, (GCObject**)&g->strt.hash[i], NULL); + for (int i = 0; i < g->strt.size; i++) /* free all string lists */ + LUAU_ASSERT(g->strt.hash[i] == NULL); - LUAU_ASSERT(L->global->strt.nuse == 0); - deletelist(L, (GCObject**)&g->strbufgc, NULL); - - // unfortunately, when string objects are freed, the string table use count is decremented - // even when the string is a buffer that wasn't placed into the table - L->global->strt.nuse = 0; - } + LUAU_ASSERT(L->global->strt.nuse == 0); + LUAU_ASSERT(g->strbufgc == NULL); } static void markmt(global_State* g) @@ -687,26 +604,13 @@ static size_t atomic(lua_State* L) g->weak = NULL; /* flip current white */ g->currentwhite = cast_byte(otherwhite(g)); - g->sweepstrgc = 0; - - if (FFlag::LuauGcPagedSweep) - { - g->sweepgcopage = g->allgcopages; - g->gcstate = GCSsweep; - } - else - { - g->sweepgc = &g->rootgc; - g->gcstate = GCSsweepstring; - } - + g->sweepgcopage = g->allgcopages; + g->gcstate = GCSsweep; return work; } static bool sweepgco(lua_State* L, lua_Page* page, GCObject* gco) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - global_State* g = L->global; int deadmask = otherwhite(g); @@ -740,8 +644,6 @@ static bool sweepgco(lua_State* L, lua_Page* page, GCObject* gco) // a version of generic luaM_visitpage specialized for the main sweep stage static int sweepgcopage(lua_State* L, lua_Page* page) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - char* start; char* end; int busyBlocks; @@ -819,75 +721,29 @@ static size_t gcstep(lua_State* L, size_t limit) cost = atomic(L); /* finish mark phase */ - if (FFlag::LuauGcPagedSweep) - LUAU_ASSERT(g->gcstate == GCSsweep); - else - LUAU_ASSERT(g->gcstate == GCSsweepstring); - break; - } - case GCSsweepstring: - { - LUAU_ASSERT(!FFlag::LuauGcPagedSweep); - - while (g->sweepstrgc < g->strt.size && cost < limit) - { - sweepwholelist(L, (GCObject**)&g->strt.hash[g->sweepstrgc++]); - - cost += GC_SWEEPCOST; - } - - // nothing more to sweep? - if (g->sweepstrgc >= g->strt.size) - { - // sweep string buffer list and preserve used string count - uint32_t nuse = L->global->strt.nuse; - - sweepwholelist(L, (GCObject**)&g->strbufgc); - - L->global->strt.nuse = nuse; - - g->gcstate = GCSsweep; // end sweep-string phase - } + LUAU_ASSERT(g->gcstate == GCSsweep); break; } case GCSsweep: { - if (FFlag::LuauGcPagedSweep) + while (g->sweepgcopage && cost < limit) { - while (g->sweepgcopage && cost < limit) - { - lua_Page* next = luaM_getnextgcopage(g->sweepgcopage); // page sweep might destroy the page - - int steps = sweepgcopage(L, g->sweepgcopage); + lua_Page* next = luaM_getnextgcopage(g->sweepgcopage); // page sweep might destroy the page - g->sweepgcopage = next; - cost += steps * GC_SWEEPPAGESTEPCOST; - } - - // nothing more to sweep? - if (g->sweepgcopage == NULL) - { - // don't forget to visit main thread - sweepgco(L, NULL, obj2gco(g->mainthread)); + int steps = sweepgcopage(L, g->sweepgcopage); - shrinkbuffers(L); - g->gcstate = GCSpause; /* end collection */ - } + g->sweepgcopage = next; + cost += steps * GC_SWEEPPAGESTEPCOST; } - else - { - while (*g->sweepgc && cost < limit) - { - g->sweepgc = sweeplist(L, g->sweepgc, GC_SWEEPMAX); - cost += GC_SWEEPMAX * GC_SWEEPCOST; - } + // nothing more to sweep? + if (g->sweepgcopage == NULL) + { + // don't forget to visit main thread + sweepgco(L, NULL, obj2gco(g->mainthread)); - if (*g->sweepgc == NULL) - { /* nothing more to sweep? */ - shrinkbuffers(L); - g->gcstate = GCSpause; /* end collection */ - } + shrinkbuffers(L); + g->gcstate = GCSpause; /* end collection */ } break; } @@ -1013,26 +869,18 @@ void luaC_fullgc(lua_State* L) if (g->gcstate <= GCSatomic) { /* reset sweep marks to sweep all elements (returning them to white) */ - g->sweepstrgc = 0; - if (FFlag::LuauGcPagedSweep) - g->sweepgcopage = g->allgcopages; - else - g->sweepgc = &g->rootgc; + g->sweepgcopage = g->allgcopages; /* reset other collector lists */ g->gray = NULL; g->grayagain = NULL; g->weak = NULL; - - if (FFlag::LuauGcPagedSweep) - g->gcstate = GCSsweep; - else - g->gcstate = GCSsweepstring; + g->gcstate = GCSsweep; } - LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); + LUAU_ASSERT(g->gcstate == GCSsweep); /* finish any pending sweep phase */ while (g->gcstate != GCSpause) { - LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); + LUAU_ASSERT(g->gcstate == GCSsweep); gcstep(L, SIZE_MAX); } @@ -1120,30 +968,19 @@ void luaC_barrierback(lua_State* L, Table* t) g->grayagain = o; } -void luaC_linkobj(lua_State* L, GCObject* o, uint8_t tt) +void luaC_initobj(lua_State* L, GCObject* o, uint8_t tt) { global_State* g = L->global; - if (!FFlag::LuauGcPagedSweep) - { - o->gch.next = g->rootgc; - g->rootgc = o; - } o->gch.marked = luaC_white(g); o->gch.tt = tt; o->gch.memcat = L->activememcat; } -void luaC_linkupval(lua_State* L, UpVal* uv) +void luaC_initupval(lua_State* L, UpVal* uv) { global_State* g = L->global; GCObject* o = obj2gco(uv); - if (!FFlag::LuauGcPagedSweep) - { - o->gch.next = g->rootgc; /* link upvalue into `rootgc' list */ - g->rootgc = o; - } - if (isgray(o)) { if (keepinvariant(g)) @@ -1221,9 +1058,6 @@ const char* luaC_statename(int state) case GCSatomic: return "atomic"; - case GCSsweepstring: - return "sweepstring"; - case GCSsweep: return "sweep"; diff --git a/VM/src/lgc.h b/VM/src/lgc.h index 2acb5d8aa..253e269f9 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -13,9 +13,7 @@ #define GCSpropagate 1 #define GCSpropagateagain 2 #define GCSatomic 3 -// TODO: remove with FFlagLuauGcPagedSweep -#define GCSsweepstring 4 -#define GCSsweep 5 +#define GCSsweep 4 /* ** macro to tell when main invariant (white objects cannot point to black @@ -132,13 +130,13 @@ luaC_wakethread(L); \ } -#define luaC_link(L, o, tt) luaC_linkobj(L, cast_to(GCObject*, (o)), tt) +#define luaC_init(L, o, tt) luaC_initobj(L, cast_to(GCObject*, (o)), tt) LUAI_FUNC void luaC_freeall(lua_State* L); LUAI_FUNC void luaC_step(lua_State* L, bool assist); LUAI_FUNC void luaC_fullgc(lua_State* L); -LUAI_FUNC void luaC_linkobj(lua_State* L, GCObject* o, uint8_t tt); -LUAI_FUNC void luaC_linkupval(lua_State* L, UpVal* uv); +LUAI_FUNC void luaC_initobj(lua_State* L, GCObject* o, uint8_t tt); +LUAI_FUNC void luaC_initupval(lua_State* L, UpVal* uv); LUAI_FUNC void luaC_barrierupval(lua_State* L, GCObject* v); LUAI_FUNC void luaC_barrierf(lua_State* L, GCObject* o, GCObject* v); LUAI_FUNC void luaC_barriertable(lua_State* L, Table* t, GCObject* v); diff --git a/VM/src/lgcdebug.cpp b/VM/src/lgcdebug.cpp index 30242e526..2b38619b0 100644 --- a/VM/src/lgcdebug.cpp +++ b/VM/src/lgcdebug.cpp @@ -13,8 +13,6 @@ #include #include -LUAU_FASTFLAG(LuauGcPagedSweep) - static void validateobjref(global_State* g, GCObject* f, GCObject* t) { LUAU_ASSERT(!isdead(g, t)); @@ -104,8 +102,7 @@ static void validatestack(global_State* g, lua_State* l) if (l->namecall) validateobjref(g, obj2gco(l), obj2gco(l->namecall)); - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - for (UpVal* uv = l->openupval; uv; uv = (UpVal*)uv->next) + for (UpVal* uv = l->openupval; uv; uv = uv->u.l.threadnext) { LUAU_ASSERT(uv->tt == LUA_TUPVAL); LUAU_ASSERT(uv->v != &uv->u.value); @@ -141,7 +138,7 @@ static void validateobj(global_State* g, GCObject* o) /* dead objects can only occur during sweep */ if (isdead(g, o)) { - LUAU_ASSERT(g->gcstate == GCSsweepstring || g->gcstate == GCSsweep); + LUAU_ASSERT(g->gcstate == GCSsweep); return; } @@ -180,18 +177,6 @@ static void validateobj(global_State* g, GCObject* o) } } -static void validatelist(global_State* g, GCObject* o) -{ - LUAU_ASSERT(!FFlag::LuauGcPagedSweep); - - while (o) - { - validateobj(g, o); - - o = o->gch.next; - } -} - static void validategraylist(global_State* g, GCObject* o) { if (!keepinvariant(g)) @@ -224,8 +209,6 @@ static void validategraylist(global_State* g, GCObject* o) static bool validategco(void* context, lua_Page* page, GCObject* gco) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - lua_State* L = (lua_State*)context; global_State* g = L->global; @@ -248,20 +231,9 @@ void luaC_validate(lua_State* L) validategraylist(g, g->gray); validategraylist(g, g->grayagain); - if (FFlag::LuauGcPagedSweep) - { - validategco(L, NULL, obj2gco(g->mainthread)); - - luaM_visitgco(L, L, validategco); - } - else - { - for (int i = 0; i < g->strt.size; ++i) - validatelist(g, (GCObject*)(g->strt.hash[i])); + validategco(L, NULL, obj2gco(g->mainthread)); - validatelist(g, g->rootgc); - validatelist(g, (GCObject*)(g->strbufgc)); - } + luaM_visitgco(L, L, validategco); for (UpVal* uv = g->uvhead.u.l.next; uv != &g->uvhead; uv = uv->u.l.next) { @@ -521,30 +493,8 @@ static void dumpobj(FILE* f, GCObject* o) } } -static void dumplist(FILE* f, GCObject* o) -{ - LUAU_ASSERT(!FFlag::LuauGcPagedSweep); - - while (o) - { - dumpref(f, o); - fputc(':', f); - dumpobj(f, o); - fputc(',', f); - fputc('\n', f); - - // thread has additional list containing collectable objects that are not present in rootgc - if (o->gch.tt == LUA_TTHREAD) - dumplist(f, (GCObject*)gco2th(o)->openupval); - - o = o->gch.next; - } -} - static bool dumpgco(void* context, lua_Page* page, GCObject* gco) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - FILE* f = (FILE*)context; dumpref(f, gco); @@ -563,19 +513,9 @@ void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* fprintf(f, "{\"objects\":{\n"); - if (FFlag::LuauGcPagedSweep) - { - dumpgco(f, NULL, obj2gco(g->mainthread)); + dumpgco(f, NULL, obj2gco(g->mainthread)); - luaM_visitgco(L, f, dumpgco); - } - else - { - dumplist(f, g->rootgc); - dumplist(f, (GCObject*)(g->strbufgc)); - for (int i = 0; i < g->strt.size; ++i) - dumplist(f, (GCObject*)(g->strt.hash[i])); - } + luaM_visitgco(L, f, dumpgco); fprintf(f, "\"0\":{\"type\":\"userdata\",\"cat\":0,\"size\":0}\n"); // to avoid issues with trailing , fprintf(f, "},\"roots\":{\n"); diff --git a/VM/src/linit.cpp b/VM/src/linit.cpp index c93f431f1..fd95f5960 100644 --- a/VM/src/linit.cpp +++ b/VM/src/linit.cpp @@ -68,7 +68,7 @@ void luaL_sandboxthread(lua_State* L) lua_setsafeenv(L, LUA_GLOBALSINDEX, true); } -static void* l_alloc(lua_State* L, void* ud, void* ptr, size_t osize, size_t nsize) +static void* l_alloc(void* ud, void* ptr, size_t osize, size_t nsize) { (void)ud; (void)osize; diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 19617b8ca..899cb0c0c 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -78,8 +78,6 @@ * allocated pages. */ -LUAU_FASTFLAG(LuauGcPagedSweep) - #ifndef __has_feature #define __has_feature(x) 0 #endif @@ -98,8 +96,10 @@ LUAU_FASTFLAG(LuauGcPagedSweep) * To prevent some of them accidentally growing and us losing memory without realizing it, we're going to lock * the sizes of all critical structures down. */ -#if defined(__APPLE__) && !defined(__MACH__) +#if defined(__APPLE__) #define ABISWITCH(x64, ms32, gcc32) (sizeof(void*) == 8 ? x64 : gcc32) +#elif defined(__i386__) +#define ABISWITCH(x64, ms32, gcc32) (gcc32) #else // Android somehow uses a similar ABI to MSVC, *not* to iOS... #define ABISWITCH(x64, ms32, gcc32) (sizeof(void*) == 8 ? x64 : ms32) @@ -114,14 +114,8 @@ static_assert(sizeof(LuaNode) == ABISWITCH(32, 32, 32), "size mismatch for table #endif static_assert(offsetof(TString, data) == ABISWITCH(24, 20, 20), "size mismatch for string header"); -// TODO (FFlagLuauGcPagedSweep): this will become ABISWITCH(16, 16, 16) -static_assert(offsetof(Udata, data) == ABISWITCH(24, 16, 16), "size mismatch for userdata header"); -// TODO (FFlagLuauGcPagedSweep): this will become ABISWITCH(48, 32, 32) -static_assert(sizeof(Table) == ABISWITCH(56, 36, 36), "size mismatch for table header"); - -// TODO (FFlagLuauGcPagedSweep): new code with old 'next' pointer requires that GCObject start at the same point as TString/UpVal -static_assert(offsetof(GCObject, uv) == 0, "UpVal data must be located at the start of the GCObject"); -static_assert(offsetof(GCObject, ts) == 0, "TString data must be located at the start of the GCObject"); +static_assert(offsetof(Udata, data) == ABISWITCH(16, 16, 12), "size mismatch for userdata header"); +static_assert(sizeof(Table) == ABISWITCH(48, 32, 32), "size mismatch for table header"); const size_t kSizeClasses = LUA_SIZECLASSES; const size_t kMaxSmallSize = 512; @@ -208,53 +202,13 @@ l_noret luaM_toobig(lua_State* L) luaG_runerror(L, "memory allocation error: block too big"); } -static lua_Page* newpageold(lua_State* L, uint8_t sizeClass) -{ - LUAU_ASSERT(!FFlag::LuauGcPagedSweep); - - global_State* g = L->global; - lua_Page* page = (lua_Page*)(*g->frealloc)(L, g->ud, NULL, 0, kPageSize); - if (!page) - luaD_throw(L, LUA_ERRMEM); - - int blockSize = kSizeClassConfig.sizeOfClass[sizeClass] + kBlockHeader; - int blockCount = (kPageSize - offsetof(lua_Page, data)) / blockSize; - - ASAN_POISON_MEMORY_REGION(page->data, blockSize * blockCount); - - // setup page header - page->prev = NULL; - page->next = NULL; - - page->gcolistprev = NULL; - page->gcolistnext = NULL; - - page->pageSize = kPageSize; - page->blockSize = blockSize; - - // note: we start with the last block in the page and move downward - // either order would work, but that way we don't need to store the block count in the page - // additionally, GC stores objects in singly linked lists, and this way the GC lists end up in increasing pointer order - page->freeList = NULL; - page->freeNext = (blockCount - 1) * blockSize; - page->busyBlocks = 0; - - // prepend a page to page freelist (which is empty because we only ever allocate a new page when it is!) - LUAU_ASSERT(!g->freepages[sizeClass]); - g->freepages[sizeClass] = page; - - return page; -} - static lua_Page* newpage(lua_State* L, lua_Page** gcopageset, int pageSize, int blockSize, int blockCount) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - global_State* g = L->global; LUAU_ASSERT(pageSize - int(offsetof(lua_Page, data)) >= blockSize * blockCount); - lua_Page* page = (lua_Page*)(*g->frealloc)(L, g->ud, NULL, 0, pageSize); + lua_Page* page = (lua_Page*)(*g->frealloc)(g->ud, NULL, 0, pageSize); if (!page) luaD_throw(L, LUA_ERRMEM); @@ -290,8 +244,6 @@ static lua_Page* newpage(lua_State* L, lua_Page** gcopageset, int pageSize, int static lua_Page* newclasspage(lua_State* L, lua_Page** freepageset, lua_Page** gcopageset, uint8_t sizeClass, bool storeMetadata) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - int blockSize = kSizeClassConfig.sizeOfClass[sizeClass] + (storeMetadata ? kBlockHeader : 0); int blockCount = (kPageSize - offsetof(lua_Page, data)) / blockSize; @@ -304,29 +256,8 @@ static lua_Page* newclasspage(lua_State* L, lua_Page** freepageset, lua_Page** g return page; } -static void freepageold(lua_State* L, lua_Page* page, uint8_t sizeClass) -{ - LUAU_ASSERT(!FFlag::LuauGcPagedSweep); - - global_State* g = L->global; - - // remove page from freelist - if (page->next) - page->next->prev = page->prev; - - if (page->prev) - page->prev->next = page->next; - else if (g->freepages[sizeClass] == page) - g->freepages[sizeClass] = page->next; - - // so long - (*g->frealloc)(L, g->ud, page, kPageSize, 0); -} - static void freepage(lua_State* L, lua_Page** gcopageset, lua_Page* page) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - global_State* g = L->global; if (gcopageset) @@ -342,13 +273,11 @@ static void freepage(lua_State* L, lua_Page** gcopageset, lua_Page* page) } // so long - (*g->frealloc)(L, g->ud, page, page->pageSize, 0); + (*g->frealloc)(g->ud, page, page->pageSize, 0); } static void freeclasspage(lua_State* L, lua_Page** freepageset, lua_Page** gcopageset, lua_Page* page, uint8_t sizeClass) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - // remove page from freelist if (page->next) page->next->prev = page->prev; @@ -368,12 +297,7 @@ static void* newblock(lua_State* L, int sizeClass) // slow path: no page in the freelist, allocate a new one if (!page) - { - if (FFlag::LuauGcPagedSweep) - page = newclasspage(L, g->freepages, NULL, sizeClass, true); - else - page = newpageold(L, sizeClass); - } + page = newclasspage(L, g->freepages, NULL, sizeClass, true); LUAU_ASSERT(!page->prev); LUAU_ASSERT(page->freeList || page->freeNext >= 0); @@ -416,8 +340,6 @@ static void* newblock(lua_State* L, int sizeClass) static void* newgcoblock(lua_State* L, int sizeClass) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - global_State* g = L->global; lua_Page* page = g->freegcopages[sizeClass]; @@ -496,17 +418,11 @@ static void freeblock(lua_State* L, int sizeClass, void* block) // if it's the last block in the page, we don't need the page if (page->busyBlocks == 0) - { - if (FFlag::LuauGcPagedSweep) - freeclasspage(L, g->freepages, NULL, page, sizeClass); - else - freepageold(L, page, sizeClass); - } + freeclasspage(L, g->freepages, NULL, page, sizeClass); } static void freegcoblock(lua_State* L, int sizeClass, void* block, lua_Page* page) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); LUAU_ASSERT(page && page->busyBlocks > 0); LUAU_ASSERT(page->blockSize == kSizeClassConfig.sizeOfClass[sizeClass]); LUAU_ASSERT(block >= page->data && block < (char*)page + page->pageSize); @@ -544,7 +460,7 @@ void* luaM_new_(lua_State* L, size_t nsize, uint8_t memcat) int nclass = sizeclass(nsize); - void* block = nclass >= 0 ? newblock(L, nclass) : (*g->frealloc)(L, g->ud, NULL, 0, nsize); + void* block = nclass >= 0 ? newblock(L, nclass) : (*g->frealloc)(g->ud, NULL, 0, nsize); if (block == NULL && nsize > 0) luaD_throw(L, LUA_ERRMEM); @@ -556,9 +472,6 @@ void* luaM_new_(lua_State* L, size_t nsize, uint8_t memcat) GCObject* luaM_newgco_(lua_State* L, size_t nsize, uint8_t memcat) { - if (!FFlag::LuauGcPagedSweep) - return (GCObject*)luaM_new_(L, nsize, memcat); - // we need to accommodate space for link for free blocks (freegcolink) LUAU_ASSERT(nsize >= kGCOLinkOffset + sizeof(void*)); @@ -602,7 +515,7 @@ void luaM_free_(lua_State* L, void* block, size_t osize, uint8_t memcat) if (oclass >= 0) freeblock(L, oclass, block); else - (*g->frealloc)(L, g->ud, block, osize, 0); + (*g->frealloc)(g->ud, block, osize, 0); g->totalbytes -= osize; g->memcatbytes[memcat] -= osize; @@ -610,12 +523,6 @@ void luaM_free_(lua_State* L, void* block, size_t osize, uint8_t memcat) void luaM_freegco_(lua_State* L, GCObject* block, size_t osize, uint8_t memcat, lua_Page* page) { - if (!FFlag::LuauGcPagedSweep) - { - luaM_free_(L, block, osize, memcat); - return; - } - global_State* g = L->global; LUAU_ASSERT((osize == 0) == (block == NULL)); @@ -652,7 +559,7 @@ void* luaM_realloc_(lua_State* L, void* block, size_t osize, size_t nsize, uint8 // if either block needs to be allocated using a block allocator, we can't use realloc directly if (nclass >= 0 || oclass >= 0) { - result = nclass >= 0 ? newblock(L, nclass) : (*g->frealloc)(L, g->ud, NULL, 0, nsize); + result = nclass >= 0 ? newblock(L, nclass) : (*g->frealloc)(g->ud, NULL, 0, nsize); if (result == NULL && nsize > 0) luaD_throw(L, LUA_ERRMEM); @@ -662,11 +569,11 @@ void* luaM_realloc_(lua_State* L, void* block, size_t osize, size_t nsize, uint8 if (oclass >= 0) freeblock(L, oclass, block); else - (*g->frealloc)(L, g->ud, block, osize, 0); + (*g->frealloc)(g->ud, block, osize, 0); } else { - result = (*g->frealloc)(L, g->ud, block, osize, nsize); + result = (*g->frealloc)(g->ud, block, osize, nsize); if (result == NULL && nsize > 0) luaD_throw(L, LUA_ERRMEM); } @@ -679,8 +586,6 @@ void* luaM_realloc_(lua_State* L, void* block, size_t osize, size_t nsize, uint8 void luaM_getpagewalkinfo(lua_Page* page, char** start, char** end, int* busyBlocks, int* blockSize) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - int blockCount = (page->pageSize - offsetof(lua_Page, data)) / page->blockSize; LUAU_ASSERT(page->freeNext >= -page->blockSize && page->freeNext <= (blockCount - 1) * page->blockSize); @@ -700,8 +605,6 @@ lua_Page* luaM_getnextgcopage(lua_Page* page) void luaM_visitpage(lua_Page* page, void* context, bool (*visitor)(void* context, lua_Page* page, GCObject* gco)) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - char* start; char* end; int busyBlocks; @@ -730,8 +633,6 @@ void luaM_visitpage(lua_Page* page, void* context, bool (*visitor)(void* context void luaM_visitgco(lua_State* L, void* context, bool (*visitor)(void* context, lua_Page* page, GCObject* gco)) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - global_State* g = L->global; for (lua_Page* curr = g->allgcopages; curr;) diff --git a/VM/src/lmem.h b/VM/src/lmem.h index 1bfe48fa8..007884529 100644 --- a/VM/src/lmem.h +++ b/VM/src/lmem.h @@ -7,11 +7,7 @@ struct lua_Page; union GCObject; -// TODO: remove with FFlagLuauGcPagedSweep and rename luaM_newgco to luaM_new -#define luaM_new(L, t, size, memcat) cast_to(t*, luaM_new_(L, size, memcat)) #define luaM_newgco(L, t, size, memcat) cast_to(t*, luaM_newgco_(L, size, memcat)) -// TODO: remove with FFlagLuauGcPagedSweep and rename luaM_freegco to luaM_free -#define luaM_free(L, p, size, memcat) luaM_free_(L, (p), size, memcat) #define luaM_freegco(L, p, size, memcat, page) luaM_freegco_(L, obj2gco(p), size, memcat, page) #define luaM_arraysize_(n, e) ((cast_to(size_t, (n)) <= SIZE_MAX / (e)) ? (n) * (e) : (luaM_toobig(L), SIZE_MAX)) diff --git a/VM/src/lobject.h b/VM/src/lobject.h index 57ffd82ab..5e02c2ead 100644 --- a/VM/src/lobject.h +++ b/VM/src/lobject.h @@ -15,7 +15,6 @@ typedef union GCObject GCObject; */ // clang-format off #define CommonHeader \ - GCObject* next; /* TODO: remove with FFlagLuauGcPagedSweep */ \ uint8_t tt; uint8_t marked; uint8_t memcat // clang-format on @@ -233,6 +232,8 @@ typedef struct TString int16_t atom; // 2 byte padding + TString* next; // next string in the hash table bucket or the string buffer linked list + unsigned int hash; unsigned int len; @@ -327,7 +328,7 @@ typedef struct UpVal struct UpVal* next; /* thread double linked list (when open) */ - // TODO: when FFlagLuauGcPagedSweep is removed, old outer 'next' value will be placed here + struct UpVal* threadnext; /* note: this is the location of a pointer to this upvalue in the previous element that can be either an UpVal or a lua_State */ struct UpVal** threadprev; } l; diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index d6d127c02..d4f3f0a19 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -10,7 +10,6 @@ #include "ldo.h" #include "ldebug.h" -LUAU_FASTFLAG(LuauGcPagedSweep) LUAU_FASTFLAGVARIABLE(LuauReduceStackReallocs, false) /* @@ -90,8 +89,6 @@ static void close_state(lua_State* L) global_State* g = L->global; luaF_close(L, L->stack); /* close all upvalues for this thread */ luaC_freeall(L); /* collect all objects */ - if (!FFlag::LuauGcPagedSweep) - LUAU_ASSERT(g->rootgc == obj2gco(L)); LUAU_ASSERT(g->strbufgc == NULL); LUAU_ASSERT(g->strt.nuse == 0); luaM_freearray(L, L->global->strt.hash, L->global->strt.size, TString*, 0); @@ -99,22 +96,20 @@ static void close_state(lua_State* L) for (int i = 0; i < LUA_SIZECLASSES; i++) { LUAU_ASSERT(g->freepages[i] == NULL); - if (FFlag::LuauGcPagedSweep) - LUAU_ASSERT(g->freegcopages[i] == NULL); + LUAU_ASSERT(g->freegcopages[i] == NULL); } - if (FFlag::LuauGcPagedSweep) - LUAU_ASSERT(g->allgcopages == NULL); + LUAU_ASSERT(g->allgcopages == NULL); LUAU_ASSERT(g->totalbytes == sizeof(LG)); LUAU_ASSERT(g->memcatbytes[0] == sizeof(LG)); for (int i = 1; i < LUA_MEMORY_CATEGORIES; i++) LUAU_ASSERT(g->memcatbytes[i] == 0); - (*g->frealloc)(L, g->ud, L, sizeof(LG), 0); + (*g->frealloc)(g->ud, L, sizeof(LG), 0); } lua_State* luaE_newthread(lua_State* L) { lua_State* L1 = luaM_newgco(L, lua_State, sizeof(lua_State), L->activememcat); - luaC_link(L, L1, LUA_TTHREAD); + luaC_init(L, L1, LUA_TTHREAD); preinit_state(L1, L->global); L1->activememcat = L->activememcat; // inherit the active memory category stack_init(L1, L); /* init stack */ @@ -184,13 +179,11 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) int i; lua_State* L; global_State* g; - void* l = (*f)(NULL, ud, NULL, 0, sizeof(LG)); + void* l = (*f)(ud, NULL, 0, sizeof(LG)); if (l == NULL) return NULL; L = (lua_State*)l; g = &((LG*)L)->g; - if (!FFlag::LuauGcPagedSweep) - L->next = NULL; L->tt = LUA_TTHREAD; L->marked = g->currentwhite = bit2mask(WHITE0BIT, FIXEDBIT); L->memcat = 0; @@ -214,11 +207,6 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) setnilvalue(&g->pseudotemp); setnilvalue(registry(L)); g->gcstate = GCSpause; - if (!FFlag::LuauGcPagedSweep) - g->rootgc = obj2gco(L); - g->sweepstrgc = 0; - if (!FFlag::LuauGcPagedSweep) - g->sweepgc = &g->rootgc; g->gray = NULL; g->grayagain = NULL; g->weak = NULL; @@ -230,14 +218,10 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) for (i = 0; i < LUA_SIZECLASSES; i++) { g->freepages[i] = NULL; - if (FFlag::LuauGcPagedSweep) - g->freegcopages[i] = NULL; - } - if (FFlag::LuauGcPagedSweep) - { - g->allgcopages = NULL; - g->sweepgcopage = NULL; + g->freegcopages[i] = NULL; } + g->allgcopages = NULL; + g->sweepgcopage = NULL; for (i = 0; i < LUA_T_COUNT; i++) g->mt[i] = NULL; for (i = 0; i < LUA_UTAG_LIMIT; i++) diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 6dd891382..3ee967181 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -142,11 +142,6 @@ typedef struct global_State uint8_t gcstate; /* state of garbage collector */ - int sweepstrgc; /* position of sweep in `strt' */ - // TODO: remove with FFlagLuauGcPagedSweep - GCObject* rootgc; /* list of all collectable objects */ - // TODO: remove with FFlagLuauGcPagedSweep - GCObject** sweepgc; /* position of sweep in `rootgc' */ GCObject* gray; /* list of gray objects */ GCObject* grayagain; /* list of objects to be traversed atomically */ GCObject* weak; /* list of weak tables (to be cleared) */ diff --git a/VM/src/lstring.cpp b/VM/src/lstring.cpp index 9bbc43dec..872501468 100644 --- a/VM/src/lstring.cpp +++ b/VM/src/lstring.cpp @@ -7,8 +7,6 @@ #include -LUAU_FASTFLAG(LuauGcPagedSweep) - unsigned int luaS_hash(const char* str, size_t len) { // Note that this hashing algorithm is replicated in BytecodeBuilder.cpp, BytecodeBuilder::getStringHash @@ -46,8 +44,6 @@ unsigned int luaS_hash(const char* str, size_t len) void luaS_resize(lua_State* L, int newsize) { - if (L->global->gcstate == GCSsweepstring) - return; /* cannot resize during GC traverse */ TString** newhash = luaM_newarray(L, newsize, TString*, 0); stringtable* tb = &L->global->strt; for (int i = 0; i < newsize; i++) @@ -58,13 +54,11 @@ void luaS_resize(lua_State* L, int newsize) TString* p = tb->hash[i]; while (p) { /* for each node in the list */ - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - TString* next = (TString*)p->next; /* save next */ + TString* next = p->next; /* save next */ unsigned int h = p->hash; int h1 = lmod(h, newsize); /* new position */ LUAU_ASSERT(cast_int(h % newsize) == lmod(h, newsize)); - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - p->next = (GCObject*)newhash[h1]; /* chain it */ + p->next = newhash[h1]; /* chain it */ newhash[h1] = p; p = next; } @@ -91,8 +85,7 @@ static TString* newlstr(lua_State* L, const char* str, size_t l, unsigned int h) ts->atom = L->global->cb.useratom ? L->global->cb.useratom(ts->data, l) : -1; tb = &L->global->strt; h = lmod(h, tb->size); - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the case will not be required - ts->next = (GCObject*)tb->hash[h]; /* chain new entry */ + ts->next = tb->hash[h]; /* chain new entry */ tb->hash[h] = ts; tb->nuse++; if (tb->nuse > cast_to(uint32_t, tb->size) && tb->size <= INT_MAX / 2) @@ -104,20 +97,9 @@ static void linkstrbuf(lua_State* L, TString* ts) { global_State* g = L->global; - if (FFlag::LuauGcPagedSweep) - { - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - ts->next = (GCObject*)g->strbufgc; - g->strbufgc = ts; - ts->marked = luaC_white(g); - } - else - { - GCObject* o = obj2gco(ts); - o->gch.next = (GCObject*)g->strbufgc; - g->strbufgc = gco2ts(o); - o->gch.marked = luaC_white(g); - } + ts->next = g->strbufgc; + g->strbufgc = ts; + ts->marked = luaC_white(g); } static void unlinkstrbuf(lua_State* L, TString* ts) @@ -130,14 +112,12 @@ static void unlinkstrbuf(lua_State* L, TString* ts) { if (curr == ts) { - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - *p = (TString*)curr->next; + *p = curr->next; return; } else { - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - p = (TString**)&curr->next; + p = &curr->next; } } @@ -167,8 +147,7 @@ TString* luaS_buffinish(lua_State* L, TString* ts) int bucket = lmod(h, tb->size); // search if we already have this string in the hash table - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - for (TString* el = tb->hash[bucket]; el != NULL; el = (TString*)el->next) + for (TString* el = tb->hash[bucket]; el != NULL; el = el->next) { if (el->len == ts->len && memcmp(el->data, ts->data, ts->len) == 0) { @@ -187,8 +166,7 @@ TString* luaS_buffinish(lua_State* L, TString* ts) // Complete string object ts->atom = L->global->cb.useratom ? L->global->cb.useratom(ts->data, ts->len) : -1; - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - ts->next = (GCObject*)tb->hash[bucket]; // chain new entry + ts->next = tb->hash[bucket]; // chain new entry tb->hash[bucket] = ts; tb->nuse++; @@ -201,8 +179,7 @@ TString* luaS_buffinish(lua_State* L, TString* ts) TString* luaS_newlstr(lua_State* L, const char* str, size_t l) { unsigned int h = luaS_hash(str, l); - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - for (TString* el = L->global->strt.hash[lmod(h, L->global->strt.size)]; el != NULL; el = (TString*)el->next) + for (TString* el = L->global->strt.hash[lmod(h, L->global->strt.size)]; el != NULL; el = el->next) { if (el->len == l && (memcmp(str, getstr(el), l) == 0)) { @@ -217,8 +194,6 @@ TString* luaS_newlstr(lua_State* L, const char* str, size_t l) static bool unlinkstr(lua_State* L, TString* ts) { - LUAU_ASSERT(FFlag::LuauGcPagedSweep); - global_State* g = L->global; TString** p = &g->strt.hash[lmod(ts->hash, g->strt.size)]; @@ -227,14 +202,12 @@ static bool unlinkstr(lua_State* L, TString* ts) { if (curr == ts) { - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - *p = (TString*)curr->next; + *p = curr->next; return true; } else { - // TODO (FFlagLuauGcPagedSweep): 'next' type will change after removal of the flag and the cast will not be required - p = (TString**)&curr->next; + p = &curr->next; } } @@ -243,20 +216,11 @@ static bool unlinkstr(lua_State* L, TString* ts) void luaS_free(lua_State* L, TString* ts, lua_Page* page) { - if (FFlag::LuauGcPagedSweep) - { - // Unchain from the string table - if (!unlinkstr(L, ts)) - unlinkstrbuf(L, ts); // An unlikely scenario when we have a string buffer on our hands - else - L->global->strt.nuse--; - - luaM_freegco(L, ts, sizestring(ts->len), ts->memcat, page); - } + // Unchain from the string table + if (!unlinkstr(L, ts)) + unlinkstrbuf(L, ts); // An unlikely scenario when we have a string buffer on our hands else - { L->global->strt.nuse--; - luaM_free(L, ts, sizestring(ts->len), ts->memcat); - } + luaM_freegco(L, ts, sizestring(ts->len), ts->memcat, page); } diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index c57374e0e..0412ea76d 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -425,7 +425,7 @@ static void rehash(lua_State* L, Table* t, const TValue* ek) Table* luaH_new(lua_State* L, int narray, int nhash) { Table* t = luaM_newgco(L, Table, sizeof(Table), L->activememcat); - luaC_link(L, t, LUA_TTABLE); + luaC_init(L, t, LUA_TTABLE); t->metatable = NULL; t->flags = cast_byte(~0); t->array = NULL; @@ -742,7 +742,7 @@ int luaH_getn(Table* t) Table* luaH_clone(lua_State* L, Table* tt) { Table* t = luaM_newgco(L, Table, sizeof(Table), L->activememcat); - luaC_link(L, t, LUA_TTABLE); + luaC_init(L, t, LUA_TTABLE); t->metatable = tt->metatable; t->flags = tt->flags; t->array = NULL; diff --git a/VM/src/ludata.cpp b/VM/src/ludata.cpp index 758a9bdb7..0dfac508f 100644 --- a/VM/src/ludata.cpp +++ b/VM/src/ludata.cpp @@ -12,7 +12,7 @@ Udata* luaU_newudata(lua_State* L, size_t s, int tag) if (s > INT_MAX - sizeof(Udata)) luaM_toobig(L); Udata* u = luaM_newgco(L, Udata, sizeudata(s), L->activememcat); - luaC_link(L, u, LUA_TUSERDATA); + luaC_init(L, u, LUA_TUSERDATA); u->len = int(s); u->metatable = NULL; LUAU_ASSERT(tag >= 0 && tag <= 255); diff --git a/fuzz/linter.cpp b/fuzz/linter.cpp index 04638d23d..0bdd49f58 100644 --- a/fuzz/linter.cpp +++ b/fuzz/linter.cpp @@ -1,10 +1,12 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include -#include "Luau/TypeInfer.h" -#include "Luau/Linter.h" + #include "Luau/BuiltinDefinitions.h" -#include "Luau/ModuleResolver.h" #include "Luau/Common.h" +#include "Luau/Linter.h" +#include "Luau/ModuleResolver.h" +#include "Luau/Parser.h" +#include "Luau/TypeInfer.h" extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) { diff --git a/fuzz/luau.proto b/fuzz/luau.proto index c78fcf31c..190b8c5be 100644 --- a/fuzz/luau.proto +++ b/fuzz/luau.proto @@ -96,11 +96,13 @@ message ExprIndexExpr { } message ExprFunction { - repeated Local args = 1; - required bool vararg = 2; - required StatBlock body = 3; - repeated Type types = 4; - repeated Type rettypes = 5; + repeated Typename generics = 1; + repeated Typename genericpacks = 2; + repeated Local args = 3; + required bool vararg = 4; + required StatBlock body = 5; + repeated Type types = 6; + repeated Type rettypes = 7; } message TableItem { @@ -153,7 +155,10 @@ message ExprBinary { message ExprIfElse { required Expr cond = 1; required Expr then = 2; - required Expr else = 3; + oneof else_oneof { + Expr else = 3; + ExprIfElse elseif = 4; + } } message LValue { @@ -183,6 +188,7 @@ message Stat { StatFunction function = 14; StatLocalFunction local_function = 15; StatTypeAlias type_alias = 16; + StatRequireIntoLocalHelper require_into_local = 17; } } @@ -276,9 +282,16 @@ message StatLocalFunction { } message StatTypeAlias { - required Typename name = 1; - required Type type = 2; - repeated Typename generics = 3; + required bool export = 1; + required Typename name = 2; + required Type type = 3; + repeated Typename generics = 4; + repeated Typename genericpacks = 5; +} + +message StatRequireIntoLocalHelper { + required Local var = 1; + required int32 modulenum = 2; } message Type { @@ -292,6 +305,8 @@ message Type { TypeIntersection intersection = 7; TypeClass class = 8; TypeRef ref = 9; + TypeBoolean boolean = 10; + TypeString string = 11; } } @@ -301,7 +316,8 @@ message TypePrimitive { message TypeLiteral { required Typename name = 1; - repeated Typename generics = 2; + repeated Type generics = 2; + repeated Typename genericpacks = 3; } message TypeTableItem { @@ -320,8 +336,10 @@ message TypeTable { } message TypeFunction { - repeated Type args = 1; - repeated Type rets = 2; + repeated Typename generics = 1; + repeated Typename genericpacks = 2; + repeated Type args = 3; + repeated Type rets = 4; // TODO: vararg? } @@ -347,3 +365,16 @@ message TypeRef { required Local prefix = 1; required Typename index = 2; } + +message TypeBoolean { + required bool val = 1; +} + +message TypeString { + required string val = 1; +} + +message ModuleSet { + optional StatBlock module = 1; + required StatBlock program = 2; +} diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index f407248a5..912fef231 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -2,16 +2,17 @@ #include "src/libfuzzer/libfuzzer_macro.h" #include "luau.pb.h" -#include "Luau/TypeInfer.h" #include "Luau/BuiltinDefinitions.h" -#include "Luau/ModuleResolver.h" -#include "Luau/ModuleResolver.h" -#include "Luau/Compiler.h" -#include "Luau/Linter.h" #include "Luau/BytecodeBuilder.h" #include "Luau/Common.h" +#include "Luau/Compiler.h" +#include "Luau/Frontend.h" +#include "Luau/Linter.h" +#include "Luau/ModuleResolver.h" +#include "Luau/Parser.h" #include "Luau/ToString.h" #include "Luau/Transpiler.h" +#include "Luau/TypeInfer.h" #include "lua.h" #include "lualib.h" @@ -30,7 +31,7 @@ const bool kFuzzTypes = true; static_assert(!(kFuzzVM && !kFuzzCompiler), "VM requires the compiler!"); -std::string protoprint(const luau::StatBlock& stat, bool types); +std::vector protoprint(const luau::ModuleSet& stat, bool types); LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTypeInferTypePackLoopLimit) @@ -38,6 +39,7 @@ LUAU_FASTINT(LuauCheckRecursionLimit) LUAU_FASTINT(LuauTableTypeMaximumStringifierLength) LUAU_FASTINT(LuauTypeInferIterationLimit) LUAU_FASTINT(LuauTarjanChildLimit) +LUAU_FASTFLAG(DebugLuauFreezeArena) std::chrono::milliseconds kInterruptTimeout(10); std::chrono::time_point interruptDeadline; @@ -135,10 +137,58 @@ int registerTypes(Luau::TypeChecker& env) return 0; } +struct FuzzFileResolver : Luau::FileResolver +{ + std::optional readSource(const Luau::ModuleName& name) override + { + auto it = source.find(name); + if (it == source.end()) + return std::nullopt; + + return Luau::SourceCode{it->second, Luau::SourceCode::Module}; + } + + std::optional resolveModule(const Luau::ModuleInfo* context, Luau::AstExpr* expr) override + { + if (Luau::AstExprGlobal* g = expr->as()) + return Luau::ModuleInfo{g->name.value}; + + return std::nullopt; + } + + std::string getHumanReadableModuleName(const Luau::ModuleName& name) const override + { + return name; + } + + std::optional getEnvironmentForModule(const Luau::ModuleName& name) const override + { + return std::nullopt; + } + + std::unordered_map source; +}; + +struct FuzzConfigResolver : Luau::ConfigResolver +{ + FuzzConfigResolver() + { + defaultConfig.mode = Luau::Mode::Nonstrict; // typecheckTwice option will cover Strict mode + defaultConfig.enabledLint.warningMask = ~0ull; + defaultConfig.parseOptions.captureComments = true; + } + + virtual const Luau::Config& getConfig(const Luau::ModuleName& name) const override + { + return defaultConfig; + } + + Luau::Config defaultConfig; +}; -static std::string debugsource; +static std::vector debugsources; -DEFINE_PROTO_FUZZER(const luau::StatBlock& message) +DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) { FInt::LuauTypeInferRecursionLimit.value = 100; FInt::LuauTypeInferTypePackLoopLimit.value = 100; @@ -151,91 +201,90 @@ DEFINE_PROTO_FUZZER(const luau::StatBlock& message) if (strncmp(flag->name, "Luau", 4) == 0) flag->value = true; - Luau::Allocator allocator; - Luau::AstNameTable names(allocator); + FFlag::DebugLuauFreezeArena.value = true; - std::string source = protoprint(message, kFuzzTypes); + std::vector sources = protoprint(message, kFuzzTypes); // stash source in a global for easier crash dump debugging - debugsource = source; - - Luau::ParseResult parseResult = Luau::Parser::parse(source.c_str(), source.size(), names, allocator); - - // "static" here is to accelerate fuzzing process by only creating and populating the type environment once - static Luau::NullModuleResolver moduleResolver; - static Luau::InternalErrorReporter iceHandler; - static Luau::TypeChecker sharedEnv(&moduleResolver, &iceHandler); - static int once = registerTypes(sharedEnv); - (void)once; - static int once2 = (Luau::freeze(sharedEnv.globalTypes), 0); - (void)once2; - - iceHandler.onInternalError = [](const char* error) { - printf("ICE: %s\n", error); - LUAU_ASSERT(!"ICE"); - }; + debugsources = sources; static bool debug = getenv("LUAU_DEBUG") != 0; if (debug) { - fprintf(stdout, "--\n%s\n", source.c_str()); + for (std::string& source : sources) + fprintf(stdout, "--\n%s\n", source.c_str()); fflush(stdout); } - std::string bytecode; + // parse all sources + std::vector> parseAllocators; + std::vector> parseNameTables; - // compile - if (kFuzzCompiler && parseResult.errors.empty()) - { - Luau::CompileOptions compileOptions; + Luau::ParseOptions parseOptions; + parseOptions.captureComments = true; - try - { - Luau::BytecodeBuilder bcb; - Luau::compileOrThrow(bcb, parseResult.root, names, compileOptions); - bytecode = bcb.getBytecode(); - } - catch (const Luau::CompileError&) - { - // not all valid ASTs can be compiled due to limits on number of registers - } - } + std::vector parseResults; - // typecheck - if (kFuzzTypeck && parseResult.root) + for (std::string& source : sources) { - Luau::SourceModule sourceModule; - sourceModule.root = parseResult.root; - sourceModule.mode = Luau::Mode::Nonstrict; - - Luau::TypeChecker typeck(&moduleResolver, &iceHandler); - typeck.globalScope = sharedEnv.globalScope; + parseAllocators.push_back(std::make_unique()); + parseNameTables.push_back(std::make_unique(*parseAllocators.back())); - Luau::ModulePtr module = nullptr; + parseResults.push_back(Luau::Parser::parse(source.c_str(), source.size(), *parseNameTables.back(), *parseAllocators.back(), parseOptions)); + } - try - { - module = typeck.check(sourceModule, Luau::Mode::Nonstrict); - } - catch (std::exception&) + // typecheck all sources + if (kFuzzTypeck) + { + static FuzzFileResolver fileResolver; + static Luau::NullConfigResolver configResolver; + static Luau::FrontendOptions options{true, true}; + static Luau::Frontend frontend(&fileResolver, &configResolver, options); + + static int once = registerTypes(frontend.typeChecker); + (void)once; + static int once2 = (Luau::freeze(frontend.typeChecker.globalTypes), 0); + (void)once2; + + frontend.iceHandler.onInternalError = [](const char* error) { + printf("ICE: %s\n", error); + LUAU_ASSERT(!"ICE"); + }; + + // restart + frontend.clear(); + fileResolver.source.clear(); + + // load sources + for (size_t i = 0; i < sources.size(); i++) { - // This catches internal errors that the type checker currently (unfortunately) throws in some cases + std::string name = "module" + std::to_string(i); + fileResolver.source[name] = sources[i]; } - // lint (note that we need access to types so we need to do this with typeck in scope) - if (kFuzzLinter) + // check sources + for (size_t i = 0; i < sources.size(); i++) { - Luau::LintOptions lintOptions = {~0u}; - Luau::lint(parseResult.root, names, sharedEnv.globalScope, module.get(), {}, lintOptions); + std::string name = "module" + std::to_string(i); + + try + { + Luau::CheckResult result = frontend.check(name, std::nullopt); + + // lint (note that we need access to types so we need to do this with typeck in scope) + if (kFuzzLinter && result.errors.empty()) + frontend.lint(name, std::nullopt); + } + catch (std::exception&) + { + // This catches internal errors that the type checker currently (unfortunately) throws in some cases + } } - } - // validate sharedEnv post-typecheck; valuable for debugging some typeck crashes but slows fuzzing down - // note: it's important for typeck to be destroyed at this point! - if (kFuzzTypeck) - { - for (auto& p : sharedEnv.globalScope->bindings) + // validate sharedEnv post-typecheck; valuable for debugging some typeck crashes but slows fuzzing down + // note: it's important for typeck to be destroyed at this point! + for (auto& p : frontend.typeChecker.globalScope->bindings) { Luau::ToStringOptions opts; opts.exhaustive = true; @@ -246,12 +295,44 @@ DEFINE_PROTO_FUZZER(const luau::StatBlock& message) } } - if (kFuzzTranspile && parseResult.root) + if (kFuzzTranspile) { - transpileWithTypes(*parseResult.root); + for (Luau::ParseResult& parseResult : parseResults) + { + if (parseResult.root) + transpileWithTypes(*parseResult.root); + } + } + + std::string bytecode; + + // compile + if (kFuzzCompiler) + { + for (size_t i = 0; i < parseResults.size(); i++) + { + Luau::ParseResult& parseResult = parseResults[i]; + Luau::AstNameTable& parseNameTable = *parseNameTables[i]; + + if (parseResult.errors.empty()) + { + Luau::CompileOptions compileOptions; + + try + { + Luau::BytecodeBuilder bcb; + Luau::compileOrThrow(bcb, parseResult.root, parseNameTable, compileOptions); + bytecode = bcb.getBytecode(); + } + catch (const Luau::CompileError&) + { + // not all valid ASTs can be compiled due to limits on number of registers + } + } + } } - // run resulting bytecode + // run resulting bytecode (from last successfully compiler module) if (kFuzzVM && bytecode.size()) { static lua_State* globalState = createGlobalState(); diff --git a/fuzz/protoprint.cpp b/fuzz/protoprint.cpp index e61b69365..66a89f243 100644 --- a/fuzz/protoprint.cpp +++ b/fuzz/protoprint.cpp @@ -208,6 +208,35 @@ struct ProtoToLuau source += std::to_string(name.index() & 0xff); } + template + void genericidents(const T& node) + { + if (node.generics_size() || node.genericpacks_size()) + { + source += '<'; + bool first = true; + + for (size_t i = 0; i < node.generics_size(); ++i) + { + if (!first) + source += ','; + first = false; + ident(node.generics(i)); + } + + for (size_t i = 0; i < node.genericpacks_size(); ++i) + { + if (!first) + source += ','; + first = false; + ident(node.genericpacks(i)); + source += "..."; + } + + source += '>'; + } + } + void print(const luau::Expr& expr) { if (expr.has_group()) @@ -240,6 +269,8 @@ struct ProtoToLuau print(expr.unary()); else if (expr.has_binary()) print(expr.binary()); + else if (expr.has_ifelse()) + print(expr.ifelse()); else source += "_"; } @@ -350,6 +381,7 @@ struct ProtoToLuau void function(const luau::ExprFunction& expr) { + genericidents(expr); source += "("; for (int i = 0; i < expr.args_size(); ++i) { @@ -478,12 +510,21 @@ struct ProtoToLuau void print(const luau::ExprIfElse& expr) { - source += " if "; + source += "if "; print(expr.cond()); source += " then "; print(expr.then()); - source += " else "; - print(expr.else_()); + + if (expr.has_else_()) + { + source += " else "; + print(expr.else_()); + } + else if (expr.has_elseif()) + { + source += " else"; + print(expr.elseif()); + } } void print(const luau::LValue& expr) @@ -534,6 +575,8 @@ struct ProtoToLuau print(stat.local_function()); else if (stat.has_type_alias()) print(stat.type_alias()); + else if (stat.has_require_into_local()) + print(stat.require_into_local()); else source += "do end\n"; } @@ -804,26 +847,24 @@ struct ProtoToLuau void print(const luau::StatTypeAlias& stat) { + if (stat.export_()) + source += "export "; + source += "type "; ident(stat.name()); - - if (stat.generics_size()) - { - source += '<'; - for (size_t i = 0; i < stat.generics_size(); ++i) - { - if (i != 0) - source += ','; - ident(stat.generics(i)); - } - source += '>'; - } - + genericidents(stat); source += " = "; print(stat.type()); source += '\n'; } + void print(const luau::StatRequireIntoLocalHelper& stat) + { + source += "local "; + print(stat.var()); + source += " = require(module" + std::to_string(stat.modulenum() % 2) + ")\n"; + } + void print(const luau::Type& type) { if (type.has_primitive()) @@ -844,6 +885,10 @@ struct ProtoToLuau print(type.class_()); else if (type.has_ref()) print(type.ref()); + else if (type.has_boolean()) + print(type.boolean()); + else if (type.has_string()) + print(type.string()); else source += "any"; } @@ -858,15 +903,28 @@ struct ProtoToLuau { ident(type.name()); - if (type.generics_size()) + if (type.generics_size() || type.genericpacks_size()) { source += '<'; + bool first = true; + for (size_t i = 0; i < type.generics_size(); ++i) { - if (i != 0) + if (!first) source += ','; - ident(type.generics(i)); + first = false; + print(type.generics(i)); } + + for (size_t i = 0; i < type.genericpacks_size(); ++i) + { + if (!first) + source += ','; + first = false; + ident(type.genericpacks(i)); + source += "..."; + } + source += '>'; } } @@ -893,6 +951,7 @@ struct ProtoToLuau void print(const luau::TypeFunction& type) { + genericidents(type); source += '('; for (size_t i = 0; i < type.args_size(); ++i) { @@ -950,12 +1009,38 @@ struct ProtoToLuau source += '.'; ident(type.index()); } + + void print(const luau::TypeBoolean& type) + { + source += type.val() ? "true" : "false"; + } + + void print(const luau::TypeString& type) + { + source += '"'; + for (char ch : type.val()) + if (isgraph(ch)) + source += ch; + source += '"'; + } }; -std::string protoprint(const luau::StatBlock& stat, bool types) +std::vector protoprint(const luau::ModuleSet& stat, bool types) { + std::vector result; + + if (stat.has_module()) + { + ProtoToLuau printer; + printer.types = types; + printer.print(stat.module()); + result.push_back(printer.source); + } + ProtoToLuau printer; printer.types = types; - printer.print(stat); - return printer.source; + printer.print(stat.program()); + result.push_back(printer.source); + + return result; } diff --git a/fuzz/prototest.cpp b/fuzz/prototest.cpp index 804e708a9..ccaa1971b 100644 --- a/fuzz/prototest.cpp +++ b/fuzz/prototest.cpp @@ -2,11 +2,15 @@ #include "src/libfuzzer/libfuzzer_macro.h" #include "luau.pb.h" -std::string protoprint(const luau::StatBlock& stat, bool types); +std::vector protoprint(const luau::ModuleSet& stat, bool types); -DEFINE_PROTO_FUZZER(const luau::StatBlock& message) +DEFINE_PROTO_FUZZER(const luau::ModuleSet& message) { - std::string source = protoprint(message, true); + std::vector sources = protoprint(message, true); - printf("%s\n", source.c_str()); + for (size_t i = 0; i < sources.size(); i++) + { + printf("Module 'l%d':\n", int(i)); + printf("%s\n", sources[i].c_str()); + } } diff --git a/fuzz/typeck.cpp b/fuzz/typeck.cpp index 5020c7710..3905cc191 100644 --- a/fuzz/typeck.cpp +++ b/fuzz/typeck.cpp @@ -1,9 +1,11 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include -#include "Luau/TypeInfer.h" + #include "Luau/BuiltinDefinitions.h" -#include "Luau/ModuleResolver.h" #include "Luau/Common.h" +#include "Luau/ModuleResolver.h" +#include "Luau/Parser.h" +#include "Luau/TypeInfer.h" LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTINT(LuauTypeInferTypePackLoopLimit) diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 1978a0d31..6aadef329 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2752,7 +2752,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons") ScopedFastFlag sffs[] = { {"LuauParseSingletonTypes", true}, {"LuauSingletonTypes", true}, - {"LuauRefactorTypeVarQuestions", true}, }; check(R"( diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index b09c1efb9..eb6ca749a 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -712,6 +712,47 @@ TEST_CASE("Reference") CHECK(dtorhits == 2); } +TEST_CASE("ApiTables") +{ + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + lua_newtable(L); + lua_pushnumber(L, 123.0); + lua_setfield(L, -2, "key"); + lua_pushstring(L, "test"); + lua_rawseti(L, -2, 5); + + // lua_gettable + lua_pushstring(L, "key"); + CHECK(lua_gettable(L, -2) == LUA_TNUMBER); + CHECK(lua_tonumber(L, -1) == 123.0); + lua_pop(L, 1); + + // lua_getfield + CHECK(lua_getfield(L, -1, "key") == LUA_TNUMBER); + CHECK(lua_tonumber(L, -1) == 123.0); + lua_pop(L, 1); + + // lua_rawgetfield + CHECK(lua_rawgetfield(L, -1, "key") == LUA_TNUMBER); + CHECK(lua_tonumber(L, -1) == 123.0); + lua_pop(L, 1); + + // lua_rawget + lua_pushstring(L, "key"); + CHECK(lua_rawget(L, -2) == LUA_TNUMBER); + CHECK(lua_tonumber(L, -1) == 123.0); + lua_pop(L, 1); + + // lua_rawgeti + CHECK(lua_rawgeti(L, -1, 5) == LUA_TSTRING); + CHECK(strcmp(lua_tostring(L, -1), "test") == 0); + lua_pop(L, 1); + + lua_pop(L, 1); +} + TEST_CASE("ApiFunctionCalls") { StateRef globalState = runConformance("apicalls.lua"); @@ -796,7 +837,7 @@ TEST_CASE("ExceptionObject") return ExceptionResult{false, ""}; }; - auto reallocFunc = [](lua_State* L, void* /*ud*/, void* ptr, size_t /*osize*/, size_t nsize) -> void* { + auto reallocFunc = [](void* /*ud*/, void* ptr, size_t /*osize*/, size_t nsize) -> void* { if (nsize == 0) { free(ptr); @@ -923,4 +964,53 @@ TEST_CASE("StringConversion") runConformance("strconv.lua"); } +TEST_CASE("GCDump") +{ + // internal function, declared in lgc.h - not exposed via lua.h + extern void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)); + + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + // push various objects on stack to cover different paths + lua_createtable(L, 1, 2); + lua_pushstring(L, "value"); + lua_setfield(L, -2, "key"); + + lua_pushinteger(L, 42); + lua_rawseti(L, -2, 1000); + + lua_pushinteger(L, 42); + lua_rawseti(L, -2, 1); + + lua_pushvalue(L, -1); + lua_setmetatable(L, -2); + + lua_newuserdata(L, 42); + lua_pushvalue(L, -2); + lua_setmetatable(L, -2); + + lua_pushinteger(L, 1); + lua_pushcclosure(L, lua_silence, "test", 1); + + lua_State* CL = lua_newthread(L); + + lua_pushstring(CL, "local x x = {} local function f() x[1] = math.abs(42) end function foo() coroutine.yield() end foo() return f"); + lua_loadstring(CL); + lua_resume(CL, nullptr, 0); + +#ifdef _WIN32 + const char* path = "NUL"; +#else + const char* path = "/dev/null"; +#endif + + FILE* f = fopen(path, "w"); + REQUIRE(f); + + luaC_dump(L, f, nullptr); + + fclose(f); +} + TEST_SUITE_END(); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index d4b973607..ab19cea30 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1594,4 +1594,17 @@ TEST_CASE_FIXTURE(Fixture, "WrongCommentMuteSelf") REQUIRE_EQ(result.warnings.size(), 0); // --!nolint disables WrongComment lint :) } +TEST_CASE_FIXTURE(Fixture, "DuplicateConditionsIfStatAndExpr") +{ + LintResult result = lint(R"( +if if 1 then 2 else 3 then +elseif if 1 then 2 else 3 then +elseif if 0 then 5 else 4 then +end +)"); + + REQUIRE_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "Condition has already been checked on line 2"); +} + TEST_SUITE_END(); diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 0d4c088dd..77e49ce3c 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -2575,7 +2575,6 @@ do end TEST_CASE_FIXTURE(Fixture, "recover_expected_type_pack") { ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauParseRecoverTypePackEllipsis{"LuauParseRecoverTypePackEllipsis", true}; ParseResult result = tryParse(R"( type Y = (T...) -> U... diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index ac5be859b..332aba9e4 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -651,4 +651,19 @@ local a: Packed CHECK_EQ(code, transpile(code, {}, true).code); } + +TEST_CASE_FIXTURE(Fixture, "transpile_singleton_types") +{ + ScopedFastFlag luauParseSingletonTypes{"LuauParseSingletonTypes", true}; + + std::string code = R"( +type t1 = 'hello' +type t2 = true +type t3 = '' +type t4 = false + )"; + + CHECK_EQ(code, transpile(code, {}, true).code); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index d677e28d8..26881b5cc 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -175,6 +175,8 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_property_guarante TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_works_at_arbitrary_depth") { + ScopedFastFlag sff{"LuauDoNotTryToReduce", true}; + CheckResult result = check(R"( type A = {x: {y: {z: {thing: string}}}} type B = {x: {y: {z: {thing: string}}}} @@ -184,7 +186,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_works_at_arbitrary_dep )"); LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(*typeChecker.stringType, *requireType("r")); + CHECK_EQ("string & string", toString(requireType("r"))); } TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_mixed_types") @@ -218,7 +220,7 @@ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_one_part_missing_ TEST_CASE_FIXTURE(Fixture, "index_on_an_intersection_type_with_one_property_of_type_any") { CheckResult result = check(R"( - type A = {x: number} + type A = {y: number} type B = {x: any} local t: A & B diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 48e6be6a4..bff8926c6 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1115,6 +1115,22 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_ LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "refine_a_property_not_to_be_nil_through_an_intersection_table") +{ + ScopedFastFlag sff{"LuauDoNotTryToReduce", true}; + + CheckResult result = check(R"( + type T = {} & {f: ((string) -> string)?} + local function f(t: T, x) + if t.f then + t.f(x) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(RefinementClassFixture, "discriminate_from_isa_of_x") { ScopedFastFlag sff[] = { diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 9021700dc..856549bde 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -439,4 +439,128 @@ local a: Animal = if true then { tag = 'cat', catfood = 'something' } else { tag LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_singleton") +{ + ScopedFastFlag sff[]{ + {"LuauSingletonTypes", true}, + {"LuauEqConstraint", true}, + {"LuauDiscriminableUnions2", true}, + {"LuauWidenIfSupertypeIsFree", true}, + {"LuauWeakEqConstraint", false}, + }; + + CheckResult result = check(R"( + local function foo(f, x) + if x == "hi" then + f(x) + f("foo") + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 18}))); + // should be ((string) -> a..., string) -> () but needs lower bounds calculation + CHECK_EQ("((string) -> (b...), a) -> ()", toString(requireType("foo"))); +} + +// TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") +// { +// ScopedFastFlag sff[]{ +// {"LuauParseSingletonTypes", true}, +// {"LuauSingletonTypes", true}, +// {"LuauDiscriminableUnions2", true}, +// {"LuauEqConstraint", true}, +// {"LuauWidenIfSupertypeIsFree", true}, +// {"LuauWeakEqConstraint", false}, +// }; + +// CheckResult result = check(R"( +// local function foo(f, x): "hello"? -- anyone there? +// return if x == "hi" +// then f(x) +// else nil +// end +// )"); + +// LUAU_REQUIRE_NO_ERRORS(result); + +// CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 23}))); +// CHECK_EQ(R"(((string) -> ("hello"?, b...), a) -> "hello"?)", toString(requireType("foo"))); +// } + +TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere") +{ + ScopedFastFlag sff[]{ + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + {"LuauWidenIfSupertypeIsFree", true}, + }; + + CheckResult result = check(R"( + local foo: "foo" = "foo" + local copy = foo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("string", toString(requireType("copy"))); +} + +TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere_except_for_tables") +{ + ScopedFastFlag sff[]{ + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + {"LuauDiscriminableUnions2", true}, + {"LuauWidenIfSupertypeIsFree", true}, + }; + + CheckResult result = check(R"( + type Cat = {tag: "Cat", meows: boolean} + type Dog = {tag: "Dog", barks: boolean} + type Animal = Cat | Dog + + local function f(tag: "Cat" | "Dog"): Animal? + if tag == "Cat" then + local result = {tag = tag, meows = true} + return result + elseif tag == "Dog" then + local result = {tag = tag, barks = true} + return result + else + return nil + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "table_insert_with_a_singleton_argument") +{ + ScopedFastFlag sff[]{ + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + {"LuauWidenIfSupertypeIsFree", true}, + }; + + CheckResult result = check(R"( + local function foo(t, x) + if x == "hi" or x == "bye" then + table.insert(t, x) + end + + return t + end + + local t = foo({}, "hi") + table.insert(t, "totally_unrelated_type" :: "totally_unrelated_type") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("{string}", toString(requireType("t"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 6bcd4b99a..aa949789c 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2239,4 +2239,22 @@ TEST_CASE_FIXTURE(Fixture, "give_up_after_one_metatable_index_look_up") CHECK_EQ("Type 't2' does not have key 'x'", toString(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "confusing_indexing") +{ + ScopedFastFlag sff{"LuauDoNotTryToReduce", true}; + + CheckResult result = check(R"( + type T = {} & {p: number | string} + local function f(t: T) + return t.p + end + + local foo = f({p = "string"}) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number | string", toString(requireType("foo"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 323585712..f44d9fd83 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -5127,8 +5127,6 @@ TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types") { - ScopedFastFlag noSealedTypeMod{"LuauNoSealedTypeMod", true}; - fileResolver.source["game/A"] = R"( export type Type = { unrelated: boolean } return {} @@ -5190,7 +5188,6 @@ TEST_CASE_FIXTURE(Fixture, "indexing_on_string_singletons") { ScopedFastFlag sff[]{ {"LuauDiscriminableUnions2", true}, - {"LuauRefactorTypeVarQuestions", true}, {"LuauSingletonTypes", true}, }; @@ -5210,7 +5207,6 @@ TEST_CASE_FIXTURE(Fixture, "indexing_on_union_of_string_singletons") { ScopedFastFlag sff[]{ {"LuauDiscriminableUnions2", true}, - {"LuauRefactorTypeVarQuestions", true}, {"LuauSingletonTypes", true}, }; @@ -5230,7 +5226,6 @@ TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_string_singleton") { ScopedFastFlag sff[]{ {"LuauDiscriminableUnions2", true}, - {"LuauRefactorTypeVarQuestions", true}, {"LuauSingletonTypes", true}, }; @@ -5250,7 +5245,6 @@ TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_union_of_string_singleton") { ScopedFastFlag sff[]{ {"LuauDiscriminableUnions2", true}, - {"LuauRefactorTypeVarQuestions", true}, {"LuauSingletonTypes", true}, }; diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index 4669ea8eb..c9bf51032 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -19,7 +19,7 @@ struct TryUnifyFixture : Fixture ScopePtr globalScope{new Scope{arena.addTypePack({TypeId{}})}}; InternalErrorReporter iceHandler; UnifierSharedState unifierState{&iceHandler}; - Unifier state{&arena, Mode::Strict, globalScope, Location{}, Variance::Covariant, unifierState}; + Unifier state{&arena, Mode::Strict, Location{}, Variance::Covariant, unifierState}; }; TEST_SUITE_BEGIN("TryUnifyTests"); @@ -261,8 +261,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly") TEST_CASE_FIXTURE(TryUnifyFixture, "recursive_metatable_getmatchtag") { - ScopedFastFlag luauUnionTagMatchFix{"LuauUnionTagMatchFix", true}; - TypeVar redirect{FreeTypeVar{TypeLevel{}}}; TypeVar table{TableTypeVar{}}; TypeVar metatable{MetatableTypeVar{&redirect, &table}}; diff --git a/tests/TypeVar.test.cpp b/tests/TypeVar.test.cpp index e43161fa3..fd5f4dbcd 100644 --- a/tests/TypeVar.test.cpp +++ b/tests/TypeVar.test.cpp @@ -361,16 +361,12 @@ local b: (T, T, T) -> T TEST_CASE("isString_on_string_singletons") { - ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; - TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}}; CHECK(isString(&helloString)); } TEST_CASE("isString_on_unions_of_various_string_singletons") { - ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; - TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}}; TypeVar byeString{SingletonTypeVar{StringSingleton{"bye"}}}; TypeVar union_{UnionTypeVar{{&helloString, &byeString}}}; @@ -380,8 +376,6 @@ TEST_CASE("isString_on_unions_of_various_string_singletons") TEST_CASE("proof_that_isString_uses_all_of") { - ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; - TypeVar helloString{SingletonTypeVar{StringSingleton{"hello"}}}; TypeVar byeString{SingletonTypeVar{StringSingleton{"bye"}}}; TypeVar booleanType{PrimitiveTypeVar{PrimitiveTypeVar::Boolean}}; @@ -392,16 +386,12 @@ TEST_CASE("proof_that_isString_uses_all_of") TEST_CASE("isBoolean_on_boolean_singletons") { - ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; - TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}}; CHECK(isBoolean(&trueBool)); } TEST_CASE("isBoolean_on_unions_of_true_or_false_singletons") { - ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; - TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}}; TypeVar falseBool{SingletonTypeVar{BooleanSingleton{false}}}; TypeVar union_{UnionTypeVar{{&trueBool, &falseBool}}}; @@ -411,8 +401,6 @@ TEST_CASE("isBoolean_on_unions_of_true_or_false_singletons") TEST_CASE("proof_that_isBoolean_uses_all_of") { - ScopedFastFlag sff{"LuauRefactorTypeVarQuestions", true}; - TypeVar trueBool{SingletonTypeVar{BooleanSingleton{true}}}; TypeVar falseBool{SingletonTypeVar{BooleanSingleton{false}}}; TypeVar stringType{PrimitiveTypeVar{PrimitiveTypeVar::String}}; diff --git a/tools/natvis/VM.natvis b/tools/natvis/VM.natvis index 9924e194f..ccc7e3905 100644 --- a/tools/natvis/VM.natvis +++ b/tools/natvis/VM.natvis @@ -183,13 +183,12 @@ openupval - u.l.next + u.l.threadnext this - l_gt - env + gt userdata From 9bfecab5baf796c0b358a88fbc8ca8d70af2da12 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 4 Mar 2022 08:19:20 -0800 Subject: [PATCH 28/32] Sync to upstream/release/517 --- Analysis/include/Luau/TxnLog.h | 32 ++- Analysis/include/Luau/TypeInfer.h | 18 -- Analysis/include/Luau/TypeVar.h | 5 +- Analysis/include/Luau/Unifier.h | 2 +- Analysis/src/Autocomplete.cpp | 62 ++--- Analysis/src/BuiltinDefinitions.cpp | 13 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 1 - Analysis/src/JsonEncoder.cpp | 44 +-- Analysis/src/Linter.cpp | 157 ++++++++++- Analysis/src/Module.cpp | 5 +- Analysis/src/ToString.cpp | 119 +------- Analysis/src/Transpiler.cpp | 63 ++--- Analysis/src/TxnLog.cpp | 80 +++++- Analysis/src/TypeInfer.cpp | 291 ++++++++------------ Analysis/src/TypeVar.cpp | 27 +- Analysis/src/Unifier.cpp | 135 +++++++-- Ast/src/Parser.cpp | 3 +- CLI/Repl.cpp | 35 +-- VM/include/lua.h | 3 + VM/src/lapi.cpp | 52 +++- VM/src/laux.cpp | 46 +--- VM/src/ldebug.cpp | 5 + VM/src/lgc.cpp | 74 ++++- VM/src/lgc.h | 7 - VM/src/lnumprint.cpp | 9 - VM/src/lnumutils.h | 1 - VM/src/lstate.h | 23 +- VM/src/ltable.cpp | 3 +- VM/src/ltablib.cpp | 21 ++ bench/gc/test_GC_Boehm_Trees.lua | 3 + bench/gc/test_GC_Tree_Pruning_Eager.lua | 2 +- bench/gc/test_GC_Tree_Pruning_Gen.lua | 2 +- bench/gc/test_GC_Tree_Pruning_Lazy.lua | 2 +- extern/isocline/include/isocline.h | 2 +- fuzz/number.cpp | 4 - fuzz/proto.cpp | 2 +- tests/Autocomplete.test.cpp | 10 +- tests/Conformance.test.cpp | 6 +- tests/Linter.test.cpp | 107 +++++++ tests/Parser.test.cpp | 6 - tests/ToDot.test.cpp | 2 - tests/ToString.test.cpp | 2 - tests/Transpiler.test.cpp | 3 - tests/TypeInfer.aliases.test.cpp | 3 +- tests/TypeInfer.builtins.test.cpp | 27 ++ tests/TypeInfer.generics.test.cpp | 89 ++++++ tests/TypeInfer.provisional.test.cpp | 95 +------ tests/TypeInfer.refinements.test.cpp | 11 +- tests/TypeInfer.singletons.test.cpp | 50 ++-- tests/TypeInfer.tables.test.cpp | 81 +++++- tests/TypeInfer.test.cpp | 71 +++++ tests/TypeInfer.tryUnify.test.cpp | 7 +- tests/TypeInfer.typePacks.cpp | 39 --- tests/TypeInfer.unionTypes.test.cpp | 31 ++- tests/conformance/nextvar.lua | 38 +++ 55 files changed, 1262 insertions(+), 769 deletions(-) diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index f238e258a..f81053839 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -12,6 +12,8 @@ LUAU_FASTFLAG(LuauShareTxnSeen); namespace Luau { +using TypeOrPackId = const void*; + // Log of where what TypeIds we are rebinding and what they used to be // Remove with LuauUseCommitTxnLog struct DEPRECATED_TxnLog @@ -23,7 +25,7 @@ struct DEPRECATED_TxnLog { } - explicit DEPRECATED_TxnLog(std::vector>* sharedSeen) + explicit DEPRECATED_TxnLog(std::vector>* sharedSeen) : originalSeenSize(sharedSeen->size()) , ownedSeen() , sharedSeen(sharedSeen) @@ -48,15 +50,23 @@ struct DEPRECATED_TxnLog void pushSeen(TypeId lhs, TypeId rhs); void popSeen(TypeId lhs, TypeId rhs); + bool haveSeen(TypePackId lhs, TypePackId rhs); + void pushSeen(TypePackId lhs, TypePackId rhs); + void popSeen(TypePackId lhs, TypePackId rhs); + private: std::vector> typeVarChanges; std::vector> typePackChanges; std::vector>> tableChanges; size_t originalSeenSize; + bool haveSeen(TypeOrPackId lhs, TypeOrPackId rhs); + void pushSeen(TypeOrPackId lhs, TypeOrPackId rhs); + void popSeen(TypeOrPackId lhs, TypeOrPackId rhs); + public: - std::vector> ownedSeen; // used to avoid infinite recursion when types are cyclic - std::vector>* sharedSeen; // shared with all the descendent logs + std::vector> ownedSeen; // used to avoid infinite recursion when types are cyclic + std::vector>* sharedSeen; // shared with all the descendent logs }; // Pending state for a TypeVar. Generated by a TxnLog and committed via @@ -127,12 +137,12 @@ struct TxnLog } } - explicit TxnLog(std::vector>* sharedSeen) + explicit TxnLog(std::vector>* sharedSeen) : sharedSeen(sharedSeen) { } - TxnLog(TxnLog* parent, std::vector>* sharedSeen) + TxnLog(TxnLog* parent, std::vector>* sharedSeen) : parent(parent) , sharedSeen(sharedSeen) { @@ -173,6 +183,10 @@ struct TxnLog void pushSeen(TypeId lhs, TypeId rhs); void popSeen(TypeId lhs, TypeId rhs); + bool haveSeen(TypePackId lhs, TypePackId rhs) const; + void pushSeen(TypePackId lhs, TypePackId rhs); + void popSeen(TypePackId lhs, TypePackId rhs); + // Queues a type for modification. The original type will not change until commit // is called. Use pending to get the pending state. // @@ -316,12 +330,16 @@ struct TxnLog // TxnLogs; use sharedSeen instead. This field exists because in the tree // of TxnLogs, the root must own its seen set. In all descendant TxnLogs, // this is an empty vector. - std::vector> ownedSeen; + std::vector> ownedSeen; + + bool haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) const; + void pushSeen(TypeOrPackId lhs, TypeOrPackId rhs); + void popSeen(TypeOrPackId lhs, TypeOrPackId rhs); public: // Used to avoid infinite recursion when types are cyclic. // Shared with all the descendent TxnLogs. - std::vector>* sharedSeen; + std::vector>* sharedSeen; }; } // namespace Luau diff --git a/Analysis/include/Luau/TypeInfer.h b/Analysis/include/Luau/TypeInfer.h index 2440c810b..839043cc0 100644 --- a/Analysis/include/Luau/TypeInfer.h +++ b/Analysis/include/Luau/TypeInfer.h @@ -73,24 +73,6 @@ struct Instantiation : Substitution TypePackId clean(TypePackId tp) override; }; -// A substitution which replaces free types by generic types. -struct Quantification : Substitution -{ - Quantification(TypeArena* arena, TypeLevel level) - : Substitution(TxnLog::empty(), arena) - , level(level) - { - } - - TypeLevel level; - std::vector generics; - std::vector genericPacks; - bool isDirty(TypeId ty) override; - bool isDirty(TypePackId tp) override; - TypeId clean(TypeId ty) override; - TypePackId clean(TypePackId tp) override; -}; - // A substitution which replaces free types by any struct Anyification : Substitution { diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 8d1a9fa6c..29578dcd9 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -298,7 +298,7 @@ struct TableTypeVar TableTypeVar() = default; explicit TableTypeVar(TableState state, TypeLevel level); - TableTypeVar(const Props& props, const std::optional& indexer, TypeLevel level, TableState state = TableState::Unsealed); + TableTypeVar(const Props& props, const std::optional& indexer, TypeLevel level, TableState state); Props props; std::optional indexer; @@ -477,6 +477,9 @@ bool isOptional(TypeId ty); bool isTableIntersection(TypeId ty); bool isOverloadedFunction(TypeId ty); +// True when string is a subtype of ty +bool maybeString(TypeId ty); + std::optional getMetatable(TypeId type); TableTypeVar* getMutableTableType(TypeId type); const TableTypeVar* getTableType(TypeId type); diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index fe822b012..4c0462fe5 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -56,7 +56,7 @@ struct Unifier Unifier(TypeArena* types, Mode mode, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); - Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, + Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog = nullptr); // Test whether the two type vars unify. Never commits the result. diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index 29a2c6b54..c3de8d0e1 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -16,7 +16,6 @@ LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); LUAU_FASTFLAGVARIABLE(LuauMissingFollowACMetatables, false); -LUAU_FASTFLAGVARIABLE(PreferToCallFunctionsForIntersects, false); LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); static const std::unordered_set kStatementStartingKeywords = { @@ -272,55 +271,34 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ TypeId expectedType = follow(*typeAtPosition); - if (FFlag::PreferToCallFunctionsForIntersects) - { - auto checkFunctionType = [&canUnify, &expectedType](const FunctionTypeVar* ftv) { - auto [retHead, retTail] = flatten(ftv->retType); - - if (!retHead.empty() && canUnify(retHead.front(), expectedType)) - return true; - - // We might only have a variadic tail pack, check if the element is compatible - if (retTail) - { - if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) - return true; - } + auto checkFunctionType = [&canUnify, &expectedType](const FunctionTypeVar* ftv) { + auto [retHead, retTail] = flatten(ftv->retType); - return false; - }; + if (!retHead.empty() && canUnify(retHead.front(), expectedType)) + return true; - // We also want to suggest functions that return compatible result - if (const FunctionTypeVar* ftv = get(ty); ftv && checkFunctionType(ftv)) + // We might only have a variadic tail pack, check if the element is compatible + if (retTail) { - return TypeCorrectKind::CorrectFunctionResult; - } - else if (const IntersectionTypeVar* itv = get(ty)) - { - for (TypeId id : itv->parts) - { - if (const FunctionTypeVar* ftv = get(id); ftv && checkFunctionType(ftv)) - { - return TypeCorrectKind::CorrectFunctionResult; - } - } + if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) + return true; } + + return false; + }; + + // We also want to suggest functions that return compatible result + if (const FunctionTypeVar* ftv = get(ty); ftv && checkFunctionType(ftv)) + { + return TypeCorrectKind::CorrectFunctionResult; } - else + else if (const IntersectionTypeVar* itv = get(ty)) { - // We also want to suggest functions that return compatible result - if (const FunctionTypeVar* ftv = get(ty)) + for (TypeId id : itv->parts) { - auto [retHead, retTail] = flatten(ftv->retType); - - if (!retHead.empty() && canUnify(retHead.front(), expectedType)) - return TypeCorrectKind::CorrectFunctionResult; - - // We might only have a variadic tail pack, check if the element is compatible - if (retTail) + if (const FunctionTypeVar* ftv = get(id); ftv && checkFunctionType(ftv)) { - if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) - return TypeCorrectKind::CorrectFunctionResult; + return TypeCorrectKind::CorrectFunctionResult; } } } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index d72422a53..e4e5dab82 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -9,6 +9,7 @@ #include LUAU_FASTFLAG(LuauAssertStripsFalsyTypes) +LUAU_FASTFLAGVARIABLE(LuauTableCloneType, false) /** FIXME: Many of these type definitions are not quite completely accurate. * @@ -283,8 +284,16 @@ void registerBuiltinTypes(TypeChecker& typeChecker) attachMagicFunction(getGlobalBinding(typeChecker, "setmetatable"), magicFunctionSetMetaTable); attachMagicFunction(getGlobalBinding(typeChecker, "select"), magicFunctionSelect); - auto tableLib = getMutable(getGlobalBinding(typeChecker, "table")); - attachMagicFunction(tableLib->props["pack"].type, magicFunctionPack); + if (TableTypeVar* ttv = getMutable(getGlobalBinding(typeChecker, "table"))) + { + // tabTy is a generic table type which we can't express via declaration syntax yet + ttv->props["freeze"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.freeze"); + + if (FFlag::LuauTableCloneType) + ttv->props["clone"] = makeProperty(makeFunction(arena, std::nullopt, {tabTy}, {tabTy}), "@luau/global/table.clone"); + + attachMagicFunction(ttv->props["pack"].type, magicFunctionPack); + } attachMagicFunction(getGlobalBinding(typeChecker, "require"), magicFunctionRequire); } diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index bf6e1193f..471b61ad8 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -170,7 +170,6 @@ declare function gcinfo(): number move: ({V}, number, number, number, {V}?) -> {V}, clear: ({[K]: V}) -> (), - freeze: ({[K]: V}) -> {[K]: V}, isfrozen: ({[K]: V}) -> boolean, } diff --git a/Analysis/src/JsonEncoder.cpp b/Analysis/src/JsonEncoder.cpp index ec3991581..811e7c24c 100644 --- a/Analysis/src/JsonEncoder.cpp +++ b/Analysis/src/JsonEncoder.cpp @@ -5,8 +5,6 @@ #include "Luau/StringUtils.h" #include "Luau/Common.h" -LUAU_FASTFLAG(LuauTypeAliasDefaults) - namespace Luau { @@ -369,38 +367,24 @@ struct AstJsonEncoder : public AstVisitor void write(const AstGenericType& genericType) { - if (FFlag::LuauTypeAliasDefaults) - { - writeRaw("{"); - bool c = pushComma(); - write("name", genericType.name); - if (genericType.defaultValue) - write("type", genericType.defaultValue); - popComma(c); - writeRaw("}"); - } - else - { - write(genericType.name); - } + writeRaw("{"); + bool c = pushComma(); + write("name", genericType.name); + if (genericType.defaultValue) + write("type", genericType.defaultValue); + popComma(c); + writeRaw("}"); } void write(const AstGenericTypePack& genericTypePack) { - if (FFlag::LuauTypeAliasDefaults) - { - writeRaw("{"); - bool c = pushComma(); - write("name", genericTypePack.name); - if (genericTypePack.defaultValue) - write("type", genericTypePack.defaultValue); - popComma(c); - writeRaw("}"); - } - else - { - write(genericTypePack.name); - } + writeRaw("{"); + bool c = pushComma(); + write("name", genericTypePack.name); + if (genericTypePack.defaultValue) + write("type", genericTypePack.defaultValue); + popComma(c); + writeRaw("}"); } void write(AstExprTable::Item::Kind kind) diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 7635dc0ff..56c4e3e89 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -13,6 +13,7 @@ #include LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) +LUAU_FASTFLAGVARIABLE(LuauLintGlobalNeverReadBeforeWritten, false) namespace Luau { @@ -233,6 +234,20 @@ class LintGlobalLocal : AstVisitor } private: + struct FunctionInfo + { + explicit FunctionInfo(AstExprFunction* ast) + : ast(ast) + , dominatedGlobals({}) + , conditionalExecution(false) + { + } + + AstExprFunction* ast; + DenseHashSet dominatedGlobals; + bool conditionalExecution; + }; + struct Global { AstExprGlobal* firstRef = nullptr; @@ -241,6 +256,9 @@ class LintGlobalLocal : AstVisitor bool assigned = false; bool builtin = false; + bool definedInModuleScope = false; + bool definedAsFunction = false; + bool readBeforeWritten = false; std::optional deprecated; }; @@ -248,7 +266,8 @@ class LintGlobalLocal : AstVisitor DenseHashMap globals; std::vector globalRefs; - std::vector functionStack; + std::vector functionStack; + LintGlobalLocal() : globals(AstName()) @@ -291,12 +310,18 @@ class LintGlobalLocal : AstVisitor "Global '%s' is only used in the enclosing function defined at line %d; consider changing it to local", g.firstRef->name.value, top->location.begin.line + 1); } + else if (FFlag::LuauLintGlobalNeverReadBeforeWritten && g.assigned && !g.readBeforeWritten && !g.definedInModuleScope && + g.firstRef->name != context->placeholder) + { + emitWarning(*context, LintWarning::Code_GlobalUsedAsLocal, g.firstRef->location, + "Global '%s' is never read before being written. Consider changing it to local", g.firstRef->name.value); + } } } bool visit(AstExprFunction* node) override { - functionStack.push_back(node); + functionStack.emplace_back(node); node->body->visit(this); @@ -307,6 +332,11 @@ class LintGlobalLocal : AstVisitor bool visit(AstExprGlobal* node) override { + if (FFlag::LuauLintGlobalNeverReadBeforeWritten && !functionStack.empty() && !functionStack.back().dominatedGlobals.contains(node->name)) + { + Global& g = globals[node->name]; + g.readBeforeWritten = true; + } trackGlobalRef(node); if (node->name == context->placeholder) @@ -335,6 +365,21 @@ class LintGlobalLocal : AstVisitor { Global& g = globals[gv->name]; + if (FFlag::LuauLintGlobalNeverReadBeforeWritten) + { + if (functionStack.empty()) + { + g.definedInModuleScope = true; + } + else + { + if (!functionStack.back().conditionalExecution) + { + functionStack.back().dominatedGlobals.insert(gv->name); + } + } + } + if (g.builtin) emitWarning(*context, LintWarning::Code_BuiltinGlobalWrite, gv->location, "Built-in global '%s' is overwritten here; consider using a local or changing the name", gv->name.value); @@ -369,7 +414,14 @@ class LintGlobalLocal : AstVisitor emitWarning(*context, LintWarning::Code_BuiltinGlobalWrite, gv->location, "Built-in global '%s' is overwritten here; consider using a local or changing the name", gv->name.value); else + { g.assigned = true; + if (FFlag::LuauLintGlobalNeverReadBeforeWritten) + { + g.definedAsFunction = true; + g.definedInModuleScope = functionStack.empty(); + } + } trackGlobalRef(gv); } @@ -377,6 +429,98 @@ class LintGlobalLocal : AstVisitor return true; } + class HoldConditionalExecution + { + public: + HoldConditionalExecution(LintGlobalLocal& p) + : p(p) + { + if (!p.functionStack.empty() && !p.functionStack.back().conditionalExecution) + { + resetToFalse = true; + p.functionStack.back().conditionalExecution = true; + } + } + ~HoldConditionalExecution() + { + if (resetToFalse) + p.functionStack.back().conditionalExecution = false; + } + + private: + bool resetToFalse = false; + LintGlobalLocal& p; + }; + + bool visit(AstStatIf* node) override + { + if (!FFlag::LuauLintGlobalNeverReadBeforeWritten) + return true; + + HoldConditionalExecution ce(*this); + node->condition->visit(this); + node->thenbody->visit(this); + if (node->elsebody) + node->elsebody->visit(this); + + return false; + } + + bool visit(AstStatWhile* node) override + { + if (!FFlag::LuauLintGlobalNeverReadBeforeWritten) + return true; + + HoldConditionalExecution ce(*this); + node->condition->visit(this); + node->body->visit(this); + + return false; + } + + bool visit(AstStatRepeat* node) override + { + if (!FFlag::LuauLintGlobalNeverReadBeforeWritten) + return true; + + HoldConditionalExecution ce(*this); + node->condition->visit(this); + node->body->visit(this); + + return false; + } + + bool visit(AstStatFor* node) override + { + if (!FFlag::LuauLintGlobalNeverReadBeforeWritten) + return true; + + HoldConditionalExecution ce(*this); + node->from->visit(this); + node->to->visit(this); + + if (node->step) + node->step->visit(this); + + node->body->visit(this); + + return false; + } + + bool visit(AstStatForIn* node) override + { + if (!FFlag::LuauLintGlobalNeverReadBeforeWritten) + return true; + + HoldConditionalExecution ce(*this); + for (AstExpr* expr : node->values) + expr->visit(this); + + node->body->visit(this); + + return false; + } + void trackGlobalRef(AstExprGlobal* node) { Global& g = globals[node->name]; @@ -390,7 +534,12 @@ class LintGlobalLocal : AstVisitor // to reduce the cost of tracking we only track this for user globals if (!g.builtin) { - g.functionRef = functionStack; + g.functionRef.clear(); + g.functionRef.reserve(functionStack.size()); + for (const FunctionInfo& entry : functionStack) + { + g.functionRef.push_back(entry.ast); + } } } else @@ -401,7 +550,7 @@ class LintGlobalLocal : AstVisitor // we need to find a common prefix between all uses of a global size_t prefix = 0; - while (prefix < g.functionRef.size() && prefix < functionStack.size() && g.functionRef[prefix] == functionStack[prefix]) + while (prefix < g.functionRef.size() && prefix < functionStack.size() && g.functionRef[prefix] == functionStack[prefix].ast) prefix++; g.functionRef.resize(prefix); diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 412b78bbb..76dc72d22 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -14,7 +14,6 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) // Remove with FFlagLuauImmutableTypes LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) -LUAU_FASTFLAG(LuauTypeAliasDefaults) LUAU_FASTFLAG(LuauImmutableTypes) namespace Luau @@ -463,7 +462,7 @@ TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, See TypeId ty = clone(param.ty, dest, seenTypes, seenTypePacks, cloneState); std::optional defaultValue; - if (FFlag::LuauTypeAliasDefaults && param.defaultValue) + if (param.defaultValue) defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); result.typeParams.push_back({ty, defaultValue}); @@ -474,7 +473,7 @@ TypeFun clone(const TypeFun& typeFun, TypeArena& dest, SeenTypes& seenTypes, See TypePackId tp = clone(param.tp, dest, seenTypes, seenTypePacks, cloneState); std::optional defaultValue; - if (FFlag::LuauTypeAliasDefaults && param.defaultValue) + if (param.defaultValue) defaultValue = clone(*param.defaultValue, dest, seenTypes, seenTypePacks, cloneState); result.typePackParams.push_back({tp, defaultValue}); diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 5e79b8413..010ca3612 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -10,8 +10,6 @@ #include #include -LUAU_FASTFLAG(LuauTypeAliasDefaults) - /* * Prefix generic typenames with gen- * Additionally, free types will be prefixed with free- and suffixed with their level. eg free-a-4 @@ -288,28 +286,15 @@ struct TypeVarStringifier else first = false; - if (FFlag::LuauTypeAliasDefaults) - { - bool wrap = !singleTp && get(follow(tp)); - - if (wrap) - state.emit("("); + bool wrap = !singleTp && get(follow(tp)); - stringify(tp); - - if (wrap) - state.emit(")"); - } - else - { - if (!singleTp) - state.emit("("); + if (wrap) + state.emit("("); - stringify(tp); + stringify(tp); - if (!singleTp) - state.emit(")"); - } + if (wrap) + state.emit(")"); } if (types.size() || typePacks.size()) @@ -1105,100 +1090,8 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts) return toString(const_cast(&tp), std::move(opts)); } -std::string toStringNamedFunction_DEPRECATED(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) -{ - std::string s = prefix; - - auto toString_ = [&opts](TypeId ty) -> std::string { - ToStringResult res = toStringDetailed(ty, opts); - opts.nameMap = std::move(res.nameMap); - return res.name; - }; - - auto toStringPack_ = [&opts](TypePackId ty) -> std::string { - ToStringResult res = toStringDetailed(ty, opts); - opts.nameMap = std::move(res.nameMap); - return res.name; - }; - - if (!opts.hideNamedFunctionTypeParameters && (!ftv.generics.empty() || !ftv.genericPacks.empty())) - { - s += "<"; - - bool first = true; - for (TypeId g : ftv.generics) - { - if (!first) - s += ", "; - first = false; - s += toString_(g); - } - - for (TypePackId gp : ftv.genericPacks) - { - if (!first) - s += ", "; - first = false; - s += toStringPack_(gp); - } - - s += ">"; - } - - s += "("; - - auto argPackIter = begin(ftv.argTypes); - auto argNameIter = ftv.argNames.begin(); - - bool first = true; - while (argPackIter != end(ftv.argTypes)) - { - if (!first) - s += ", "; - first = false; - - // We don't currently respect opts.functionTypeArguments. I don't think this function should. - if (argNameIter != ftv.argNames.end()) - { - s += (*argNameIter ? (*argNameIter)->name : "_") + ": "; - ++argNameIter; - } - else - { - s += "_: "; - } - - s += toString_(*argPackIter); - ++argPackIter; - } - - if (argPackIter.tail()) - { - if (auto vtp = get(*argPackIter.tail())) - s += ", ...: " + toString_(vtp->ty); - else - s += ", ...: " + toStringPack_(*argPackIter.tail()); - } - - s += "): "; - - size_t retSize = size(ftv.retType); - bool hasTail = !finite(ftv.retType); - if (retSize == 0 && !hasTail) - s += "()"; - else if ((retSize == 0 && hasTail) || (retSize == 1 && !hasTail)) - s += toStringPack_(ftv.retType); - else - s += "(" + toStringPack_(ftv.retType) + ")"; - - return s; -} - std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) { - if (!FFlag::LuauTypeAliasDefaults) - return toStringNamedFunction_DEPRECATED(prefix, ftv, opts); - ToStringResult result; StringifierState state(opts, result, opts.nameMap); TypeVarStringifier tvs{state}; diff --git a/Analysis/src/Transpiler.cpp b/Analysis/src/Transpiler.cpp index a02d396bd..92ed241ea 100644 --- a/Analysis/src/Transpiler.cpp +++ b/Analysis/src/Transpiler.cpp @@ -10,8 +10,6 @@ #include #include -LUAU_FASTFLAG(LuauTypeAliasDefaults) - namespace { bool isIdentifierStartChar(char c) @@ -796,21 +794,14 @@ struct Printer { comma(); - if (FFlag::LuauTypeAliasDefaults) - { - writer.advance(o.location.begin); - writer.identifier(o.name.value); - - if (o.defaultValue) - { - writer.maybeSpace(o.defaultValue->location.begin, 2); - writer.symbol("="); - visualizeTypeAnnotation(*o.defaultValue); - } - } - else + writer.advance(o.location.begin); + writer.identifier(o.name.value); + + if (o.defaultValue) { - writer.identifier(o.name.value); + writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.symbol("="); + visualizeTypeAnnotation(*o.defaultValue); } } @@ -818,23 +809,15 @@ struct Printer { comma(); - if (FFlag::LuauTypeAliasDefaults) - { - writer.advance(o.location.begin); - writer.identifier(o.name.value); - writer.symbol("..."); - - if (o.defaultValue) - { - writer.maybeSpace(o.defaultValue->location.begin, 2); - writer.symbol("="); - visualizeTypePackAnnotation(*o.defaultValue, false); - } - } - else + writer.advance(o.location.begin); + writer.identifier(o.name.value); + writer.symbol("..."); + + if (o.defaultValue) { - writer.identifier(o.name.value); - writer.symbol("..."); + writer.maybeSpace(o.defaultValue->location.begin, 2); + writer.symbol("="); + visualizeTypePackAnnotation(*o.defaultValue, false); } } @@ -882,18 +865,14 @@ struct Printer { comma(); - if (FFlag::LuauTypeAliasDefaults) - writer.advance(o.location.begin); - + writer.advance(o.location.begin); writer.identifier(o.name.value); } for (const auto& o : func.genericPacks) { comma(); - if (FFlag::LuauTypeAliasDefaults) - writer.advance(o.location.begin); - + writer.advance(o.location.begin); writer.identifier(o.name.value); writer.symbol("..."); } @@ -1023,18 +1002,14 @@ struct Printer { comma(); - if (FFlag::LuauTypeAliasDefaults) - writer.advance(o.location.begin); - + writer.advance(o.location.begin); writer.identifier(o.name.value); } for (const auto& o : a->genericPacks) { comma(); - if (FFlag::LuauTypeAliasDefaults) - writer.advance(o.location.begin); - + writer.advance(o.location.begin); writer.identifier(o.name.value); writer.symbol("..."); } diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index 00067bdd1..c7bf1e62c 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -60,23 +60,53 @@ void DEPRECATED_TxnLog::concat(DEPRECATED_TxnLog rhs) } bool DEPRECATED_TxnLog::haveSeen(TypeId lhs, TypeId rhs) +{ + return haveSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +void DEPRECATED_TxnLog::pushSeen(TypeId lhs, TypeId rhs) +{ + pushSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +void DEPRECATED_TxnLog::popSeen(TypeId lhs, TypeId rhs) +{ + popSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +bool DEPRECATED_TxnLog::haveSeen(TypePackId lhs, TypePackId rhs) +{ + return haveSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +void DEPRECATED_TxnLog::pushSeen(TypePackId lhs, TypePackId rhs) +{ + pushSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +void DEPRECATED_TxnLog::popSeen(TypePackId lhs, TypePackId rhs) +{ + popSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +bool DEPRECATED_TxnLog::haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) { LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)); } -void DEPRECATED_TxnLog::pushSeen(TypeId lhs, TypeId rhs) +void DEPRECATED_TxnLog::pushSeen(TypeOrPackId lhs, TypeOrPackId rhs) { LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); sharedSeen->push_back(sortedPair); } -void DEPRECATED_TxnLog::popSeen(TypeId lhs, TypeId rhs) +void DEPRECATED_TxnLog::popSeen(TypeOrPackId lhs, TypeOrPackId rhs) { LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); LUAU_ASSERT(sortedPair == sharedSeen->back()); sharedSeen->pop_back(); } @@ -186,10 +216,40 @@ TxnLog TxnLog::inverse() } bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) const +{ + return haveSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +void TxnLog::pushSeen(TypeId lhs, TypeId rhs) +{ + pushSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +void TxnLog::popSeen(TypeId lhs, TypeId rhs) +{ + popSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +bool TxnLog::haveSeen(TypePackId lhs, TypePackId rhs) const +{ + return haveSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +void TxnLog::pushSeen(TypePackId lhs, TypePackId rhs) +{ + pushSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +void TxnLog::popSeen(TypePackId lhs, TypePackId rhs) +{ + popSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); +} + +bool TxnLog::haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) const { LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); if (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)) { return true; @@ -203,19 +263,19 @@ bool TxnLog::haveSeen(TypeId lhs, TypeId rhs) const return false; } -void TxnLog::pushSeen(TypeId lhs, TypeId rhs) +void TxnLog::pushSeen(TypeOrPackId lhs, TypeOrPackId rhs) { LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); sharedSeen->push_back(sortedPair); } -void TxnLog::popSeen(TypeId lhs, TypeId rhs) +void TxnLog::popSeen(TypeOrPackId lhs, TypeOrPackId rhs) { LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); + const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); LUAU_ASSERT(sortedPair == sharedSeen->back()); sharedSeen->pop_back(); } diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index faf60eb3e..8e6b3b52f 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -29,22 +29,20 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) LUAU_FASTFLAGVARIABLE(LuauImmutableTypes, false) -LUAU_FASTFLAGVARIABLE(LuauQuantifyInPlace2, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions2, false) -LUAU_FASTFLAGVARIABLE(LuauTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) +LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) -LUAU_FASTFLAGVARIABLE(LuauAscribeCorrectLevelToInferredProperitesOfFreeTables, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) LUAU_FASTFLAGVARIABLE(LuauAssertStripsFalsyTypes, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. -LUAU_FASTFLAGVARIABLE(LuauAnotherTypeLevelFix, false) LUAU_FASTFLAG(LuauWidenIfSupertypeIsFree) LUAU_FASTFLAGVARIABLE(LuauDoNotTryToReduce, false) +LUAU_FASTFLAGVARIABLE(LuauDoNotAccidentallyDependOnPointerOrdering, false) namespace Luau { @@ -445,7 +443,7 @@ void TypeChecker::checkBlock(const ScopePtr& scope, const AstStatBlock& block) // function f(x:a):a local x: number = g(37) return x end // function g(x:number):number return f(x) end // ``` - if (FFlag::LuauQuantifyInPlace2 ? containsFunctionCallOrReturn(**protoIter) : containsFunctionCall(**protoIter)) + if (containsFunctionCallOrReturn(**protoIter)) { while (checkIter != protoIter) { @@ -1676,7 +1674,7 @@ std::optional TypeChecker::getIndexTypeFromType( } else if (tableType->state == TableState::Free) { - TypeId result = FFlag::LuauAscribeCorrectLevelToInferredProperitesOfFreeTables ? freshType(tableType->level) : freshType(scope); + TypeId result = freshType(tableType->level); tableType->props[name] = {result}; return result; } @@ -1776,31 +1774,62 @@ std::optional TypeChecker::getIndexTypeFromType( std::vector TypeChecker::reduceUnion(const std::vector& types) { - std::set s; - - for (TypeId t : types) + if (FFlag::LuauDoNotAccidentallyDependOnPointerOrdering) { - if (const UnionTypeVar* utv = get(follow(t))) + std::vector result; + for (TypeId t : types) { - std::vector r = reduceUnion(utv->options); - for (TypeId ty : r) - s.insert(ty); + t = follow(t); + if (get(t) || get(t)) + return {t}; + + if (const UnionTypeVar* utv = get(t)) + { + std::vector r = reduceUnion(utv->options); + for (TypeId ty : r) + { + ty = follow(ty); + if (get(ty) || get(ty)) + return {ty}; + + if (std::find(result.begin(), result.end(), ty) == result.end()) + result.push_back(ty); + } + } + else if (std::find(result.begin(), result.end(), t) == result.end()) + result.push_back(t); } - else - s.insert(t); - } - // If any of them are ErrorTypeVars/AnyTypeVars, decay into them. - for (TypeId t : s) - { - t = follow(t); - if (get(t) || get(t)) - return {t}; + return result; } + else + { + std::set s; - std::vector r(s.begin(), s.end()); - std::sort(r.begin(), r.end()); - return r; + for (TypeId t : types) + { + if (const UnionTypeVar* utv = get(follow(t))) + { + std::vector r = reduceUnion(utv->options); + for (TypeId ty : r) + s.insert(ty); + } + else + s.insert(t); + } + + // If any of them are ErrorTypeVars/AnyTypeVars, decay into them. + for (TypeId t : s) + { + t = follow(t); + if (get(t) || get(t)) + return {t}; + } + + std::vector r(s.begin(), s.end()); + std::sort(r.begin(), r.end()); + return r; + } } std::optional TypeChecker::tryStripUnionFromNil(TypeId ty) @@ -2811,7 +2840,7 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex } else if (exprTable->state == TableState::Unsealed || exprTable->state == TableState::Free) { - TypeId resultType = freshType(FFlag::LuauAnotherTypeLevelFix ? exprTable->level : scope->level); + TypeId resultType = freshType(exprTable->level); exprTable->indexer = TableIndexer{anyIfNonstrict(indexType), anyIfNonstrict(resultType)}; return resultType; } @@ -4453,51 +4482,6 @@ TypePackId ReplaceGenerics::clean(TypePackId tp) return addTypePack(TypePackVar(FreeTypePack{level})); } -bool Quantification::isDirty(TypeId ty) -{ - if (const TableTypeVar* ttv = log->getMutable(ty)) - return level.subsumes(ttv->level) && ((ttv->state == TableState::Free) || (ttv->state == TableState::Unsealed)); - else if (const FreeTypeVar* ftv = log->getMutable(ty)) - return level.subsumes(ftv->level); - else - return false; -} - -bool Quantification::isDirty(TypePackId tp) -{ - if (const FreeTypePack* ftv = log->getMutable(tp)) - return level.subsumes(ftv->level); - else - return false; -} - -TypeId Quantification::clean(TypeId ty) -{ - LUAU_ASSERT(isDirty(ty)); - if (const TableTypeVar* ttv = log->getMutable(ty)) - { - TableState state = (ttv->state == TableState::Unsealed ? TableState::Sealed : TableState::Generic); - TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, level, state}; - clone.methodDefinitionLocations = ttv->methodDefinitionLocations; - clone.definitionModuleName = ttv->definitionModuleName; - return addType(std::move(clone)); - } - else - { - TypeId generic = addType(GenericTypeVar{level}); - generics.push_back(generic); - return generic; - } -} - -TypePackId Quantification::clean(TypePackId tp) -{ - LUAU_ASSERT(isDirty(tp)); - TypePackId genericPack = addTypePack(TypePackVar(GenericTypePack{level})); - genericPacks.push_back(genericPack); - return genericPack; -} - bool Anyification::isDirty(TypeId ty) { if (const TableTypeVar* ttv = log->getMutable(ty)) @@ -4550,29 +4534,8 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location if (!ftv || !ftv->generics.empty() || !ftv->genericPacks.empty()) return ty; - if (FFlag::LuauQuantifyInPlace2) - { - Luau::quantify(ty, scope->level); - return ty; - } - - Quantification quantification{¤tModule->internalTypes, scope->level}; - std::optional qty = quantification.substitute(ty); - - if (!qty.has_value()) - { - reportError(location, UnificationTooComplex{}); - return errorRecoveryType(scope); - } - - if (ty == *qty) - return ty; - - FunctionTypeVar* qftv = getMutable(*qty); - LUAU_ASSERT(qftv); - qftv->generics = std::move(quantification.generics); - qftv->genericPacks = std::move(quantification.genericPacks); - return *qty; + Luau::quantify(ty, scope->level); + return ty; } TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log) @@ -4915,35 +4878,20 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (lit->parameters.size == 0 && tf->typeParams.empty() && tf->typePackParams.empty()) return tf->type; - bool hasDefaultTypes = false; - bool hasDefaultPacks = false; bool parameterCountErrorReported = false; + bool hasDefaultTypes = std::any_of(tf->typeParams.begin(), tf->typeParams.end(), [](auto&& el) { + return el.defaultValue.has_value(); + }); + bool hasDefaultPacks = std::any_of(tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& el) { + return el.defaultValue.has_value(); + }); - if (FFlag::LuauTypeAliasDefaults) - { - hasDefaultTypes = std::any_of(tf->typeParams.begin(), tf->typeParams.end(), [](auto&& el) { - return el.defaultValue.has_value(); - }); - hasDefaultPacks = std::any_of(tf->typePackParams.begin(), tf->typePackParams.end(), [](auto&& el) { - return el.defaultValue.has_value(); - }); - - if (!lit->hasParameterList) - { - if ((!tf->typeParams.empty() && !hasDefaultTypes) || (!tf->typePackParams.empty() && !hasDefaultPacks)) - { - reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); - parameterCountErrorReported = true; - if (!FFlag::LuauErrorRecoveryType) - return errorRecoveryType(scope); - } - } - } - else + if (!lit->hasParameterList) { - if (!lit->hasParameterList && !tf->typePackParams.empty()) + if ((!tf->typeParams.empty() && !hasDefaultTypes) || (!tf->typePackParams.empty() && !hasDefaultPacks)) { reportError(TypeError{annotation.location, GenericError{"Type parameter list is required"}}); + parameterCountErrorReported = true; if (!FFlag::LuauErrorRecoveryType) return errorRecoveryType(scope); } @@ -4986,72 +4934,69 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (typePackParams.empty() && !extraTypes.empty()) typePackParams.push_back(addTypePack(extraTypes)); - if (FFlag::LuauTypeAliasDefaults) - { - size_t typesProvided = typeParams.size(); - size_t typesRequired = tf->typeParams.size(); + size_t typesProvided = typeParams.size(); + size_t typesRequired = tf->typeParams.size(); - size_t packsProvided = typePackParams.size(); - size_t packsRequired = tf->typePackParams.size(); + size_t packsProvided = typePackParams.size(); + size_t packsRequired = tf->typePackParams.size(); - bool notEnoughParameters = - (typesProvided < typesRequired && packsProvided == 0) || (typesProvided == typesRequired && packsProvided < packsRequired); - bool hasDefaultParameters = hasDefaultTypes || hasDefaultPacks; + bool notEnoughParameters = + (typesProvided < typesRequired && packsProvided == 0) || (typesProvided == typesRequired && packsProvided < packsRequired); + bool hasDefaultParameters = hasDefaultTypes || hasDefaultPacks; - // Add default type and type pack parameters if that's required and it's possible - if (notEnoughParameters && hasDefaultParameters) - { - // 'applyTypeFunction' is used to substitute default types that reference previous generic types - ApplyTypeFunction applyTypeFunction{¤tModule->internalTypes, scope->level}; + // Add default type and type pack parameters if that's required and it's possible + if (notEnoughParameters && hasDefaultParameters) + { + // 'applyTypeFunction' is used to substitute default types that reference previous generic types + ApplyTypeFunction applyTypeFunction{¤tModule->internalTypes, scope->level}; - for (size_t i = 0; i < typesProvided; ++i) - applyTypeFunction.typeArguments[tf->typeParams[i].ty] = typeParams[i]; + for (size_t i = 0; i < typesProvided; ++i) + applyTypeFunction.typeArguments[tf->typeParams[i].ty] = typeParams[i]; - if (typesProvided < typesRequired) + if (typesProvided < typesRequired) + { + for (size_t i = typesProvided; i < typesRequired; ++i) { - for (size_t i = typesProvided; i < typesRequired; ++i) - { - TypeId defaultTy = tf->typeParams[i].defaultValue.value_or(nullptr); - - if (!defaultTy) - break; + TypeId defaultTy = tf->typeParams[i].defaultValue.value_or(nullptr); - std::optional maybeInstantiated = applyTypeFunction.substitute(defaultTy); + if (!defaultTy) + break; - if (!maybeInstantiated.has_value()) - { - reportError(annotation.location, UnificationTooComplex{}); - maybeInstantiated = errorRecoveryType(scope); - } + std::optional maybeInstantiated = applyTypeFunction.substitute(defaultTy); - applyTypeFunction.typeArguments[tf->typeParams[i].ty] = *maybeInstantiated; - typeParams.push_back(*maybeInstantiated); + if (!maybeInstantiated.has_value()) + { + reportError(annotation.location, UnificationTooComplex{}); + maybeInstantiated = errorRecoveryType(scope); } + + applyTypeFunction.typeArguments[tf->typeParams[i].ty] = *maybeInstantiated; + typeParams.push_back(*maybeInstantiated); } + } - for (size_t i = 0; i < packsProvided; ++i) - applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = typePackParams[i]; + for (size_t i = 0; i < packsProvided; ++i) + applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = typePackParams[i]; - if (packsProvided < packsRequired) + if (packsProvided < packsRequired) + { + for (size_t i = packsProvided; i < packsRequired; ++i) { - for (size_t i = packsProvided; i < packsRequired; ++i) - { - TypePackId defaultTp = tf->typePackParams[i].defaultValue.value_or(nullptr); - - if (!defaultTp) - break; + TypePackId defaultTp = tf->typePackParams[i].defaultValue.value_or(nullptr); - std::optional maybeInstantiated = applyTypeFunction.substitute(defaultTp); + if (!defaultTp) + break; - if (!maybeInstantiated.has_value()) - { - reportError(annotation.location, UnificationTooComplex{}); - maybeInstantiated = errorRecoveryTypePack(scope); - } + std::optional maybeInstantiated = applyTypeFunction.substitute(defaultTp); - applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = *maybeInstantiated; - typePackParams.push_back(*maybeInstantiated); + if (!maybeInstantiated.has_value()) + { + reportError(annotation.location, UnificationTooComplex{}); + maybeInstantiated = errorRecoveryTypePack(scope); } + + applyTypeFunction.typePackArguments[tf->typePackParams[i].tp] = *maybeInstantiated; + typePackParams.push_back(*maybeInstantiated); } } } @@ -5343,12 +5288,12 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, TypeId instantiated = *maybeInstantiated; - // TODO: CLI-46926 it's not a good idea to rename the type here TypeId target = follow(instantiated); bool needsClone = follow(tf.type) == target; + bool shouldMutate = (!FFlag::LuauOnlyMutateInstantiatedTables || getTableType(tf.type)); TableTypeVar* ttv = getMutableTableType(target); - - if (ttv && needsClone) + + if (shouldMutate && ttv && needsClone) { // Substitution::clone is a shallow clone. If this is a metatable type, we // want to mutate its table, so we need to explicitly clone that table as @@ -5368,7 +5313,7 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, } } - if (ttv) + if (shouldMutate && ttv) { ttv->instantiatedTypeParams = typeParams; ttv->instantiatedTypePackParams = typePackParams; @@ -5382,7 +5327,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st { LUAU_ASSERT(scope->parent); - const TypeLevel level = (FFlag::LuauQuantifyInPlace2 && levelOpt) ? *levelOpt : scope->level; + const TypeLevel level = levelOpt.value_or(scope->level); std::vector generics; @@ -5390,7 +5335,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st { std::optional defaultValue; - if (FFlag::LuauTypeAliasDefaults && generic.defaultValue) + if (generic.defaultValue) defaultValue = resolveType(scope, *generic.defaultValue); Name n = generic.name.value; @@ -5426,7 +5371,7 @@ GenericTypeDefinitions TypeChecker::createGenericTypes(const ScopePtr& scope, st { std::optional defaultValue; - if (FFlag::LuauTypeAliasDefaults && genericPack.defaultValue) + if (genericPack.defaultValue) defaultValue = resolveTypePack(scope, *genericPack.defaultValue); Name n = genericPack.name.value; diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index a1dcfdbec..5af2c8a62 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -24,6 +24,7 @@ LUAU_FASTINTVARIABLE(LuauTypeMaximumStringifierLength, 500) LUAU_FASTINTVARIABLE(LuauTableTypeMaximumStringifierLength, 0) LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauErrorRecoveryType) +LUAU_FASTFLAG(LuauSubtypingAddOptPropsToUnsealedTables) LUAU_FASTFLAG(LuauDiscriminableUnions2) namespace Luau @@ -157,6 +158,7 @@ bool isNumber(TypeId ty) return isPrim(ty, PrimitiveTypeVar::Number); } +// Returns true when ty is a subtype of string bool isString(TypeId ty) { if (isPrim(ty, PrimitiveTypeVar::String) || get(get(follow(ty)))) @@ -168,6 +170,27 @@ bool isString(TypeId ty) return false; } +// Returns true when ty is a supertype of string +bool maybeString(TypeId ty) +{ + if (FFlag::LuauSubtypingAddOptPropsToUnsealedTables) + { + ty = follow(ty); + + if (isPrim(ty, PrimitiveTypeVar::String) || get(ty)) + return true; + + if (auto utv = get(ty)) + return std::any_of(begin(utv), end(utv), maybeString); + + return false; + } + else + { + return isString(ty); + } +} + bool isThread(TypeId ty) { return isPrim(ty, PrimitiveTypeVar::Thread); @@ -684,7 +707,7 @@ TypeId SingletonTypes::makeStringMetatable() {"sub", {makeFunction(*arena, stringType, {}, {}, {numberType, optionalNumber}, {}, {stringType})}}, {"upper", {stringToStringType}}, {"split", {makeFunction(*arena, stringType, {}, {}, {optionalString}, {}, - {arena->addType(TableTypeVar{{}, TableIndexer{numberType, stringType}, TypeLevel{}})})}}, + {arena->addType(TableTypeVar{{}, TableIndexer{numberType, stringType}, TypeLevel{}, TableState::Sealed})})}}, {"pack", {arena->addType(FunctionTypeVar{ arena->addTypePack(TypePack{{stringType}, anyTypePack}), oneStringPack, @@ -761,6 +784,8 @@ void persist(TypeId ty) } else if (auto ttv = get(t)) { + LUAU_ASSERT(ttv->state != TableState::Free && ttv->state != TableState::Unsealed); + for (const auto& [_name, prop] : ttv->props) queue.push_back(prop.type); diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index d0eba0135..6c29486a4 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -18,11 +18,13 @@ LUAU_FASTFLAG(LuauImmutableTypes) LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); -LUAU_FASTFLAGVARIABLE(LuauTableUnificationEarlyTest, false) LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauErrorRecoveryType); +LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) LUAU_FASTFLAGVARIABLE(LuauFollowWithCommittingTxnLogInAnyUnification, false) LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree, false) +LUAU_FASTFLAGVARIABLE(LuauDifferentOrderOfUnificationDoesntMatter, false) +LUAU_FASTFLAGVARIABLE(LuauTxnLogSeesTypePacks2, true) namespace Luau { @@ -329,7 +331,7 @@ Unifier::Unifier(TypeArena* types, Mode mode, const Location& location, Variance LUAU_ASSERT(sharedState.iceHandler); } -Unifier::Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, +Unifier::Unifier(TypeArena* types, Mode mode, std::vector>* sharedSeen, const Location& location, Variance variance, UnifierSharedState& sharedState, TxnLog* parentLog) : types(types) , mode(mode) @@ -656,26 +658,85 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId failed = true; } - if (FFlag::LuauUseCommittingTxnLog) + if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter) { - if (i == count - 1) - { - log.concat(std::move(innerState.log)); - } + if (!FFlag::LuauUseCommittingTxnLog) + innerState.DEPRECATED_log.rollback(); } else { - if (i != count - 1) + if (FFlag::LuauUseCommittingTxnLog) { - innerState.DEPRECATED_log.rollback(); + if (i == count - 1) + { + log.concat(std::move(innerState.log)); + } } else { - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + if (i != count - 1) + { + innerState.DEPRECATED_log.rollback(); + } + else + { + DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + } } + + ++i; } + } + + // even if A | B <: T fails, we want to bind some options of T with A | B iff A | B was a subtype of that option. + if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter) + { + auto tryBind = [this, subTy](TypeId superOption) { + superOption = FFlag::LuauUseCommittingTxnLog ? log.follow(superOption) : follow(superOption); + + // just skip if the superOption is not free-ish. + auto ttv = log.getMutable(superOption); + if (!log.is(superOption) && (!ttv || ttv->state != TableState::Free)) + return; - ++i; + // Since we have already checked if S <: T, checking it again will not queue up the type for replacement. + // So we'll have to do it ourselves. We assume they unified cleanly if they are still in the seen set. + if (FFlag::LuauUseCommittingTxnLog) + { + if (log.haveSeen(subTy, superOption)) + { + // TODO: would it be nice for TxnLog::replace to do this? + if (log.is(superOption)) + log.bindTable(superOption, subTy); + else + log.replace(superOption, *subTy); + } + } + else + { + if (DEPRECATED_log.haveSeen(subTy, superOption)) + { + if (auto ttv = getMutable(superOption)) + { + DEPRECATED_log(ttv); + ttv->boundTo = subTy; + } + else + { + DEPRECATED_log(superOption); + *asMutable(superOption) = BoundTypeVar(subTy); + } + } + } + }; + + if (auto utv = (FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy))) + { + for (TypeId ty : utv) + tryBind(ty); + } + else + tryBind(superTy); } if (unificationTooComplex) @@ -1163,6 +1224,9 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (superTp == subTp) return; + if (FFlag::LuauTxnLogSeesTypePacks2 && log.haveSeen(superTp, subTp)) + return; + if (log.getMutable(superTp)) { occursCheck(superTp, subTp); @@ -1365,6 +1429,9 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (superTp == subTp) return; + if (FFlag::LuauTxnLogSeesTypePacks2 && DEPRECATED_log.haveSeen(superTp, subTp)) + return; + if (get(superTp)) { occursCheck(superTp, subTp); @@ -1619,6 +1686,17 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal DEPRECATED_log.pushSeen(superFunction->generics[i], subFunction->generics[i]); } + if (FFlag::LuauTxnLogSeesTypePacks2) + { + for (size_t i = 0; i < numGenericPacks; i++) + { + if (FFlag::LuauUseCommittingTxnLog) + log.pushSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); + else + DEPRECATED_log.pushSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); + } + } + CountMismatch::Context context = ctx; if (!isFunctionCall) @@ -1708,6 +1786,17 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal ctx = context; + if (FFlag::LuauTxnLogSeesTypePacks2) + { + for (int i = int(numGenericPacks) - 1; 0 <= i; i--) + { + if (FFlag::LuauUseCommittingTxnLog) + log.popSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); + else + DEPRECATED_log.popSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); + } + } + for (int i = int(numGenerics) - 1; 0 <= i; i--) { if (FFlag::LuauUseCommittingTxnLog) @@ -1760,7 +1849,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) std::vector extraProperties; // Optimization: First test that the property sets are compatible without doing any recursive unification - if (FFlag::LuauTableUnificationEarlyTest && !subTable->indexer && subTable->state != TableState::Free) + if (!subTable->indexer && subTable->state != TableState::Free) { for (const auto& [propName, superProp] : superTable->props) { @@ -1769,7 +1858,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) bool isAny = FFlag::LuauUseCommittingTxnLog ? log.getMutable(log.follow(superProp.type)) : get(follow(superProp.type)); - if (subIter == subTable->props.end() && !isOptional(superProp.type) && !isAny) + if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type) && !isAny) missingProperties.push_back(propName); } @@ -1781,7 +1870,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } // And vice versa if we're invariant - if (FFlag::LuauTableUnificationEarlyTest && variance == Invariant && !superTable->indexer && superTable->state != TableState::Unsealed && + if (variance == Invariant && !superTable->indexer && superTable->state != TableState::Unsealed && superTable->state != TableState::Free) { for (const auto& [propName, subProp] : subTable->props) @@ -1790,7 +1879,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) bool isAny = FFlag::LuauUseCommittingTxnLog ? log.getMutable(log.follow(subProp.type)) : get(follow(subProp.type)); - if (superIter == superTable->props.end() && !isOptional(subProp.type) && !isAny) + if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || (!isOptional(subProp.type) && !isAny))) extraProperties.push_back(propName); } @@ -1830,7 +1919,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) innerState.DEPRECATED_log.rollback(); } } - else if (subTable->indexer && isString(subTable->indexer->indexType)) + else if (subTable->indexer && maybeString(subTable->indexer->indexType)) { // TODO: read-only indexers don't need invariance // TODO: really we should only allow this if prop.type is optional. @@ -1855,9 +1944,11 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) innerState.DEPRECATED_log.rollback(); } } - else if (isOptional(prop.type) || get(follow(prop.type))) - // TODO: this case is unsound, but without it our test suite fails. CLI-46031 + else if ((!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && (isOptional(prop.type) || get(follow(prop.type)))) + // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` + // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. // TODO: should isOptional(anyType) be true? + // TODO: if the supertype is written to, the subtype may no longer be precise (alias analysis?) { } else if (subTable->state == TableState::Free) @@ -1887,7 +1978,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // If both lt and rt contain the property, then // we're done since we already unified them above } - else if (superTable->indexer && isString(superTable->indexer->indexType)) + else if (superTable->indexer && maybeString(superTable->indexer->indexType)) { // TODO: read-only indexers don't need invariance // TODO: really we should only allow this if prop.type is optional. @@ -1936,9 +2027,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else if (variance == Covariant) { } - else if (isOptional(prop.type) || get(follow(prop.type))) - // TODO: this case is unsound, but without it our test suite fails. CLI-46031 - // TODO: should isOptional(anyType) be true? + else if (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables && (isOptional(prop.type) || get(follow(prop.type)))) { } else if (superTable->state == TableState::Free) @@ -2333,7 +2422,7 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec bool errorReported = false; // Optimization: First test that the property sets are compatible without doing any recursive unification - if (FFlag::LuauTableUnificationEarlyTest && !subTable->indexer) + if (!subTable->indexer) { for (const auto& [propName, superProp] : superTable->props) { diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 8767daa05..1cb8f1343 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -11,7 +11,6 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) -LUAU_FASTFLAGVARIABLE(LuauParseTypeAliasDefaults, false) LUAU_FASTFLAGVARIABLE(LuauParseAllHotComments, false) LUAU_FASTFLAGVARIABLE(LuauTableFieldFunctionDebugname, false) @@ -779,7 +778,7 @@ AstStat* Parser::parseTypeAlias(const Location& start, bool exported) if (!name) name = Name(nameError, lexer.current().location); - auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ FFlag::LuauParseTypeAliasDefaults); + auto [generics, genericPacks] = parseGenericTypeList(/* withDefaultValues= */ true); expectAndConsume('=', "type alias"); diff --git a/CLI/Repl.cpp b/CLI/Repl.cpp index 13304d57f..5fd6d3413 100644 --- a/CLI/Repl.cpp +++ b/CLI/Repl.cpp @@ -413,29 +413,22 @@ static void completeRepl(ic_completion_env_t* cenv, const char* editBuffer) ic_complete_word(cenv, editBuffer, icGetCompletions, isMethodOrFunctionChar); } -struct LinenoiseScopedHistory +static void loadHistory(const char* name) { - LinenoiseScopedHistory() - { - const std::string name(".luau_history"); + std::string path; - if (const char* home = getenv("HOME")) - { - historyFilepath = joinPaths(home, name); - } - else if (const char* userProfile = getenv("USERPROFILE")) - { - historyFilepath = joinPaths(userProfile, name); - } - - if (!historyFilepath.empty()) - ic_set_history(historyFilepath.c_str(), -1 /* default entries (= 200) */); + if (const char* home = getenv("HOME")) + { + path = joinPaths(home, name); + } + else if (const char* userProfile = getenv("USERPROFILE")) + { + path = joinPaths(userProfile, name); } - ~LinenoiseScopedHistory() {} - - std::string historyFilepath; -}; + if (!path.empty()) + ic_set_history(path.c_str(), -1 /* default entries (= 200) */); +} static void runReplImpl(lua_State* L) { @@ -447,8 +440,10 @@ static void runReplImpl(lua_State* L) // Prevent auto insertion of braces ic_enable_brace_insertion(false); + // Loads history from the given file; isocline automatically saves the history on process exit + loadHistory(".luau_history"); + std::string buffer; - LinenoiseScopedHistory scopedHistory; for (;;) { diff --git a/VM/include/lua.h b/VM/include/lua.h index af0e28354..0a561f274 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -265,6 +265,8 @@ LUA_API double lua_clock(); LUA_API void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(void*)); +LUA_API void lua_clonefunction(lua_State* L, int idx); + /* ** reference system, can be used to pin objects */ @@ -324,6 +326,7 @@ typedef struct lua_Debug lua_Debug; /* activation record */ /* Functions to be called by the debugger in specific events */ typedef void (*lua_Hook)(lua_State* L, lua_Debug* ar); +LUA_API int lua_stackdepth(lua_State* L); LUA_API int lua_getinfo(lua_State* L, int level, const char* what, lua_Debug* ar); LUA_API int lua_getargument(lua_State* L, int level, int n); LUA_API const char* lua_getlocal(lua_State* L, int level, int n); diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 39c76e087..f7f154428 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -14,7 +14,7 @@ #include -LUAU_FASTFLAGVARIABLE(LuauGcForwardMetatableBarrier, false) +LUAU_FASTFLAG(LuauGcAdditionalStats) const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" @@ -876,16 +876,7 @@ int lua_setmetatable(lua_State* L, int objindex) luaG_runerror(L, "Attempt to modify a readonly table"); hvalue(obj)->metatable = mt; if (mt) - { - if (FFlag::LuauGcForwardMetatableBarrier) - { - luaC_objbarrier(L, hvalue(obj), mt); - } - else - { - luaC_objbarriert(L, hvalue(obj), mt); - } - } + luaC_objbarrier(L, hvalue(obj), mt); break; } case LUA_TUSERDATA: @@ -1069,6 +1060,8 @@ int lua_gc(lua_State* L, int what, int data) g->GCthreshold = 0; bool waspaused = g->gcstate == GCSpause; + double startmarktime = g->gcstats.currcycle.marktime; + double startsweeptime = g->gcstats.currcycle.sweeptime; // track how much work the loop will actually perform size_t actualwork = 0; @@ -1086,6 +1079,31 @@ int lua_gc(lua_State* L, int what, int data) } } + if (FFlag::LuauGcAdditionalStats) + { + // record explicit step statistics + GCCycleStats* cyclestats = g->gcstate == GCSpause ? &g->gcstats.lastcycle : &g->gcstats.currcycle; + + double totalmarktime = cyclestats->marktime - startmarktime; + double totalsweeptime = cyclestats->sweeptime - startsweeptime; + + if (totalmarktime > 0.0) + { + cyclestats->markexplicitsteps++; + + if (totalmarktime > cyclestats->markmaxexplicittime) + cyclestats->markmaxexplicittime = totalmarktime; + } + + if (totalsweeptime > 0.0) + { + cyclestats->sweepexplicitsteps++; + + if (totalsweeptime > cyclestats->sweepmaxexplicittime) + cyclestats->sweepmaxexplicittime = totalsweeptime; + } + } + // if cycle hasn't finished, advance threshold forward for the amount of extra work performed if (g->gcstate != GCSpause) { @@ -1299,6 +1317,18 @@ void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(void*)) L->global->udatagc[tag] = dtor; } +LUA_API void lua_clonefunction(lua_State* L, int idx) +{ + StkId p = index2addr(L, idx); + api_check(L, isLfunction(p)); + + luaC_checkthreadsleep(L); + + Closure* cl = clvalue(p); + Closure* newcl = luaF_newLclosure(L, 0, L->gt, cl->l.p); + setclvalue(L, L->top - 1, newcl); +} + lua_Callbacks* lua_callbacks(lua_State* L) { return &L->global->cb; diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index 71975a520..9a6f77938 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -11,8 +11,6 @@ #include -LUAU_FASTFLAG(LuauSchubfach) - /* convert a stack index to positive */ #define abs_index(L, i) ((i) > 0 || (i) <= LUA_REGISTRYINDEX ? (i) : lua_gettop(L) + (i) + 1) @@ -480,18 +478,13 @@ const char* luaL_tolstring(lua_State* L, int idx, size_t* len) switch (lua_type(L, idx)) { case LUA_TNUMBER: - if (FFlag::LuauSchubfach) - { - double n = lua_tonumber(L, idx); - char s[LUAI_MAXNUM2STR]; - char* e = luai_num2str(s, n); - lua_pushlstring(L, s, e - s); - } - else - { - lua_pushstring(L, lua_tostring(L, idx)); - } + { + double n = lua_tonumber(L, idx); + char s[LUAI_MAXNUM2STR]; + char* e = luai_num2str(s, n); + lua_pushlstring(L, s, e - s); break; + } case LUA_TSTRING: lua_pushvalue(L, idx); break; @@ -505,29 +498,18 @@ const char* luaL_tolstring(lua_State* L, int idx, size_t* len) { const float* v = lua_tovector(L, idx); - if (FFlag::LuauSchubfach) + char s[LUAI_MAXNUM2STR * LUA_VECTOR_SIZE]; + char* e = s; + for (int i = 0; i < LUA_VECTOR_SIZE; ++i) { - char s[LUAI_MAXNUM2STR * LUA_VECTOR_SIZE]; - char* e = s; - for (int i = 0; i < LUA_VECTOR_SIZE; ++i) + if (i != 0) { - if (i != 0) - { - *e++ = ','; - *e++ = ' '; - } - e = luai_num2str(e, v[i]); + *e++ = ','; + *e++ = ' '; } - lua_pushlstring(L, s, e - s); - } - else - { -#if LUA_VECTOR_SIZE == 4 - lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2], v[3]); -#else - lua_pushfstring(L, LUA_NUMBER_FMT ", " LUA_NUMBER_FMT ", " LUA_NUMBER_FMT, v[0], v[1], v[2]); -#endif + e = luai_num2str(e, v[i]); } + lua_pushlstring(L, s, e - s); break; } default: diff --git a/VM/src/ldebug.cpp b/VM/src/ldebug.cpp index a4f93c621..7a9947b74 100644 --- a/VM/src/ldebug.cpp +++ b/VM/src/ldebug.cpp @@ -168,6 +168,11 @@ static int auxgetinfo(lua_State* L, const char* what, lua_Debug* ar, Closure* f, return status; } +int lua_stackdepth(lua_State* L) +{ + return int(L->ci - L->base_ci); +} + int lua_getinfo(lua_State* L, int level, const char* what, lua_Debug* ar) { int status = 0; diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index 8c3a20296..a656854ed 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -11,6 +11,8 @@ #include "lmem.h" #include "ludata.h" +LUAU_FASTFLAGVARIABLE(LuauGcAdditionalStats, false) + #include #define GC_SWEEPMAX 40 @@ -53,17 +55,28 @@ static void recordGcStateTime(global_State* g, int startgcstate, double seconds, case GCSpause: // record root mark time if we have switched to next state if (g->gcstate == GCSpropagate) + { g->gcstats.currcycle.marktime += seconds; + + if (FFlag::LuauGcAdditionalStats && assist) + g->gcstats.currcycle.markassisttime += seconds; + } break; case GCSpropagate: case GCSpropagateagain: g->gcstats.currcycle.marktime += seconds; + + if (FFlag::LuauGcAdditionalStats && assist) + g->gcstats.currcycle.markassisttime += seconds; break; case GCSatomic: g->gcstats.currcycle.atomictime += seconds; break; case GCSsweep: g->gcstats.currcycle.sweeptime += seconds; + + if (FFlag::LuauGcAdditionalStats && assist) + g->gcstats.currcycle.sweepassisttime += seconds; break; default: LUAU_ASSERT(!"Unexpected GC state"); @@ -78,7 +91,7 @@ static void recordGcStateTime(global_State* g, int startgcstate, double seconds, static void startGcCycleStats(global_State* g) { g->gcstats.currcycle.starttimestamp = lua_clock(); - g->gcstats.currcycle.waittime = g->gcstats.currcycle.starttimestamp - g->gcstats.lastcycle.endtimestamp; + g->gcstats.currcycle.pausetime = g->gcstats.currcycle.starttimestamp - g->gcstats.lastcycle.endtimestamp; } static void finishGcCycleStats(global_State* g) @@ -585,10 +598,21 @@ static size_t atomic(lua_State* L) LUAU_ASSERT(g->gcstate == GCSatomic); size_t work = 0; + double currts = lua_clock(); + double prevts = currts; + /* remark occasional upvalues of (maybe) dead threads */ work += remarkupvals(g); /* traverse objects caught by write barrier and by 'remarkupvals' */ work += propagateall(g); + + if (FFlag::LuauGcAdditionalStats) + { + currts = lua_clock(); + g->gcstats.currcycle.atomictimeupval += currts - prevts; + prevts = currts; + } + /* remark weak tables */ g->gray = g->weak; g->weak = NULL; @@ -596,16 +620,41 @@ static size_t atomic(lua_State* L) markobject(g, L); /* mark running thread */ markmt(g); /* mark basic metatables (again) */ work += propagateall(g); + + if (FFlag::LuauGcAdditionalStats) + { + currts = lua_clock(); + g->gcstats.currcycle.atomictimeweak += currts - prevts; + prevts = currts; + } + /* remark gray again */ g->gray = g->grayagain; g->grayagain = NULL; work += propagateall(g); - work += cleartable(L, g->weak); /* remove collected objects from weak tables */ + + if (FFlag::LuauGcAdditionalStats) + { + currts = lua_clock(); + g->gcstats.currcycle.atomictimegray += currts - prevts; + prevts = currts; + } + + /* remove collected objects from weak tables */ + work += cleartable(L, g->weak); g->weak = NULL; + + if (FFlag::LuauGcAdditionalStats) + { + currts = lua_clock(); + g->gcstats.currcycle.atomictimeclear += currts - prevts; + } + /* flip current white */ g->currentwhite = cast_byte(otherwhite(g)); g->sweepgcopage = g->allgcopages; g->gcstate = GCSsweep; + return work; } @@ -693,6 +742,9 @@ static size_t gcstep(lua_State* L, size_t limit) if (!g->gray) { + if (FFlag::LuauGcAdditionalStats) + g->gcstats.currcycle.propagatework = g->gcstats.currcycle.explicitwork + g->gcstats.currcycle.assistwork; + // perform one iteration over 'gray again' list g->gray = g->grayagain; g->grayagain = NULL; @@ -710,6 +762,10 @@ static size_t gcstep(lua_State* L, size_t limit) if (!g->gray) /* no more `gray' objects */ { + if (FFlag::LuauGcAdditionalStats) + g->gcstats.currcycle.propagateagainwork = + g->gcstats.currcycle.explicitwork + g->gcstats.currcycle.assistwork - g->gcstats.currcycle.propagatework; + g->gcstate = GCSatomic; } break; @@ -811,6 +867,12 @@ static size_t getheaptrigger(global_State* g, size_t heapgoal) void luaC_step(lua_State* L, bool assist) { global_State* g = L->global; + + if (assist) + g->gcstats.currcycle.assistrequests += g->gcstepsize; + else + g->gcstats.currcycle.explicitrequests += g->gcstepsize; + int lim = (g->gcstepsize / 100) * g->gcstepmul; /* how much to work */ LUAU_ASSERT(g->totalbytes >= g->GCthreshold); size_t debt = g->totalbytes - g->GCthreshold; @@ -833,6 +895,11 @@ void luaC_step(lua_State* L, bool assist) recordGcStateTime(g, lastgcstate, lua_clock() - lasttimestamp, assist); + if (lastgcstate == GCSpropagate) + g->gcstats.currcycle.markrequests += g->gcstepsize; + else if (lastgcstate == GCSsweep) + g->gcstats.currcycle.sweeprequests += g->gcstepsize; + // at the end of the last cycle if (g->gcstate == GCSpause) { @@ -844,6 +911,9 @@ void luaC_step(lua_State* L, bool assist) finishGcCycleStats(g); + if (FFlag::LuauGcAdditionalStats) + g->gcstats.currcycle.starttotalsizebytes = g->totalbytes; + g->gcstats.currcycle.heapgoalsizebytes = heapgoal; g->gcstats.currcycle.heaptriggersizebytes = heaptrigger; } diff --git a/VM/src/lgc.h b/VM/src/lgc.h index 253e269f9..cbeeebd48 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -111,13 +111,6 @@ luaC_barrierf(L, obj2gco(p), obj2gco(o)); \ } -// TODO: remove with FFlagLuauGcForwardMetatableBarrier -#define luaC_objbarriert(L, t, o) \ - { \ - if (isblack(obj2gco(t)) && iswhite(obj2gco(o))) \ - luaC_barriertable(L, t, obj2gco(o)); \ - } - #define luaC_upvalbarrier(L, uv, tv) \ { \ if (iscollectable(tv) && iswhite(gcvalue(tv)) && (!(uv) || ((UpVal*)uv)->v != &((UpVal*)uv)->u.value)) \ diff --git a/VM/src/lnumprint.cpp b/VM/src/lnumprint.cpp index 2fd0f1bbd..d64e3ca40 100644 --- a/VM/src/lnumprint.cpp +++ b/VM/src/lnumprint.cpp @@ -6,7 +6,6 @@ #include "lcommon.h" #include -#include // TODO: Remove with LuauSchubfach #ifdef _MSC_VER #include @@ -18,8 +17,6 @@ // The code uses the notation from the paper for local variables where appropriate, and refers to paper sections/figures/results. -LUAU_FASTFLAGVARIABLE(LuauSchubfach, false) - // 9.8.2. Precomputed table for 128-bit overestimates of powers of 10 (see figure 3 for table bounds) // To avoid storing 616 128-bit numbers directly we use a technique inspired by Dragonbox implementation and store 16 consecutive // powers using a 128-bit baseline and a bitvector with 1-bit scale and 3-bit offset for the delta between each entry and base*5^k @@ -275,12 +272,6 @@ inline char* trimzero(char* end) char* luai_num2str(char* buf, double n) { - if (!FFlag::LuauSchubfach) - { - snprintf(buf, LUAI_MAXNUM2STR, LUA_NUMBER_FMT, n); - return buf + strlen(buf); - } - // IEEE-754 union { diff --git a/VM/src/lnumutils.h b/VM/src/lnumutils.h index fba07bc3b..549b4630d 100644 --- a/VM/src/lnumutils.h +++ b/VM/src/lnumutils.h @@ -55,7 +55,6 @@ LUAU_FASTMATH_END #define luai_num2unsigned(i, n) ((i) = (unsigned)(long long)(n)) #endif -#define LUA_NUMBER_FMT "%.14g" /* TODO: Remove with LuauSchubfach */ #define LUAI_MAXNUM2STR 48 LUAI_FUNC char* luai_num2str(char* buf, double n); diff --git a/VM/src/lstate.h b/VM/src/lstate.h index 3ee967181..b2bedb486 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -77,25 +77,46 @@ typedef struct CallInfo struct GCCycleStats { + size_t starttotalsizebytes = 0; size_t heapgoalsizebytes = 0; size_t heaptriggersizebytes = 0; - double waittime = 0.0; // time from end of the last cycle to the start of a new one + double pausetime = 0.0; // time from end of the last cycle to the start of a new one double starttimestamp = 0.0; double endtimestamp = 0.0; double marktime = 0.0; + double markassisttime = 0.0; + double markmaxexplicittime = 0.0; + size_t markexplicitsteps = 0; + size_t markrequests = 0; double atomicstarttimestamp = 0.0; size_t atomicstarttotalsizebytes = 0; double atomictime = 0.0; + // specific atomic stage parts + double atomictimeupval = 0.0; + double atomictimeweak = 0.0; + double atomictimegray = 0.0; + double atomictimeclear = 0.0; + double sweeptime = 0.0; + double sweepassisttime = 0.0; + double sweepmaxexplicittime = 0.0; + size_t sweepexplicitsteps = 0; + size_t sweeprequests = 0; + + size_t assistrequests = 0; + size_t explicitrequests = 0; size_t assistwork = 0; size_t explicitwork = 0; + size_t propagatework = 0; + size_t propagateagainwork = 0; + size_t endtotalsizebytes = 0; }; diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 0412ea76d..ef0b4b93a 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -447,7 +447,8 @@ void luaH_free(lua_State* L, Table* t, lua_Page* page) { if (t->node != dummynode) luaM_freearray(L, t->node, sizenode(t), LuaNode, t->memcat); - luaM_freearray(L, t->array, t->sizearray, TValue, t->memcat); + if (t->array) + luaM_freearray(L, t->array, t->sizearray, TValue, t->memcat); luaM_freegco(L, t, sizeof(Table), t->memcat, page); } diff --git a/VM/src/ltablib.cpp b/VM/src/ltablib.cpp index 0d3374efa..007537426 100644 --- a/VM/src/ltablib.cpp +++ b/VM/src/ltablib.cpp @@ -2,6 +2,7 @@ // This code is based on Lua 5.x implementation licensed under MIT License; see lua_LICENSE.txt for details #include "lualib.h" +#include "lapi.h" #include "lstate.h" #include "ltable.h" #include "lstring.h" @@ -9,6 +10,8 @@ #include "ldebug.h" #include "lvm.h" +LUAU_FASTFLAGVARIABLE(LuauTableClone, false) + static int foreachi(lua_State* L) { luaL_checktype(L, 1, LUA_TTABLE); @@ -507,6 +510,23 @@ static int tisfrozen(lua_State* L) return 1; } +static int tclone(lua_State* L) +{ + if (!FFlag::LuauTableClone) + luaG_runerror(L, "table.clone is not available"); + + luaL_checktype(L, 1, LUA_TTABLE); + luaL_argcheck(L, !luaL_getmetafield(L, 1, "__metatable"), 1, "table has a protected metatable"); + + Table* tt = luaH_clone(L, hvalue(L->base)); + + TValue v; + sethvalue(L, &v, tt); + luaA_pushobject(L, &v); + + return 1; +} + static const luaL_Reg tab_funcs[] = { {"concat", tconcat}, {"foreach", foreach}, @@ -524,6 +544,7 @@ static const luaL_Reg tab_funcs[] = { {"clear", tclear}, {"freeze", tfreeze}, {"isfrozen", tisfrozen}, + {"clone", tclone}, {NULL, NULL}, }; diff --git a/bench/gc/test_GC_Boehm_Trees.lua b/bench/gc/test_GC_Boehm_Trees.lua index 1451769f2..5abad1d80 100644 --- a/bench/gc/test_GC_Boehm_Trees.lua +++ b/bench/gc/test_GC_Boehm_Trees.lua @@ -74,4 +74,7 @@ function test() end end +bench.runs = 6 +bench.extraRuns = 2 + bench.runCode(test, "GC: Boehm tree") diff --git a/bench/gc/test_GC_Tree_Pruning_Eager.lua b/bench/gc/test_GC_Tree_Pruning_Eager.lua index 514766ae3..2111d9ffa 100644 --- a/bench/gc/test_GC_Tree_Pruning_Eager.lua +++ b/bench/gc/test_GC_Tree_Pruning_Eager.lua @@ -40,7 +40,7 @@ function test() local tree = { id = 0 } - for i = 1,1000 do + for i = 1,100 do fill_tree(tree, 10) prune_tree(tree, 0) diff --git a/bench/gc/test_GC_Tree_Pruning_Gen.lua b/bench/gc/test_GC_Tree_Pruning_Gen.lua index a8d0f40a2..f88bd7f47 100644 --- a/bench/gc/test_GC_Tree_Pruning_Gen.lua +++ b/bench/gc/test_GC_Tree_Pruning_Gen.lua @@ -42,7 +42,7 @@ function test() local tree = { id = 0 } fill_tree(tree, 16) - for i = 1,1000 do + for i = 1,100 do local small_tree = { id = 0 } fill_tree(small_tree, 8) diff --git a/bench/gc/test_GC_Tree_Pruning_Lazy.lua b/bench/gc/test_GC_Tree_Pruning_Lazy.lua index 8cb691921..3ea6bbef1 100644 --- a/bench/gc/test_GC_Tree_Pruning_Lazy.lua +++ b/bench/gc/test_GC_Tree_Pruning_Lazy.lua @@ -46,7 +46,7 @@ function test() local tree = { id = 0 } - for i = 1,1000 do + for i = 1,100 do fill_tree(tree, 10) prune_tree(tree, 0) diff --git a/extern/isocline/include/isocline.h b/extern/isocline/include/isocline.h index 0d46cf3ff..a7e03ed2b 100644 --- a/extern/isocline/include/isocline.h +++ b/extern/isocline/include/isocline.h @@ -259,7 +259,7 @@ void ic_complete_qword( ic_completion_env_t* cenv, const char* prefix, ic_comple /// The `escape_char` is the escaping character, usually `\` but use 0 to not have escape characters. /// The `quote_chars` define the quotes, use NULL for the default `"\'\""` quotes. /// @see ic_complete_word() which uses the default values for `non_word_chars`, `quote_chars` and `\` for escape characters. -void ic_complete_qword_ex( ic_completion_env_t* cenv, const char* prefix, ic_completer_fun_t fun, +void ic_complete_qword_ex( ic_completion_env_t* cenv, const char* prefix, ic_completer_fun_t* fun, ic_is_char_class_fun_t* is_word_char, char escape_char, const char* quote_chars ); /// \} diff --git a/fuzz/number.cpp b/fuzz/number.cpp index 704474096..31c953e31 100644 --- a/fuzz/number.cpp +++ b/fuzz/number.cpp @@ -6,8 +6,6 @@ #include #include -LUAU_FASTFLAG(LuauSchubfach); - #define LUAI_MAXNUM2STR 48 char* luai_num2str(char* buf, double n); @@ -17,8 +15,6 @@ extern "C" int LLVMFuzzerTestOneInput(const uint8_t* Data, size_t Size) if (Size < 8) return 0; - FFlag::LuauSchubfach.value = true; - double num; memcpy(&num, Data, 8); diff --git a/fuzz/proto.cpp b/fuzz/proto.cpp index 912fef231..1022831b4 100644 --- a/fuzz/proto.cpp +++ b/fuzz/proto.cpp @@ -59,7 +59,7 @@ void interrupt(lua_State* L, int gc) } } -void* allocate(lua_State* L, void* ud, void* ptr, size_t osize, size_t nsize) +void* allocate(void* ud, void* ptr, size_t osize, size_t nsize) { if (nsize == 0) { diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 6aadef329..ce890ba89 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -15,6 +15,7 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) LUAU_FASTFLAG(LuauUseCommittingTxnLog) +LUAU_FASTFLAG(LuauTableCloneType) using namespace Luau; @@ -262,7 +263,7 @@ TEST_CASE_FIXTURE(ACFixture, "get_member_completions") auto ac = autocomplete('1'); - CHECK_EQ(16, ac.entryMap.size()); + CHECK_EQ(FFlag::LuauTableCloneType ? 17 : 16, ac.entryMap.size()); CHECK(ac.entryMap.count("find")); CHECK(ac.entryMap.count("pack")); CHECK(!ac.entryMap.count("math")); @@ -2235,7 +2236,7 @@ TEST_CASE_FIXTURE(ACFixture, "autocompleteSource") auto ac = autocompleteSource(frontend, source, Position{1, 24}, nullCallback).result; - CHECK_EQ(16, ac.entryMap.size()); + CHECK_EQ(FFlag::LuauTableCloneType ? 17 : 16, ac.entryMap.size()); CHECK(ac.entryMap.count("find")); CHECK(ac.entryMap.count("pack")); CHECK(!ac.entryMap.count("math")); @@ -2695,8 +2696,6 @@ local r4 = t:bar1(@4) TEST_CASE_FIXTURE(ACFixture, "autocomplete_default_type_parameters") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - check(R"( type A = () -> T )"); @@ -2709,8 +2708,6 @@ type A = () -> T TEST_CASE_FIXTURE(ACFixture, "autocomplete_default_type_pack_parameters") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - check(R"( type A = () -> T )"); @@ -2768,7 +2765,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons") TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses_2") { ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); - ScopedFastFlag preferToCallFunctionsForIntersects("PreferToCallFunctionsForIntersects", true); check(R"( local bar: ((number) -> number) & (number, number) -> number) diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index eb6ca749a..63fbb363b 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -241,6 +241,8 @@ TEST_CASE("Math") TEST_CASE("Table") { + ScopedFastFlag sff("LuauTableClone", true); + runConformance("nextvar.lua"); } @@ -465,6 +467,8 @@ static void populateRTTI(lua_State* L, Luau::TypeId type) TEST_CASE("Types") { + ScopedFastFlag sff("LuauTableCloneType", true); + runConformance("types.lua", [](lua_State* L) { Luau::NullModuleResolver moduleResolver; Luau::InternalErrorReporter iceHandler; @@ -959,8 +963,6 @@ TEST_CASE("Coverage") TEST_CASE("StringConversion") { - ScopedFastFlag sff{"LuauSchubfach", true}; - runConformance("strconv.lua"); } diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index ab19cea30..4d6c207cc 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -157,6 +157,113 @@ return bar() CHECK_EQ(result.warnings[0].text, "Global 'foo' is only used in the enclosing function 'bar'; consider changing it to local"); } +TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalMultiFx") +{ + ScopedFastFlag sff{"LuauLintGlobalNeverReadBeforeWritten", true}; + LintResult result = lint(R"( +function bar() + foo = 6 + return foo +end + +function baz() + foo = 6 + return foo +end + +return bar() + baz() +)"); + + REQUIRE_EQ(result.warnings.size(), 1); + CHECK_EQ(result.warnings[0].text, "Global 'foo' is never read before being written. Consider changing it to local"); +} + +TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalMultiFxWithRead") +{ + ScopedFastFlag sff{"LuauLintGlobalNeverReadBeforeWritten", true}; + LintResult result = lint(R"( +function bar() + foo = 6 + return foo +end + +function baz() + foo = 6 + return foo +end + +function read() + print(foo) +end + +return bar() + baz() + read() +)"); + + CHECK_EQ(result.warnings.size(), 0); +} + +TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalWithConditional") +{ + ScopedFastFlag sff{"LuauLintGlobalNeverReadBeforeWritten", true}; + LintResult result = lint(R"( +function bar() + if true then foo = 6 end + return foo +end + +function baz() + foo = 6 + return foo +end + +return bar() + baz() +)"); + + CHECK_EQ(result.warnings.size(), 0); +} + +TEST_CASE_FIXTURE(Fixture, "GlobalAsLocal3WithConditionalRead") +{ + ScopedFastFlag sff{"LuauLintGlobalNeverReadBeforeWritten", true}; + LintResult result = lint(R"( +function bar() + foo = 6 + return foo +end + +function baz() + foo = 6 + return foo +end + +function read() + if false then print(foo) end +end + +return bar() + baz() + read() +)"); + + CHECK_EQ(result.warnings.size(), 0); +} + +TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalInnerRead") +{ + ScopedFastFlag sff{"LuauLintGlobalNeverReadBeforeWritten", true}; + LintResult result = lint(R"( +function foo() + local f = function() return bar end + f() + bar = 42 +end + +function baz() bar = 0 end + +return foo() + baz() +)"); + + CHECK_EQ(result.warnings.size(), 0); +} + TEST_CASE_FIXTURE(Fixture, "GlobalAsLocalMulti") { LintResult result = lint(R"( diff --git a/tests/Parser.test.cpp b/tests/Parser.test.cpp index 77e49ce3c..7f6a6c0d4 100644 --- a/tests/Parser.test.cpp +++ b/tests/Parser.test.cpp @@ -1988,8 +1988,6 @@ TEST_CASE_FIXTURE(Fixture, "function_type_matching_parenthesis") TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - AstStat* stat = parse(R"( type A = {} type B = {} @@ -2005,8 +2003,6 @@ type G = (U...) -> T... TEST_CASE_FIXTURE(Fixture, "parse_type_alias_default_type_errors") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - matchParseError("type Y = {}", "Expected default type after type name", Location{{0, 20}, {0, 21}}); matchParseError("type Y = {}", "Expected default type pack after type pack name", Location{{0, 29}, {0, 30}}); matchParseError("type Y number> = {}", "Expected type pack after '=', got type", Location{{0, 14}, {0, 32}}); @@ -2574,8 +2570,6 @@ do end TEST_CASE_FIXTURE(Fixture, "recover_expected_type_pack") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ParseResult result = tryParse(R"( type Y = (T...) -> U... )"); diff --git a/tests/ToDot.test.cpp b/tests/ToDot.test.cpp index 0ca9c9949..29bdd8664 100644 --- a/tests/ToDot.test.cpp +++ b/tests/ToDot.test.cpp @@ -96,8 +96,6 @@ n2 [label="number"]; TEST_CASE_FIXTURE(Fixture, "function") { - ScopedFastFlag luauQuantifyInPlace2{"LuauQuantifyInPlace2", true}; - CheckResult result = check(R"( local function f(a, ...: string) return a end )"); diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index bbb262910..6713a589d 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -500,8 +500,6 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") { - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( local function f(a: number, b: string) end local function test(...: T...): U... diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 332aba9e4..5f0295b0e 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -641,9 +641,6 @@ TEST_CASE_FIXTURE(Fixture, "transpile_to_string") TEST_CASE_FIXTURE(Fixture, "transpile_type_alias_default_type_parameters") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - std::string code = R"( type Packed = (T, U, V...)->(W...) local a: Packed diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 31d7ef10b..d584eb2d2 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -625,9 +625,8 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni ScopedFastFlag sff[] = { {"LuauTwoPassAliasDefinitionFix", true}, - // We also force these two flags because this surfaced an unfortunate interaction. + // We also force this flag because it surfaced an unfortunate interaction. {"LuauErrorRecoveryType", true}, - {"LuauQuantifyInPlace2", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index f3dfb214d..bf9907703 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -934,4 +934,31 @@ TEST_CASE_FIXTURE(Fixture, "assert_returns_false_and_string_iff_it_knows_the_fir CHECK_EQ("(nil) -> nil", toString(requireType("f"))); } +TEST_CASE_FIXTURE(Fixture, "table_freeze_is_generic") +{ + CheckResult result = check(R"( + local t1: {a: number} = {a = 42} + local t2: {b: string} = {b = "hello"} + local t3: {boolean} = {false, true} + + local tf1 = table.freeze(t1) + local tf2 = table.freeze(t2) + local tf3 = table.freeze(t3) + + local a = tf1.a + local b = tf2.b + local c = tf3[2] + + local d = tf1.b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Key 'b' not found in table '{| a: number |}'", toString(result.errors[0])); + + CHECK_EQ("number", toString(requireType("a"))); + CHECK_EQ("string", toString(requireType("b"))); + CHECK_EQ("boolean", toString(requireType("c"))); + CHECK_EQ("*unknown*", toString(requireType("d"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index f8fccf6b0..c482847bf 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -697,4 +697,93 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "generic_functions_should_be_memory_safe") +{ + ScopedFastFlag sffs[] = { + { "LuauTableSubtypingVariance2", true }, + { "LuauUnsealedTableLiteral", true }, + { "LuauPropertiesGetExpectedType", true }, + { "LuauRecursiveTypeParameterRestriction", true }, + }; + + CheckResult result = check(R"( +--!strict +-- At one point this produced a UAF +type T = { a: U, b: a } +type U = { c: T?, d : a } +local x: T = { a = { c = nil, d = 5 }, b = 37 } +x.a.c = x +local y: T = { a = { c = nil, d = 5 }, b = 37 } +y.a.c = y + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ(toString(result.errors[0]), + R"(Type 'y' could not be converted into 'T' +caused by: + Property 'a' is not compatible. Type '{ c: T?, d: number }' could not be converted into 'U' +caused by: + Property 'd' is not compatible. Type 'number' could not be converted into 'string')"); +} + +TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification1") +{ + ScopedFastFlag sff{"LuauTxnLogSeesTypePacks2", true}; + + CheckResult result = check(R"( +--!strict +type Dispatcher = { + useMemo: (create: () -> T...) -> T... +} + +local TheDispatcher: Dispatcher = { + useMemo = function(create: () -> U...): U... + return create() + end +} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification2") +{ + ScopedFastFlag sff{"LuauTxnLogSeesTypePacks2", true}; + + CheckResult result = check(R"( +--!strict +type Dispatcher = { + useMemo: (create: () -> T...) -> T... +} + +local TheDispatcher: Dispatcher = { + useMemo = function(create) + return create() + end +} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "generic_type_pack_unification3") +{ + ScopedFastFlag sff{"LuauTxnLogSeesTypePacks2", true}; + + CheckResult result = check(R"( +--!strict +type Dispatcher = { + useMemo: (arg: S, create: (S) -> T...) -> T... +} + +local TheDispatcher: Dispatcher = { + useMemo = function(arg: T, create: (T) -> U...): U... + return create(arg) + end +} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.provisional.test.cpp b/tests/TypeInfer.provisional.test.cpp index eee0e0f17..2e16b21ec 100644 --- a/tests/TypeInfer.provisional.test.cpp +++ b/tests/TypeInfer.provisional.test.cpp @@ -8,7 +8,6 @@ #include LUAU_FASTFLAG(LuauEqConstraint) -LUAU_FASTFLAG(LuauQuantifyInPlace2) using namespace Luau; @@ -40,16 +39,6 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end )"; - const std::string old_expected = R"( - function f(a:{fn:()->(free,free...)}): () - if type(a) == 'boolean'then - local a1:boolean=a - elseif a.fn()then - local a2:{fn:()->(free,free...)}=a - end - end - )"; - const std::string expected = R"( function f(a:{fn:()->(a,b...)}): () if type(a) == 'boolean'then @@ -60,10 +49,7 @@ TEST_CASE_FIXTURE(Fixture, "typeguard_inference_incomplete") end )"; - if (FFlag::LuauQuantifyInPlace2) - CHECK_EQ(expected, decorateWithTypes(code)); - else - CHECK_EQ(old_expected, decorateWithTypes(code)); + CHECK_EQ(expected, decorateWithTypes(code)); } TEST_CASE_FIXTURE(Fixture, "xpcall_returns_what_f_returns") @@ -135,46 +121,6 @@ TEST_CASE_FIXTURE(Fixture, "setmetatable_constrains_free_type_into_free_table") CHECK_EQ("number", toString(tm->givenType)); } -TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table") -{ - CheckResult result = check(R"( - local a: {x: number, y: number, [any]: any} | {y: number} - - function f(t) - t.y = 1 - return t - end - - local b = f(a) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // :( - // Should be the same as the type of a - REQUIRE_EQ("{| y: number |}", toString(requireType("b"))); -} - -TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table_2") -{ - CheckResult result = check(R"( - local a: {y: number} | {x: number, y: number, [any]: any} - - function f(t) - t.y = 1 - return t - end - - local b = f(a) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // :( - // Should be the same as the type of a - REQUIRE_EQ("{| [any]: any, x: number, y: number |}", toString(requireType("b"))); -} - // Luau currently doesn't yet know how to allow assignments when the binding was refined. TEST_CASE_FIXTURE(Fixture, "while_body_are_also_refined") { @@ -557,25 +503,6 @@ TEST_CASE_FIXTURE(Fixture, "bail_early_on_typescript_port_of_Result_type" * doct } } -TEST_CASE_FIXTURE(Fixture, "table_subtyping_shouldn't_add_optional_properties_to_sealed_tables") -{ - CheckResult result = check(R"( - --!strict - local function setNumber(t: { p: number? }, x:number) t.p = x end - local function getString(t: { p: string? }):string return t.p or "" end - -- This shouldn't type-check! - local function oh(x:number): string - local t: {} = {} - setNumber(t, x) - return getString(t) - end - local s: string = oh(37) - )"); - - // Really this should return an error, but it doesn't - LUAU_REQUIRE_NO_ERRORS(result); -} - // Should be in TypeInfer.tables.test.cpp // It's unsound to instantiate tables containing generic methods, // since mutating properties means table properties should be invariant. @@ -600,25 +527,9 @@ TEST_CASE_FIXTURE(Fixture, "invariant_table_properties_means_instantiating_table LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") -{ - // Mutability in type function application right now can create strange recursive types - // TODO: instantiation right now is problematic, in this example should either leave the Table type alone - // or it should rename the type to 'Self' so that the result will be 'Self
' - CheckResult result = check(R"( -type Table = { a: number } -type Self = T -local a: Self
- )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireType("a")), "Table
"); -} - TEST_CASE_FIXTURE(Fixture, "do_not_ice_when_trying_to_pick_first_of_generic_type_pack") { ScopedFastFlag sff[]{ - {"LuauQuantifyInPlace2", true}, {"LuauReturnAnyInsteadOfICE", true}, }; @@ -664,8 +575,6 @@ TEST_CASE_FIXTURE(Fixture, "specialization_binds_with_prototypes_too_early") TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") { - ScopedFastFlag sff{"LuauQuantifyInPlace2", true}; - CheckResult result = check(R"( local function f() return end local g = function() return f() end @@ -676,8 +585,6 @@ TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_type_pack") TEST_CASE_FIXTURE(Fixture, "weird_fail_to_unify_variadic_pack") { - ScopedFastFlag sff{"LuauQuantifyInPlace2", true}; - CheckResult result = check(R"( --!strict local function f(...) return ... end diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index bff8926c6..a5147d56a 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -8,7 +8,6 @@ LUAU_FASTFLAG(LuauDiscriminableUnions2) LUAU_FASTFLAG(LuauWeakEqConstraint) -LUAU_FASTFLAG(LuauQuantifyInPlace2) using namespace Luau; @@ -1179,20 +1178,14 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_free_table_to_vector") { LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauQuantifyInPlace2) - CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); - else - CHECK_EQ("Type '{- X: a, Y: b, Z: c -}' could not be converted into 'Instance'", toString(result.errors[0])); + CHECK_EQ("Type '{+ X: a, Y: b, Z: c +}' could not be converted into 'Instance'", toString(result.errors[0])); } CHECK_EQ("Vector3", toString(requireTypeAtPosition({5, 28}))); // type(vec) == "vector" CHECK_EQ("*unknown*", toString(requireTypeAtPosition({7, 28}))); // typeof(vec) == "Instance" - if (FFlag::LuauQuantifyInPlace2) - CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" - else - CHECK_EQ("{- X: a, Y: b, Z: c -}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" + CHECK_EQ("{+ X: a, Y: b, Z: c +}", toString(requireTypeAtPosition({9, 28}))); // type(vec) ~= "vector" and typeof(vec) ~= "Instance" } TEST_CASE_FIXTURE(RefinementClassFixture, "typeguard_cast_instance_or_vector3_to_vector") diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 856549bde..3ed536ea6 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -465,30 +465,32 @@ TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_si CHECK_EQ("((string) -> (b...), a) -> ()", toString(requireType("foo"))); } -// TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") -// { -// ScopedFastFlag sff[]{ -// {"LuauParseSingletonTypes", true}, -// {"LuauSingletonTypes", true}, -// {"LuauDiscriminableUnions2", true}, -// {"LuauEqConstraint", true}, -// {"LuauWidenIfSupertypeIsFree", true}, -// {"LuauWeakEqConstraint", false}, -// }; - -// CheckResult result = check(R"( -// local function foo(f, x): "hello"? -- anyone there? -// return if x == "hi" -// then f(x) -// else nil -// end -// )"); - -// LUAU_REQUIRE_NO_ERRORS(result); - -// CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 23}))); -// CHECK_EQ(R"(((string) -> ("hello"?, b...), a) -> "hello"?)", toString(requireType("foo"))); -// } +TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") +{ + ScopedFastFlag sff[]{ + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + {"LuauDiscriminableUnions2", true}, + {"LuauEqConstraint", true}, + {"LuauWidenIfSupertypeIsFree", true}, + {"LuauWeakEqConstraint", false}, + {"LuauDoNotAccidentallyDependOnPointerOrdering", true} + }; + + CheckResult result = check(R"( + local function foo(f, x): "hello"? -- anyone there? + return if x == "hi" + then f(x) + else nil + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 23}))); + CHECK_EQ(R"(((string) -> (a, c...), b) -> "hello"?)", toString(requireType("foo"))); + // CHECK_EQ(R"(((string) -> ("hello"?, b...), a) -> "hello"?)", toString(requireType("foo"))); +} TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere") { diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index aa949789c..da035ba14 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -1219,13 +1219,12 @@ TEST_CASE_FIXTURE(Fixture, "passing_compatible_unions_to_a_generic_table_without { CheckResult result = check(R"( type A = {x: number, y: number, [any]: any} | {y: number} - local a: A function f(t) t.y = 1 end - f(a) + f({y = 5} :: A) )"); LUAU_REQUIRE_NO_ERRORS(result); @@ -2165,6 +2164,44 @@ b() CHECK_EQ(toString(result.errors[0]), R"(Cannot call non-function t1 where t1 = { @metatable { __call: t1 }, { } })"); } +TEST_CASE_FIXTURE(Fixture, "table_subtyping_shouldn't_add_optional_properties_to_sealed_tables") +{ + ScopedFastFlag sffs[] = { + {"LuauTableSubtypingVariance2", true}, + {"LuauSubtypingAddOptPropsToUnsealedTables", true}, + }; + + CheckResult result = check(R"( + --!strict + local function setNumber(t: { p: number? }, x:number) t.p = x end + local function getString(t: { p: string? }):string return t.p or "" end + -- This shouldn't type-check! + local function oh(x:number): string + local t: {} = {} + setNumber(t, x) + return getString(t) + end + local s: string = oh(37) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "top_table_type") +{ + CheckResult result = check(R"( + --!strict + type Table = { [any] : any } + type HasTable = { p: Table? } + type HasHasTable = { p: HasTable? } + local t : Table = { p = 5 } + local u : HasTable = { p = { p = 5 } } + local v : HasHasTable = { p = { p = { p = 5 } } } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "length_operator_union") { CheckResult result = check(R"( @@ -2257,4 +2294,44 @@ TEST_CASE_FIXTURE(Fixture, "confusing_indexing") CHECK_EQ("number | string", toString(requireType("foo"))); } +TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table") +{ + ScopedFastFlag sff{"LuauDifferentOrderOfUnificationDoesntMatter", true}; + + CheckResult result = check(R"( + local a: {x: number, y: number, [any]: any} | {y: number} + + function f(t) + t.y = 1 + return t + end + + local b = f(a) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + REQUIRE_EQ("{| [any]: any, x: number, y: number |} | {| y: number |}", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a_table_2") +{ + ScopedFastFlag sff{"LuauDifferentOrderOfUnificationDoesntMatter", true}; + + CheckResult result = check(R"( + local a: {y: number} | {x: number, y: number, [any]: any} + + function f(t) + t.y = 1 + return t + end + + local b = f(a) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + REQUIRE_EQ("{| [any]: any, x: number, y: number |} | {| y: number |}", toString(requireType("b"))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index f44d9fd83..f63579b56 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -4044,6 +4044,49 @@ type t0 = any CHECK(ttv->instantiatedTypeParams.empty()); } +TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning_2") +{ + ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; + + CheckResult result = check(R"( +type X = T +type K = X +)"); + + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("math"); + REQUIRE(ty); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning_3") +{ + ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; + + CheckResult result = check(R"( +type X = T +local a = {} +a.x = 4 +local b: X +a.y = 5 +local c: X +c = b +)"); + + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("a"); + REQUIRE(ty); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + CHECK(ttv->instantiatedTypeParams.empty()); +} + TEST_CASE_FIXTURE(Fixture, "bound_free_table_export_is_ok") { CheckResult result = check(R"( @@ -4065,6 +4108,21 @@ return m LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") +{ + ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; + + // Mutability in type function application right now can create strange recursive types + CheckResult result = check(R"( +type Table = { a: number } +type Self = T +local a: Self
+ )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("a")), "Table"); +} + TEST_CASE_FIXTURE(Fixture, "no_persistent_typelevel_change") { TypeId mathTy = requireType(typeChecker.globalScope, "math"); @@ -5284,4 +5342,17 @@ TEST_CASE_FIXTURE(Fixture, "inferred_properties_of_a_table_should_start_with_the LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "global_singleton_types_are_sealed") +{ + CheckResult result = check(R"( +local function f(x: string) + local p = x:split('a') + p = table.pack(table.unpack(p, 1, #p - 1)) + return p +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index c9bf51032..f6ee3ccce 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -7,8 +7,6 @@ #include "doctest.h" -LUAU_FASTFLAG(LuauQuantifyInPlace2); - using namespace Luau; LUAU_FASTFLAG(LuauUseCommittingTxnLog) @@ -167,10 +165,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "typepack_unification_should_trim_free_tails" )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - if (FFlag::LuauQuantifyInPlace2) - CHECK_EQ("(number) -> boolean", toString(requireType("f"))); - else - CHECK_EQ("(number) -> (boolean)", toString(requireType("f"))); + CHECK_EQ("(number) -> boolean", toString(requireType("f"))); } TEST_CASE_FIXTURE(TryUnifyFixture, "variadic_type_pack_unification") diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index cbe2e48f2..6b96f4498 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -622,9 +622,6 @@ type Other = Packed TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_explicit") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type Y = { a: T, b: U } @@ -654,9 +651,6 @@ local c: Y = { a = "s" } TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_self") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type Y = { a: T, b: U } @@ -682,9 +676,6 @@ local a: Y TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_chained") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type Y = { a: T, b: U, c: V } @@ -700,9 +691,6 @@ local b: Y TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_explicit") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type Y = { a: (T...) -> () } local a: Y<> @@ -715,9 +703,6 @@ local a: Y<> TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_ty") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type Y = { a: T, b: (U...) -> T } @@ -731,9 +716,6 @@ local a: Y TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_tp") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type Y = { a: (T...) -> U... } local a: Y @@ -746,9 +728,6 @@ local a: Y TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_pack_self_chained_tp") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type Y = { a: (T...) -> U..., b: (T...) -> V... } local a: Y @@ -761,9 +740,6 @@ local a: Y TEST_CASE_FIXTURE(Fixture, "type_alias_default_mixed_self") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type Y = { a: (T, U, V...) -> W... } local a: Y @@ -782,9 +758,6 @@ local d: Y ()> TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_errors") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type Y = { a: T } local a: Y = { a = 2 } @@ -834,9 +807,6 @@ local a: Y<...number> TEST_CASE_FIXTURE(Fixture, "type_alias_default_export") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - fileResolver.source["Module/Types"] = R"( export type A = { a: T, b: U } export type B = { a: T, b: U } @@ -882,9 +852,6 @@ local h: Types.H<> TEST_CASE_FIXTURE(Fixture, "type_alias_default_type_skip_brackets") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type Y = (T...) -> number local a: Y @@ -897,9 +864,6 @@ local a: Y TEST_CASE_FIXTURE(Fixture, "type_alias_defaults_confusing_types") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type A = (T, V...) -> (U, W...) type B = A @@ -914,9 +878,6 @@ type C = A TEST_CASE_FIXTURE(Fixture, "type_alias_defaults_recursive_type") { - ScopedFastFlag luauParseTypeAliasDefaults{"LuauParseTypeAliasDefaults", true}; - ScopedFastFlag luauTypeAliasDefaults{"LuauTypeAliasDefaults", true}; - CheckResult result = check(R"( type F ()> = (K) -> V type R = { m: F } diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 3b53ddfe3..0e0b6ebba 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -400,10 +400,10 @@ local e = a.z CHECK_EQ("Type 'A | B | C | D' does not have key 'z'", toString(result.errors[3])); } -TEST_CASE_FIXTURE(Fixture, "unify_sealed_table_union_check") +TEST_CASE_FIXTURE(Fixture, "unify_unsealed_table_union_check") { CheckResult result = check(R"( -local x: { x: number } = { x = 3 } +local x = { x = 3 } type A = number? type B = string? local y: { x: number, y: A | B } @@ -413,7 +413,7 @@ y = x LUAU_REQUIRE_NO_ERRORS(result); result = check(R"( -local x: { x: number } = { x = 3 } +local x = { x = 3 } local a: number? = 2 local y = {} @@ -426,6 +426,31 @@ y = x LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "unify_sealed_table_union_check") +{ + ScopedFastFlag sffs[] = { + {"LuauTableSubtypingVariance2", true}, + {"LuauUnsealedTableLiteral", true}, + {"LuauSubtypingAddOptPropsToUnsealedTables", true}, + }; + + CheckResult result = check(R"( + -- the difference between this and unify_unsealed_table_union_check is the type annotation on x +local t = { x = 3, y = true } +local x: { x: number } = t +type A = number? +type B = string? +local y: { x: number, y: A | B } +-- Shouldn't typecheck! +y = x +-- If it does, we can convert any type to any other type +y.y = 5 +local oh : boolean = t.y + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_CASE_FIXTURE(Fixture, "error_detailed_union_part") { CheckResult result = check(R"( diff --git a/tests/conformance/nextvar.lua b/tests/conformance/nextvar.lua index 94ba5ccfb..e85fcbe8e 100644 --- a/tests/conformance/nextvar.lua +++ b/tests/conformance/nextvar.lua @@ -512,4 +512,42 @@ do assert(#t == 7) end +-- test clone +do + local t = {a = 1, b = 2, 3, 4, 5} + local tt = table.clone(t) + + assert(#tt == 3) + assert(tt.a == 1 and tt.b == 2) + + t.c = 3 + assert(tt.c == nil) + + t = table.freeze({"test"}) + tt = table.clone(t) + assert(table.isfrozen(t) and not table.isfrozen(tt)) + + t = setmetatable({}, {}) + tt = table.clone(t) + assert(getmetatable(t) == getmetatable(tt)) + + t = setmetatable({}, {__metatable = "protected"}) + assert(not pcall(table.clone, t)) + + function order(t) + local r = '' + for k,v in pairs(t) do + r ..= tostring(v) + end + return v + end + + t = {a = 1, b = 2, c = 3, d = 4, e = 5, f = 6} + tt = table.clone(t) + assert(order(t) == order(tt)) + + assert(not pcall(table.clone)) + assert(not pcall(table.clone, 42)) +end + return"OK" From feea507be3f6991a76907e92220f50db05d7a98e Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Fri, 11 Mar 2022 08:31:18 -0800 Subject: [PATCH 29/32] Sync to upstream/release/518 --- Analysis/include/Luau/TxnLog.h | 55 - Analysis/include/Luau/TypePack.h | 1 - Analysis/include/Luau/Unifier.h | 1 - Analysis/src/Autocomplete.cpp | 56 +- Analysis/src/BuiltinDefinitions.cpp | 17 +- Analysis/src/TxnLog.cpp | 117 -- Analysis/src/TypeInfer.cpp | 658 ++++------ Analysis/src/TypePack.cpp | 26 +- Analysis/src/Unifier.cpp | 1814 +++++++-------------------- VM/include/lua.h | 35 +- VM/include/luaconf.h | 27 - VM/include/lualib.h | 4 +- VM/src/laux.cpp | 15 + VM/src/lbaselib.cpp | 11 +- VM/src/lgc.h | 7 + tests/Autocomplete.test.cpp | 52 +- tests/TypeInfer.aliases.test.cpp | 23 + tests/TypeInfer.builtins.test.cpp | 51 +- tests/TypeInfer.generics.test.cpp | 45 + tests/TypeInfer.tables.test.cpp | 50 + tests/TypeInfer.test.cpp | 38 +- tests/TypeInfer.tryUnify.test.cpp | 44 +- tests/TypeInfer.typePacks.cpp | 13 +- tests/TypeInfer.unionTypes.test.cpp | 10 +- tools/LuauVisualize.py | 107 ++ tools/lldb-formatters.lldb | 2 + 26 files changed, 1124 insertions(+), 2155 deletions(-) create mode 100644 tools/LuauVisualize.py create mode 100644 tools/lldb-formatters.lldb diff --git a/Analysis/include/Luau/TxnLog.h b/Analysis/include/Luau/TxnLog.h index f81053839..c8ebaaeb7 100644 --- a/Analysis/include/Luau/TxnLog.h +++ b/Analysis/include/Luau/TxnLog.h @@ -14,61 +14,6 @@ namespace Luau using TypeOrPackId = const void*; -// Log of where what TypeIds we are rebinding and what they used to be -// Remove with LuauUseCommitTxnLog -struct DEPRECATED_TxnLog -{ - DEPRECATED_TxnLog() - : originalSeenSize(0) - , ownedSeen() - , sharedSeen(&ownedSeen) - { - } - - explicit DEPRECATED_TxnLog(std::vector>* sharedSeen) - : originalSeenSize(sharedSeen->size()) - , ownedSeen() - , sharedSeen(sharedSeen) - { - } - - DEPRECATED_TxnLog(const DEPRECATED_TxnLog&) = delete; - DEPRECATED_TxnLog& operator=(const DEPRECATED_TxnLog&) = delete; - - DEPRECATED_TxnLog(DEPRECATED_TxnLog&&) = default; - DEPRECATED_TxnLog& operator=(DEPRECATED_TxnLog&&) = default; - - void operator()(TypeId a); - void operator()(TypePackId a); - void operator()(TableTypeVar* a); - - void rollback(); - - void concat(DEPRECATED_TxnLog rhs); - - bool haveSeen(TypeId lhs, TypeId rhs); - void pushSeen(TypeId lhs, TypeId rhs); - void popSeen(TypeId lhs, TypeId rhs); - - bool haveSeen(TypePackId lhs, TypePackId rhs); - void pushSeen(TypePackId lhs, TypePackId rhs); - void popSeen(TypePackId lhs, TypePackId rhs); - -private: - std::vector> typeVarChanges; - std::vector> typePackChanges; - std::vector>> tableChanges; - size_t originalSeenSize; - - bool haveSeen(TypeOrPackId lhs, TypeOrPackId rhs); - void pushSeen(TypeOrPackId lhs, TypeOrPackId rhs); - void popSeen(TypeOrPackId lhs, TypeOrPackId rhs); - -public: - std::vector> ownedSeen; // used to avoid infinite recursion when types are cyclic - std::vector>* sharedSeen; // shared with all the descendent logs -}; - // Pending state for a TypeVar. Generated by a TxnLog and committed via // TxnLog::commit. struct PendingType diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index c74bad114..946be3561 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -105,7 +105,6 @@ struct TypePackIterator const TypePack* tp = nullptr; size_t currentIndex = 0; - // Only used if LuauUseCommittingTxnLog is true. const TxnLog* log; }; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 4c0462fe5..71958f4a1 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -45,7 +45,6 @@ struct Unifier TypeArena* const types; Mode mode; - DEPRECATED_TxnLog DEPRECATED_log; TxnLog log; ErrorVec errors; Location location; diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index c3de8d0e1..e94c432f8 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -13,9 +13,6 @@ #include #include -LUAU_FASTFLAG(LuauUseCommittingTxnLog) -LUAU_FASTFLAGVARIABLE(LuauAutocompleteAvoidMutation, false); -LUAU_FASTFLAGVARIABLE(LuauMissingFollowACMetatables, false); LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); static const std::unordered_set kStatementStartingKeywords = { @@ -240,28 +237,9 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ UnifierSharedState unifierState(&iceReporter); Unifier unifier(typeArena, Mode::Strict, Location(), Variance::Covariant, unifierState); - if (FFlag::LuauAutocompleteAvoidMutation && !FFlag::LuauUseCommittingTxnLog) - { - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; - CloneState cloneState; - superTy = clone(superTy, *typeArena, seenTypes, seenTypePacks, cloneState); - subTy = clone(subTy, *typeArena, seenTypes, seenTypePacks, cloneState); - - auto errors = unifier.canUnify(subTy, superTy); - return errors.empty(); - } - else - { - unifier.tryUnify(subTy, superTy); - - bool ok = unifier.errors.empty(); - - if (!FFlag::LuauUseCommittingTxnLog) - unifier.DEPRECATED_log.rollback(); - - return ok; - } + unifier.tryUnify(subTy, superTy); + bool ok = unifier.errors.empty(); + return ok; }; auto typeAtPosition = findExpectedTypeAt(module, node, position); @@ -403,28 +381,14 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId auto indexIt = mtable->props.find("__index"); if (indexIt != mtable->props.end()) { - if (FFlag::LuauMissingFollowACMetatables) + TypeId followed = follow(indexIt->second.type); + if (get(followed) || get(followed)) + autocompleteProps(module, typeArena, followed, indexType, nodes, result, seen); + else if (auto indexFunction = get(followed)) { - TypeId followed = follow(indexIt->second.type); - if (get(followed) || get(followed)) - autocompleteProps(module, typeArena, followed, indexType, nodes, result, seen); - else if (auto indexFunction = get(followed)) - { - std::optional indexFunctionResult = first(indexFunction->retType); - if (indexFunctionResult) - autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); - } - } - else - { - if (get(indexIt->second.type) || get(indexIt->second.type)) - autocompleteProps(module, typeArena, indexIt->second.type, indexType, nodes, result, seen); - else if (auto indexFunction = get(indexIt->second.type)) - { - std::optional indexFunctionResult = first(indexFunction->retType); - if (indexFunctionResult) - autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); - } + std::optional indexFunctionResult = first(indexFunction->retType); + if (indexFunctionResult) + autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); } } } diff --git a/Analysis/src/BuiltinDefinitions.cpp b/Analysis/src/BuiltinDefinitions.cpp index e4e5dab82..bf9ef303f 100644 --- a/Analysis/src/BuiltinDefinitions.cpp +++ b/Analysis/src/BuiltinDefinitions.cpp @@ -10,6 +10,7 @@ LUAU_FASTFLAG(LuauAssertStripsFalsyTypes) LUAU_FASTFLAGVARIABLE(LuauTableCloneType, false) +LUAU_FASTFLAGVARIABLE(LuauSetMetaTableArgsCheck, false) /** FIXME: Many of these type definitions are not quite completely accurate. * @@ -376,11 +377,19 @@ static std::optional> magicFunctionSetMetaTable( TypeId mtTy = arena.addType(mtv); - AstExpr* targetExpr = expr.args.data[0]; - if (AstExprLocal* targetLocal = targetExpr->as()) + if (FFlag::LuauSetMetaTableArgsCheck && expr.args.size < 1) { - const Name targetName(targetLocal->local->name.value); - scope->bindings[targetLocal->local] = Binding{mtTy, expr.location}; + return ExprResult{}; + } + + if (!FFlag::LuauSetMetaTableArgsCheck || !expr.self) + { + AstExpr* targetExpr = expr.args.data[0]; + if (AstExprLocal* targetLocal = targetExpr->as()) + { + const Name targetName(targetLocal->local->name.value); + scope->bindings[targetLocal->local] = Binding{mtTy, expr.location}; + } } return ExprResult{arena.addTypePack({mtTy})}; diff --git a/Analysis/src/TxnLog.cpp b/Analysis/src/TxnLog.cpp index c7bf1e62c..876f5f05a 100644 --- a/Analysis/src/TxnLog.cpp +++ b/Analysis/src/TxnLog.cpp @@ -7,110 +7,9 @@ #include #include -LUAU_FASTFLAGVARIABLE(LuauUseCommittingTxnLog, false) - namespace Luau { -void DEPRECATED_TxnLog::operator()(TypeId a) -{ - LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - typeVarChanges.emplace_back(a, *a); -} - -void DEPRECATED_TxnLog::operator()(TypePackId a) -{ - LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - typePackChanges.emplace_back(a, *a); -} - -void DEPRECATED_TxnLog::operator()(TableTypeVar* a) -{ - LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - tableChanges.emplace_back(a, a->boundTo); -} - -void DEPRECATED_TxnLog::rollback() -{ - LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - for (auto it = typeVarChanges.rbegin(); it != typeVarChanges.rend(); ++it) - std::swap(*asMutable(it->first), it->second); - - for (auto it = typePackChanges.rbegin(); it != typePackChanges.rend(); ++it) - std::swap(*asMutable(it->first), it->second); - - for (auto it = tableChanges.rbegin(); it != tableChanges.rend(); ++it) - std::swap(it->first->boundTo, it->second); - - LUAU_ASSERT(originalSeenSize <= sharedSeen->size()); - sharedSeen->resize(originalSeenSize); -} - -void DEPRECATED_TxnLog::concat(DEPRECATED_TxnLog rhs) -{ - LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - typeVarChanges.insert(typeVarChanges.end(), rhs.typeVarChanges.begin(), rhs.typeVarChanges.end()); - rhs.typeVarChanges.clear(); - - typePackChanges.insert(typePackChanges.end(), rhs.typePackChanges.begin(), rhs.typePackChanges.end()); - rhs.typePackChanges.clear(); - - tableChanges.insert(tableChanges.end(), rhs.tableChanges.begin(), rhs.tableChanges.end()); - rhs.tableChanges.clear(); -} - -bool DEPRECATED_TxnLog::haveSeen(TypeId lhs, TypeId rhs) -{ - return haveSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); -} - -void DEPRECATED_TxnLog::pushSeen(TypeId lhs, TypeId rhs) -{ - pushSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); -} - -void DEPRECATED_TxnLog::popSeen(TypeId lhs, TypeId rhs) -{ - popSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); -} - -bool DEPRECATED_TxnLog::haveSeen(TypePackId lhs, TypePackId rhs) -{ - return haveSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); -} - -void DEPRECATED_TxnLog::pushSeen(TypePackId lhs, TypePackId rhs) -{ - pushSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); -} - -void DEPRECATED_TxnLog::popSeen(TypePackId lhs, TypePackId rhs) -{ - popSeen((TypeOrPackId)lhs, (TypeOrPackId)rhs); -} - -bool DEPRECATED_TxnLog::haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) -{ - LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - return (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)); -} - -void DEPRECATED_TxnLog::pushSeen(TypeOrPackId lhs, TypeOrPackId rhs) -{ - LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - sharedSeen->push_back(sortedPair); -} - -void DEPRECATED_TxnLog::popSeen(TypeOrPackId lhs, TypeOrPackId rhs) -{ - LUAU_ASSERT(!FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); - LUAU_ASSERT(sortedPair == sharedSeen->back()); - sharedSeen->pop_back(); -} - const std::string nullPendingResult = ""; std::string toString(PendingType* pending) @@ -170,8 +69,6 @@ const TxnLog* TxnLog::empty() void TxnLog::concat(TxnLog rhs) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - for (auto& [ty, rep] : rhs.typeVarChanges) typeVarChanges[ty] = std::move(rep); @@ -181,8 +78,6 @@ void TxnLog::concat(TxnLog rhs) void TxnLog::commit() { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - for (auto& [ty, rep] : typeVarChanges) *asMutable(ty) = rep.get()->pending; @@ -194,16 +89,12 @@ void TxnLog::commit() void TxnLog::clear() { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - typeVarChanges.clear(); typePackChanges.clear(); } TxnLog TxnLog::inverse() { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - TxnLog inversed(sharedSeen); for (auto& [ty, _rep] : typeVarChanges) @@ -247,8 +138,6 @@ void TxnLog::popSeen(TypePackId lhs, TypePackId rhs) bool TxnLog::haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) const { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); if (sharedSeen->end() != std::find(sharedSeen->begin(), sharedSeen->end(), sortedPair)) { @@ -265,16 +154,12 @@ bool TxnLog::haveSeen(TypeOrPackId lhs, TypeOrPackId rhs) const void TxnLog::pushSeen(TypeOrPackId lhs, TypeOrPackId rhs) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); sharedSeen->push_back(sortedPair); } void TxnLog::popSeen(TypeOrPackId lhs, TypeOrPackId rhs) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); - const std::pair sortedPair = (lhs > rhs) ? std::make_pair(lhs, rhs) : std::make_pair(rhs, lhs); LUAU_ASSERT(sortedPair == sharedSeen->back()); sharedSeen->pop_back(); @@ -282,7 +167,6 @@ void TxnLog::popSeen(TypeOrPackId lhs, TypeOrPackId rhs) PendingType* TxnLog::queue(TypeId ty) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); LUAU_ASSERT(!ty->persistent); // Explicitly don't look in ancestors. If we have discovered something new @@ -296,7 +180,6 @@ PendingType* TxnLog::queue(TypeId ty) PendingTypePack* TxnLog::queue(TypePackId tp) { - LUAU_ASSERT(FFlag::LuauUseCommittingTxnLog); LUAU_ASSERT(!tp->persistent); // Explicitly don't look in ancestors. If we have discovered something new diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 8e6b3b52f..3fe4c90ef 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -24,7 +24,6 @@ LUAU_FASTINTVARIABLE(LuauCheckRecursionLimit, 500) LUAU_FASTFLAG(LuauKnowsTheDataModel3) LUAU_FASTFLAGVARIABLE(LuauEqConstraint, false) LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as false. -LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) @@ -36,6 +35,7 @@ LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) +LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) LUAU_FASTFLAGVARIABLE(LuauAssertStripsFalsyTypes, false) @@ -43,6 +43,8 @@ LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as LUAU_FASTFLAG(LuauWidenIfSupertypeIsFree) LUAU_FASTFLAGVARIABLE(LuauDoNotTryToReduce, false) LUAU_FASTFLAGVARIABLE(LuauDoNotAccidentallyDependOnPointerOrdering, false) +LUAU_FASTFLAGVARIABLE(LuauFixArgumentCountMismatchAmountWithGenericTypes, false) +LUAU_FASTFLAGVARIABLE(LuauFixIncorrectLineNumberDuplicateType, false) namespace Luau { @@ -652,18 +654,15 @@ ErrorVec TypeChecker::tryUnify_(Id subTy, Id superTy, const Location& location) { Unifier state = mkUnifier(location); - if (FFlag::LuauUseCommittingTxnLog && FFlag::DebugLuauFreezeDuringUnification) + if (FFlag::DebugLuauFreezeDuringUnification) freeze(currentModule->internalTypes); state.tryUnify(subTy, superTy); - if (FFlag::LuauUseCommittingTxnLog && FFlag::DebugLuauFreezeDuringUnification) + if (FFlag::DebugLuauFreezeDuringUnification) unfreeze(currentModule->internalTypes); - if (!state.errors.empty() && !FFlag::LuauUseCommittingTxnLog) - state.DEPRECATED_log.rollback(); - - if (state.errors.empty() && FFlag::LuauUseCommittingTxnLog) + if (state.errors.empty()) state.log.commit(); return state.errors; @@ -847,8 +846,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatLocal& local) state.tryUnify(valuePack, variablePack); reportErrors(state.errors); - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); // In the code 'local T = {}', we wish to ascribe the name 'T' to the type of the table for error-reporting purposes. // We also want to do this for 'local T = setmetatable(...)'. @@ -1040,8 +1038,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatForIn& forin) Unifier state = mkUnifier(firstValue->location); checkArgumentList(loopScope, state, argPack, iterFunc->argTypes, /*argLocations*/ {}); - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); reportErrors(state.errors); } @@ -1102,8 +1099,53 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco scope->bindings[name->local] = {anyIfNonstrict(quantify(funScope, ty, name->local->location)), name->local->location}; return; } + else if (auto name = function.name->as(); name && FFlag::LuauStatFunctionSimplify) + { + TypeId exprTy = checkExpr(scope, *name->expr).type; + TableTypeVar* ttv = getMutableTableType(exprTy); + if (!ttv) + { + if (isTableIntersection(exprTy)) + reportError(TypeError{function.location, CannotExtendTable{exprTy, CannotExtendTable::Property, name->index.value}}); + else if (!get(exprTy) && !get(exprTy)) + reportError(TypeError{function.location, OnlyTablesCanHaveMethods{exprTy}}); + } + else if (ttv->state == TableState::Sealed) + reportError(TypeError{function.location, CannotExtendTable{exprTy, CannotExtendTable::Property, name->index.value}}); + + ty = follow(ty); + + if (ttv && ttv->state != TableState::Sealed) + ttv->props[name->index.value] = {ty, /* deprecated */ false, {}, name->indexLocation}; + + if (function.func->self) + { + const FunctionTypeVar* funTy = get(ty); + if (!funTy) + ice("Methods should be functions"); + + std::optional arg0 = first(funTy->argTypes); + if (!arg0) + ice("Methods should always have at least 1 argument (self)"); + } + + checkFunctionBody(funScope, ty, *function.func); + + if (ttv && ttv->state != TableState::Sealed) + ttv->props[name->index.value] = {follow(quantify(funScope, ty, name->indexLocation)), /* deprecated */ false, {}, name->indexLocation}; + } + else if (FFlag::LuauStatFunctionSimplify) + { + LUAU_ASSERT(function.name->is()); + + ty = follow(ty); + + checkFunctionBody(funScope, ty, *function.func); + } else if (function.func->self) { + LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify); + AstExprIndexName* indexName = function.name->as(); if (!indexName) ice("member function declaration has malformed name expression"); @@ -1141,6 +1183,8 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco } else { + LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify); + TypeId leftType = checkLValueBinding(scope, *function.name); checkFunctionBody(funScope, ty, *function.func); @@ -1217,6 +1261,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias LUAU_ASSERT(ftv); ftv->forwardedTypeAlias = true; bindingsMap[name] = {std::move(generics), std::move(genericPacks), ty}; + + if (FFlag::LuauFixIncorrectLineNumberDuplicateType) + scope->typeAliasLocations[name] = typealias.location; } } else @@ -2102,9 +2149,7 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExprUn Unifier state = mkUnifier(expr.location); state.tryUnify(actualFunctionType, expectedFunctionType, /*isFunctionCall*/ true); - - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); TypeId retType = first(retTypePack).value_or(nilType); if (!state.errors.empty()) @@ -2283,9 +2328,7 @@ TypeId TypeChecker::checkRelationalOperation( if (!isEquality) { state.tryUnify(rhsType, lhsType); - - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); } bool needsMetamethod = !isEquality; @@ -2336,8 +2379,7 @@ TypeId TypeChecker::checkRelationalOperation( return errorRecoveryType(booleanType); } - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); } } @@ -2347,8 +2389,7 @@ TypeId TypeChecker::checkRelationalOperation( state.tryUnify( instantiate(scope, actualFunctionType, expr.location), instantiate(scope, *metamethod, expr.location), /*isFunctionCall*/ true); - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); reportErrors(state.errors); return booleanType; @@ -2464,25 +2505,15 @@ TypeId TypeChecker::checkBinaryOperation( TypePackId fallbackArguments = freshTypePack(scope); TypeId fallbackFunctionType = addType(FunctionTypeVar(scope->level, fallbackArguments, retTypePack)); state.errors.clear(); - - if (FFlag::LuauUseCommittingTxnLog) - { - state.log.clear(); - } - else - { - state.DEPRECATED_log.rollback(); - } + state.log.clear(); state.tryUnify(actualFunctionType, fallbackFunctionType, /*isFunctionCall*/ true); - if (FFlag::LuauUseCommittingTxnLog && state.errors.empty()) + if (state.errors.empty()) state.log.commit(); - else if (!state.errors.empty() && !FFlag::LuauUseCommittingTxnLog) - state.DEPRECATED_log.rollback(); } - if (FFlag::LuauUseCommittingTxnLog && !hasErrors) + if (!hasErrors) { state.log.commit(); } @@ -2729,13 +2760,11 @@ TypeId TypeChecker::checkLValueBinding(const ScopePtr& scope, const AstExprIndex TypeId retType = indexer->indexResultType; if (!state.errors.empty()) { - if (!FFlag::LuauUseCommittingTxnLog) - state.DEPRECATED_log.rollback(); reportError(expr.location, UnknownProperty{lhs, name}); retType = errorRecoveryType(retType); } - else if (FFlag::LuauUseCommittingTxnLog) + else state.log.commit(); return retType; @@ -3209,7 +3238,7 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A } // Returns the minimum number of arguments the argument list can accept. -static size_t getMinParameterCount(TypePackId tp) +static size_t getMinParameterCount_DEPRECATED(TypePackId tp) { size_t minCount = 0; size_t optionalCount = 0; @@ -3235,6 +3264,32 @@ static size_t getMinParameterCount(TypePackId tp) return minCount; } +static size_t getMinParameterCount(TxnLog* log, TypePackId tp) +{ + size_t minCount = 0; + size_t optionalCount = 0; + + auto it = begin(tp, log); + auto endIter = end(tp); + + while (it != endIter) + { + TypeId ty = *it; + if (isOptional(ty)) + ++optionalCount; + else + { + minCount += optionalCount; + optionalCount = 0; + minCount++; + } + + ++it; + } + + return minCount; +} + void TypeChecker::checkArgumentList( const ScopePtr& scope, Unifier& state, TypePackId argPack, TypePackId paramPack, const std::vector& argLocations) { @@ -3248,396 +3303,199 @@ void TypeChecker::checkArgumentList( size_t paramIndex = 0; - size_t minParams = getMinParameterCount(paramPack); + size_t minParams = FFlag::LuauFixIncorrectLineNumberDuplicateType ? 0 : getMinParameterCount_DEPRECATED(paramPack); - if (FFlag::LuauUseCommittingTxnLog) + while (true) { - while (true) - { - state.location = paramIndex < argLocations.size() ? argLocations[paramIndex] : state.location; + state.location = paramIndex < argLocations.size() ? argLocations[paramIndex] : state.location; - if (argIter == endIter && paramIter == endIter) - { - std::optional argTail = argIter.tail(); - std::optional paramTail = paramIter.tail(); + if (argIter == endIter && paramIter == endIter) + { + std::optional argTail = argIter.tail(); + std::optional paramTail = paramIter.tail(); - // If we hit the end of both type packs simultaneously, then there are definitely no further type - // errors to report. All we need to do is tie up any free tails. - // - // If one side has a free tail and the other has none at all, we create an empty pack and bind the - // free tail to that. + // If we hit the end of both type packs simultaneously, then there are definitely no further type + // errors to report. All we need to do is tie up any free tails. + // + // If one side has a free tail and the other has none at all, we create an empty pack and bind the + // free tail to that. - if (argTail) - { - if (state.log.getMutable(state.log.follow(*argTail))) - { - if (paramTail) - state.tryUnify(*paramTail, *argTail); - else - state.log.replace(*argTail, TypePackVar(TypePack{{}})); - } - } - else if (paramTail) + if (argTail) + { + if (state.log.getMutable(state.log.follow(*argTail))) { - // argTail is definitely empty - if (state.log.getMutable(state.log.follow(*paramTail))) - state.log.replace(*paramTail, TypePackVar(TypePack{{}})); + if (paramTail) + state.tryUnify(*paramTail, *argTail); + else + state.log.replace(*argTail, TypePackVar(TypePack{{}})); } - - return; } - else if (argIter == endIter) + else if (paramTail) { - // Not enough arguments. - - // Might be ok if we are forwarding a vararg along. This is a common thing to occur in nonstrict mode. - if (argIter.tail()) - { - TypePackId tail = *argIter.tail(); - if (state.log.getMutable(tail)) - { - // Unify remaining parameters so we don't leave any free-types hanging around. - while (paramIter != endIter) - { - state.tryUnify(errorRecoveryType(anyType), *paramIter); - ++paramIter; - } - return; - } - else if (auto vtp = state.log.getMutable(tail)) - { - while (paramIter != endIter) - { - state.tryUnify(vtp->ty, *paramIter); - ++paramIter; - } - - return; - } - else if (state.log.getMutable(tail)) - { - std::vector rest; - rest.reserve(std::distance(paramIter, endIter)); - while (paramIter != endIter) - { - rest.push_back(*paramIter); - ++paramIter; - } + // argTail is definitely empty + if (state.log.getMutable(state.log.follow(*paramTail))) + state.log.replace(*paramTail, TypePackVar(TypePack{{}})); + } - TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, paramIter.tail()}}); - state.tryUnify(varPack, tail); - return; - } - } + return; + } + else if (argIter == endIter) + { + // Not enough arguments. - // If any remaining unfulfilled parameters are nonoptional, this is a problem. - while (paramIter != endIter) - { - TypeId t = state.log.follow(*paramIter); - if (isOptional(t)) - { - } // ok - else if (state.log.getMutable(t)) - { - } // ok - else if (isNonstrictMode() && state.log.getMutable(t)) - { - } // ok - else - { - state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex}}); - return; - } - ++paramIter; - } - } - else if (paramIter == endIter) + // Might be ok if we are forwarding a vararg along. This is a common thing to occur in nonstrict mode. + if (argIter.tail()) { - // too many parameters passed - if (!paramIter.tail()) + TypePackId tail = *argIter.tail(); + if (state.log.getMutable(tail)) { - while (argIter != endIter) + // Unify remaining parameters so we don't leave any free-types hanging around. + while (paramIter != endIter) { - // The use of unify here is deliberate. We don't want this unification - // to be undoable. - unify(errorRecoveryType(scope), *argIter, state.location); - ++argIter; + state.tryUnify(errorRecoveryType(anyType), *paramIter); + ++paramIter; } - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); - return; - } - TypePackId tail = state.log.follow(*paramIter.tail()); - - if (state.log.getMutable(tail)) - { - // Function is variadic. Ok. return; } else if (auto vtp = state.log.getMutable(tail)) { - // Function is variadic and requires that all subsequent parameters - // be compatible with a type. - size_t argIndex = paramIndex; - while (argIter != endIter) + while (paramIter != endIter) { - Location location = state.location; - - if (argIndex < argLocations.size()) - location = argLocations[argIndex]; - - unify(*argIter, vtp->ty, location); - ++argIter; - ++argIndex; + state.tryUnify(vtp->ty, *paramIter); + ++paramIter; } return; } else if (state.log.getMutable(tail)) { - // Create a type pack out of the remaining argument types - // and unify it with the tail. std::vector rest; - rest.reserve(std::distance(argIter, endIter)); - while (argIter != endIter) + rest.reserve(std::distance(paramIter, endIter)); + while (paramIter != endIter) { - rest.push_back(*argIter); - ++argIter; + rest.push_back(*paramIter); + ++paramIter; } - TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); + TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, paramIter.tail()}}); state.tryUnify(varPack, tail); return; } - else if (state.log.getMutable(tail)) + } + + // If any remaining unfulfilled parameters are nonoptional, this is a problem. + while (paramIter != endIter) + { + TypeId t = state.log.follow(*paramIter); + if (isOptional(t)) { - state.log.replace(tail, TypePackVar(TypePack{{}})); - return; - } - else if (state.log.getMutable(tail)) + } // ok + else if (state.log.getMutable(t)) { - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - // TODO: Better error message? - state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + } // ok + else if (isNonstrictMode() && state.log.getMutable(t)) + { + } // ok + else + { + if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) + minParams = getMinParameterCount(&state.log, paramPack); + state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex}}); return; } - } - else - { - unifyWithInstantiationIfNeeded(scope, *argIter, *paramIter, state); - ++argIter; ++paramIter; } - - ++paramIndex; } - } - else - { - while (true) + else if (paramIter == endIter) { - state.location = paramIndex < argLocations.size() ? argLocations[paramIndex] : state.location; - - if (argIter == endIter && paramIter == endIter) + // too many parameters passed + if (!paramIter.tail()) { - std::optional argTail = argIter.tail(); - std::optional paramTail = paramIter.tail(); - - // If we hit the end of both type packs simultaneously, then there are definitely no further type - // errors to report. All we need to do is tie up any free tails. - // - // If one side has a free tail and the other has none at all, we create an empty pack and bind the - // free tail to that. - - if (argTail) + while (argIter != endIter) { - if (get(*argTail)) - { - if (paramTail) - state.tryUnify(*paramTail, *argTail); - else - { - state.DEPRECATED_log(*argTail); - *asMutable(*argTail) = TypePack{{}}; - } - } - } - else if (paramTail) - { - // argTail is definitely empty - if (get(*paramTail)) - { - state.DEPRECATED_log(*paramTail); - *asMutable(*paramTail) = TypePack{{}}; - } + // The use of unify here is deliberate. We don't want this unification + // to be undoable. + unify(errorRecoveryType(scope), *argIter, state.location); + ++argIter; } + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + + if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) + minParams = getMinParameterCount(&state.log, paramPack); + state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + return; + } + TypePackId tail = state.log.follow(*paramIter.tail()); + if (state.log.getMutable(tail)) + { + // Function is variadic. Ok. return; } - else if (argIter == endIter) + else if (auto vtp = state.log.getMutable(tail)) { - // Not enough arguments. - - // Might be ok if we are forwarding a vararg along. This is a common thing to occur in nonstrict mode. - if (argIter.tail()) + // Function is variadic and requires that all subsequent parameters + // be compatible with a type. + size_t argIndex = paramIndex; + while (argIter != endIter) { - TypePackId tail = *argIter.tail(); - if (get(tail)) - { - // Unify remaining parameters so we don't leave any free-types hanging around. - while (paramIter != endIter) - { - state.tryUnify(*paramIter, errorRecoveryType(anyType)); - ++paramIter; - } - return; - } - else if (auto vtp = get(tail)) - { - while (paramIter != endIter) - { - state.tryUnify(*paramIter, vtp->ty); - ++paramIter; - } + Location location = state.location; - return; - } - else if (get(tail)) - { - std::vector rest; - rest.reserve(std::distance(paramIter, endIter)); - while (paramIter != endIter) - { - rest.push_back(*paramIter); - ++paramIter; - } + if (argIndex < argLocations.size()) + location = argLocations[argIndex]; - TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, paramIter.tail()}}); - state.tryUnify(varPack, tail); - return; - } + unify(*argIter, vtp->ty, location); + ++argIter; + ++argIndex; } - // If any remaining unfulfilled parameters are nonoptional, this is a problem. - while (paramIter != endIter) - { - TypeId t = follow(*paramIter); - if (isOptional(t)) - { - } // ok - else if (get(t)) - { - } // ok - else if (isNonstrictMode() && get(t)) - { - } // ok - else - { - state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex}}); - return; - } - ++paramIter; - } + return; } - else if (paramIter == endIter) + else if (state.log.getMutable(tail)) { - // too many parameters passed - if (!paramIter.tail()) - { - while (argIter != endIter) - { - unify(*argIter, errorRecoveryType(scope), state.location); - ++argIter; - } - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); - return; - } - TypePackId tail = *paramIter.tail(); - - if (get(tail)) - { - // Function is variadic. Ok. - return; - } - else if (auto vtp = get(tail)) - { - // Function is variadic and requires that all subsequent parameters - // be compatible with a type. - size_t argIndex = paramIndex; - while (argIter != endIter) - { - Location location = state.location; - - if (argIndex < argLocations.size()) - location = argLocations[argIndex]; - - unify(*argIter, vtp->ty, location); - ++argIter; - ++argIndex; - } - - return; - } - else if (get(tail)) + // Create a type pack out of the remaining argument types + // and unify it with the tail. + std::vector rest; + rest.reserve(std::distance(argIter, endIter)); + while (argIter != endIter) { - // Create a type pack out of the remaining argument types - // and unify it with the tail. - std::vector rest; - rest.reserve(std::distance(argIter, endIter)); - while (argIter != endIter) - { - rest.push_back(*argIter); - ++argIter; - } - - TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); - state.tryUnify(tail, varPack); - return; + rest.push_back(*argIter); + ++argIter; } - else if (get(tail)) - { - if (FFlag::LuauUseCommittingTxnLog) - { - state.log.replace(tail, TypePackVar(TypePack{{}})); - } - else - { - state.DEPRECATED_log(tail); - *asMutable(tail) = TypePack{}; - } - return; - } - else if (get(tail)) - { - // For this case, we want the error span to cover every errant extra parameter - Location location = state.location; - if (!argLocations.empty()) - location = {state.location.begin, argLocations.back().end}; - // TODO: Better error message? - state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); - return; - } + TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); + state.tryUnify(varPack, tail); + return; } - else + else if (state.log.getMutable(tail)) { - unifyWithInstantiationIfNeeded(scope, *argIter, *paramIter, state); - ++argIter; - ++paramIter; + state.log.replace(tail, TypePackVar(TypePack{{}})); + return; + } + else if (state.log.getMutable(tail)) + { + // For this case, we want the error span to cover every errant extra parameter + Location location = state.location; + if (!argLocations.empty()) + location = {state.location.begin, argLocations.back().end}; + // TODO: Better error message? + if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) + minParams = getMinParameterCount(&state.log, paramPack); + state.reportError(TypeError{location, CountMismatch{minParams, std::distance(begin(argPack), end(argPack))}}); + return; } - - ++paramIndex; } + else + { + unifyWithInstantiationIfNeeded(scope, *argIter, *paramIter, state); + ++argIter; + ++paramIter; + } + + ++paramIndex; } } @@ -3882,9 +3740,6 @@ std::optional> TypeChecker::checkCallOverload(const Scope checkArgumentList(scope, state, retPack, ftv->retType, /*argLocations*/ {}); if (!state.errors.empty()) { - if (!FFlag::LuauUseCommittingTxnLog) - state.DEPRECATED_log.rollback(); - return {}; } @@ -3912,14 +3767,10 @@ std::optional> TypeChecker::checkCallOverload(const Scope overloadsThatDont.push_back(fn); errors.emplace_back(std::move(state.errors), args->head, ftv); - - if (!FFlag::LuauUseCommittingTxnLog) - state.DEPRECATED_log.rollback(); } else { - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); if (isNonstrictMode() && !expr.self && expr.func->is() && ftv->hasSelf) { @@ -3976,8 +3827,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal if (editedState.errors.empty()) { - if (FFlag::LuauUseCommittingTxnLog) - editedState.log.commit(); + editedState.log.commit(); reportError(TypeError{expr.location, FunctionDoesNotTakeSelf{}}); // This is a little bit suspect: If this overload would work with a . replaced by a : @@ -3987,8 +3837,6 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); return true; } - else if (!FFlag::LuauUseCommittingTxnLog) - editedState.DEPRECATED_log.rollback(); } else if (ftv->hasSelf) { @@ -4010,8 +3858,7 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal if (editedState.errors.empty()) { - if (FFlag::LuauUseCommittingTxnLog) - editedState.log.commit(); + editedState.log.commit(); reportError(TypeError{expr.location, FunctionRequiresSelf{}}); // This is a little bit suspect: If this overload would work with a : replaced by a . @@ -4021,8 +3868,6 @@ bool TypeChecker::handleSelfCallMismatch(const ScopePtr& scope, const AstExprCal // checkArgumentList(scope, editedState, retPack, ftv->retType, retLocations, CountMismatch::Return); return true; } - else if (!FFlag::LuauUseCommittingTxnLog) - editedState.DEPRECATED_log.rollback(); } } } @@ -4082,7 +3927,7 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast checkArgumentList(scope, state, argPack, ftv->argTypes, argLocations); } - if (FFlag::LuauUseCommittingTxnLog && state.errors.empty()) + if (state.errors.empty()) state.log.commit(); if (i > 0) @@ -4092,9 +3937,6 @@ void TypeChecker::reportOverloadResolutionError(const ScopePtr& scope, const Ast s += "and "; s += toString(overload); - - if (!FFlag::LuauUseCommittingTxnLog) - state.DEPRECATED_log.rollback(); } if (overloadsThatMatchArgCount.size() == 0) @@ -4168,24 +4010,16 @@ ExprResult TypeChecker::checkExprList(const ScopePtr& scope, const L // just performed. There's not a great way to pass that into checkExpr. Instead, we store // the inverse of the current log, and commit it. When we're done, we'll commit all the // inverses. This isn't optimal, and a better solution is welcome here. - if (FFlag::LuauUseCommittingTxnLog) - { - inverseLogs.push_back(state.log.inverse()); - state.log.commit(); - } + inverseLogs.push_back(state.log.inverse()); + state.log.commit(); } tp->head.push_back(actualType); } } - if (FFlag::LuauUseCommittingTxnLog) - { - for (TxnLog& log : inverseLogs) - log.commit(); - } - else - state.DEPRECATED_log.rollback(); + for (TxnLog& log : inverseLogs) + log.commit(); return {pack, predicates}; } @@ -4294,8 +4128,7 @@ bool TypeChecker::unify(TypeId subTy, TypeId superTy, const Location& location, Unifier state = mkUnifier(location); state.tryUnify(subTy, superTy, options.isFunctionCall); - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); reportErrors(state.errors); @@ -4308,8 +4141,7 @@ bool TypeChecker::unify(TypePackId subTy, TypePackId superTy, const Location& lo state.ctx = ctx; state.tryUnify(subTy, superTy); - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); reportErrors(state.errors); @@ -4321,8 +4153,7 @@ bool TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId s Unifier state = mkUnifier(location); unifyWithInstantiationIfNeeded(scope, subTy, superTy, state); - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); reportErrors(state.errors); @@ -4352,31 +4183,18 @@ void TypeChecker::unifyWithInstantiationIfNeeded(const ScopePtr& scope, TypeId s if (subTy == instantiated) { // Instantiating the argument made no difference, so just report any child errors - if (FFlag::LuauUseCommittingTxnLog) - state.log.concat(std::move(child.log)); - else - state.DEPRECATED_log.concat(std::move(child.DEPRECATED_log)); + state.log.concat(std::move(child.log)); state.errors.insert(state.errors.end(), child.errors.begin(), child.errors.end()); } else { - if (!FFlag::LuauUseCommittingTxnLog) - child.DEPRECATED_log.rollback(); - state.tryUnify(instantiated, superTy, /*isFunctionCall*/ false); } } else { - if (FFlag::LuauUseCommittingTxnLog) - { - state.log.concat(std::move(child.log)); - } - else - { - state.DEPRECATED_log.concat(std::move(child.DEPRECATED_log)); - } + state.log.concat(std::move(child.log)); } } } @@ -4540,7 +4358,7 @@ TypeId TypeChecker::quantify(const ScopePtr& scope, TypeId ty, Location location TypeId TypeChecker::instantiate(const ScopePtr& scope, TypeId ty, Location location, const TxnLog* log) { - Instantiation instantiation{FFlag::LuauUseCommittingTxnLog ? log : TxnLog::empty(), ¤tModule->internalTypes, scope->level}; + Instantiation instantiation{log, ¤tModule->internalTypes, scope->level}; std::optional instantiated = instantiation.substitute(ty); if (instantiated.has_value()) return *instantiated; diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index b15548a8d..91123f468 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -5,8 +5,6 @@ #include -LUAU_FASTFLAG(LuauUseCommittingTxnLog) - namespace Luau { @@ -51,16 +49,8 @@ TypePackIterator::TypePackIterator(TypePackId typePack, const TxnLog* log) { while (tp && tp->head.empty()) { - if (FFlag::LuauUseCommittingTxnLog) - { - currentTypePack = tp->tail ? log->follow(*tp->tail) : nullptr; - tp = currentTypePack ? log->getMutable(currentTypePack) : nullptr; - } - else - { - currentTypePack = tp->tail ? follow(*tp->tail) : nullptr; - tp = currentTypePack ? get(currentTypePack) : nullptr; - } + currentTypePack = tp->tail ? log->follow(*tp->tail) : nullptr; + tp = currentTypePack ? log->getMutable(currentTypePack) : nullptr; } } @@ -71,16 +61,8 @@ TypePackIterator& TypePackIterator::operator++() ++currentIndex; while (tp && currentIndex >= tp->head.size()) { - if (FFlag::LuauUseCommittingTxnLog) - { - currentTypePack = tp->tail ? log->follow(*tp->tail) : nullptr; - tp = currentTypePack ? log->getMutable(currentTypePack) : nullptr; - } - else - { - currentTypePack = tp->tail ? follow(*tp->tail) : nullptr; - tp = currentTypePack ? get(currentTypePack) : nullptr; - } + currentTypePack = tp->tail ? log->follow(*tp->tail) : nullptr; + tp = currentTypePack ? log->getMutable(currentTypePack) : nullptr; currentIndex = 0; } diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 6c29486a4..7b781f269 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -15,30 +15,28 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); LUAU_FASTFLAG(LuauImmutableTypes) -LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) -LUAU_FASTFLAGVARIABLE(LuauFollowWithCommittingTxnLogInAnyUnification, false) LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree, false) LUAU_FASTFLAGVARIABLE(LuauDifferentOrderOfUnificationDoesntMatter, false) -LUAU_FASTFLAGVARIABLE(LuauTxnLogSeesTypePacks2, true) +LUAU_FASTFLAGVARIABLE(LuauTxnLogSeesTypePacks2, false) +LUAU_FASTFLAGVARIABLE(LuauTxnLogCheckForInvalidation, false) +LUAU_FASTFLAGVARIABLE(LuauTxnLogDontRetryForIndexers, false) namespace Luau { struct PromoteTypeLevels { - DEPRECATED_TxnLog& DEPRECATED_log; TxnLog& log; const TypeArena* typeArena = nullptr; TypeLevel minLevel; - explicit PromoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel) - : DEPRECATED_log(DEPRECATED_log) - , log(log) + PromoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel) + : log(log) , typeArena(typeArena) , minLevel(minLevel) { @@ -50,15 +48,7 @@ struct PromoteTypeLevels LUAU_ASSERT(t); if (minLevel.subsumesStrict(t->level)) { - if (FFlag::LuauUseCommittingTxnLog) - { - log.changeLevel(ty, minLevel); - } - else - { - DEPRECATED_log(ty); - t->level = minLevel; - } + log.changeLevel(ty, minLevel); } } @@ -81,10 +71,10 @@ struct PromoteTypeLevels { // Surprise, it's actually a BoundTypeVar that hasn't been committed yet. // Calling getMutable on this will trigger an assertion. - if (FFlag::LuauUseCommittingTxnLog && !log.is(ty)) + if (!log.is(ty)) return true; - promote(ty, FFlag::LuauUseCommittingTxnLog ? log.getMutable(ty) : getMutable(ty)); + promote(ty, log.getMutable(ty)); return true; } @@ -94,7 +84,7 @@ struct PromoteTypeLevels if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) return false; - promote(ty, FFlag::LuauUseCommittingTxnLog ? log.getMutable(ty) : getMutable(ty)); + promote(ty, log.getMutable(ty)); return true; } @@ -107,7 +97,7 @@ struct PromoteTypeLevels if (ttv.state != TableState::Free && ttv.state != TableState::Generic) return true; - promote(ty, FFlag::LuauUseCommittingTxnLog ? log.getMutable(ty) : getMutable(ty)); + promote(ty, log.getMutable(ty)); return true; } @@ -115,33 +105,33 @@ struct PromoteTypeLevels { // Surprise, it's actually a BoundTypePack that hasn't been committed yet. // Calling getMutable on this will trigger an assertion. - if (FFlag::LuauUseCommittingTxnLog && !log.is(tp)) + if (!log.is(tp)) return true; - promote(tp, FFlag::LuauUseCommittingTxnLog ? log.getMutable(tp) : getMutable(tp)); + promote(tp, log.getMutable(tp)); return true; } }; -static void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypeId ty) +static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypeId ty) { // Type levels of types from other modules are already global, so we don't need to promote anything inside if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) return; - PromoteTypeLevels ptl{DEPRECATED_log, log, typeArena, minLevel}; + PromoteTypeLevels ptl{log, typeArena, minLevel}; DenseHashSet seen{nullptr}; visitTypeVarOnce(ty, ptl, seen); } // TODO: use this and make it static. -void promoteTypeLevels(DEPRECATED_TxnLog& DEPRECATED_log, TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp) +void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp) { // Type levels of types from other modules are already global, so we don't need to promote anything inside if (FFlag::LuauImmutableTypes && tp->owningArena != typeArena) return; - PromoteTypeLevels ptl{DEPRECATED_log, log, typeArena, minLevel}; + PromoteTypeLevels ptl{log, typeArena, minLevel}; DenseHashSet seen{nullptr}; visitTypeVarOnce(tp, ptl, seen); } @@ -251,7 +241,7 @@ struct SkipCacheForType bool Widen::isDirty(TypeId ty) { - return FFlag::LuauUseCommittingTxnLog ? log->is(ty) : bool(get(ty)); + return log->is(ty); } bool Widen::isDirty(TypePackId) @@ -262,7 +252,7 @@ bool Widen::isDirty(TypePackId) TypeId Widen::clean(TypeId ty) { LUAU_ASSERT(isDirty(ty)); - auto stv = FFlag::LuauUseCommittingTxnLog ? log->getMutable(ty) : getMutable(ty); + auto stv = log->getMutable(ty); LUAU_ASSERT(stv); if (get(stv)) @@ -284,11 +274,11 @@ bool Widen::ignoreChildren(TypeId ty) { // Sometimes we unify ("hi") -> free1 with (free2) -> free3, so don't ignore functions. // TODO: should we be doing this? we would need to rework how checkCallOverload does the unification. - if (FFlag::LuauUseCommittingTxnLog ? log->is(ty) : bool(get(ty))) + if (log->is(ty)) return false; // We only care about unions. - return !(FFlag::LuauUseCommittingTxnLog ? log->is(ty) : bool(get(ty))); + return !log->is(ty); } static std::optional hasUnificationTooComplex(const ErrorVec& errors) @@ -335,7 +325,6 @@ Unifier::Unifier(TypeArena* types, Mode mode, std::vector(superTy); - auto subFree = getMutable(subTy); - - if (FFlag::LuauUseCommittingTxnLog) - { - superFree = log.getMutable(superTy); - subFree = log.getMutable(subTy); - } + auto superFree = log.getMutable(superTy); + auto subFree = log.getMutable(subTy); if (superFree && subFree && superFree->level.subsumes(subFree->level)) { occursCheck(subTy, superTy); // The occurrence check might have caused superTy no longer to be a free type - bool occursFailed = false; - if (FFlag::LuauUseCommittingTxnLog) - occursFailed = bool(log.getMutable(subTy)); - else - occursFailed = bool(get(subTy)); + bool occursFailed = bool(log.getMutable(subTy)); if (!occursFailed) { - if (FFlag::LuauUseCommittingTxnLog) - { - log.replace(subTy, BoundTypeVar(superTy)); - } - else - { - DEPRECATED_log(subTy); - *asMutable(subTy) = BoundTypeVar(superTy); - } + log.replace(subTy, BoundTypeVar(superTy)); } return; } else if (superFree && subFree) { - if (!FFlag::LuauErrorRecoveryType && !FFlag::LuauUseCommittingTxnLog) - { - DEPRECATED_log(superTy); - subFree->level = min(subFree->level, superFree->level); - } - occursCheck(superTy, subTy); - bool occursFailed = false; - if (FFlag::LuauUseCommittingTxnLog) - occursFailed = bool(log.getMutable(superTy)); - else - occursFailed = bool(get(superTy)); - - if (!FFlag::LuauErrorRecoveryType && !FFlag::LuauUseCommittingTxnLog) - { - *asMutable(superTy) = BoundTypeVar(subTy); - return; - } + bool occursFailed = bool(log.getMutable(superTy)); if (!occursFailed) { - if (FFlag::LuauUseCommittingTxnLog) - { - if (superFree->level.subsumes(subFree->level)) - { - log.changeLevel(subTy, superFree->level); - } - - log.replace(superTy, BoundTypeVar(subTy)); - } - else + if (superFree->level.subsumes(subFree->level)) { - DEPRECATED_log(superTy); - *asMutable(superTy) = BoundTypeVar(subTy); - subFree->level = min(subFree->level, superFree->level); + log.changeLevel(subTy, superFree->level); } + + log.replace(superTy, BoundTypeVar(subTy)); } return; @@ -460,14 +398,10 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool TypeLevel superLevel = superFree->level; occursCheck(superTy, subTy); - bool occursFailed = false; - if (FFlag::LuauUseCommittingTxnLog) - occursFailed = bool(log.getMutable(superTy)); - else - occursFailed = bool(get(superTy)); + bool occursFailed = bool(log.getMutable(superTy)); // Unification can't change the level of a generic. - auto subGeneric = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy); + auto subGeneric = log.getMutable(subTy); if (subGeneric && !subGeneric->level.subsumes(superLevel)) { // TODO: a more informative error message? CLI-39912 @@ -478,18 +412,8 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool // The occurrence check might have caused superTy no longer to be a free type if (!occursFailed) { - if (FFlag::LuauUseCommittingTxnLog) - { - promoteTypeLevels(DEPRECATED_log, log, types, superLevel, subTy); - log.replace(superTy, BoundTypeVar(widen(subTy))); - } - else - { - promoteTypeLevels(DEPRECATED_log, log, types, superLevel, subTy); - - DEPRECATED_log(superTy); - *asMutable(superTy) = BoundTypeVar(widen(subTy)); - } + promoteTypeLevels(log, types, superLevel, subTy); + log.replace(superTy, BoundTypeVar(widen(subTy))); } return; @@ -499,14 +423,10 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool TypeLevel subLevel = subFree->level; occursCheck(subTy, superTy); - bool occursFailed = false; - if (FFlag::LuauUseCommittingTxnLog) - occursFailed = bool(log.getMutable(subTy)); - else - occursFailed = bool(get(subTy)); + bool occursFailed = bool(log.getMutable(subTy)); // Unification can't change the level of a generic. - auto superGeneric = FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy); + auto superGeneric = log.getMutable(superTy); if (superGeneric && !superGeneric->level.subsumes(subFree->level)) { // TODO: a more informative error message? CLI-39912 @@ -516,18 +436,8 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (!occursFailed) { - if (FFlag::LuauUseCommittingTxnLog) - { - promoteTypeLevels(DEPRECATED_log, log, types, subLevel, superTy); - log.replace(subTy, BoundTypeVar(superTy)); - } - else - { - promoteTypeLevels(DEPRECATED_log, log, types, subLevel, superTy); - - DEPRECATED_log(subTy); - *asMutable(subTy) = BoundTypeVar(superTy); - } + promoteTypeLevels(log, types, subLevel, superTy); + log.replace(subTy, BoundTypeVar(superTy)); } return; @@ -550,55 +460,38 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool // Here, we assume that the types unify. If they do not, we will find out as we roll back // the stack. - if (FFlag::LuauUseCommittingTxnLog) - { - if (log.haveSeen(superTy, subTy)) - return; - - log.pushSeen(superTy, subTy); - } - else - { - if (DEPRECATED_log.haveSeen(superTy, subTy)) - return; + if (log.haveSeen(superTy, subTy)) + return; - DEPRECATED_log.pushSeen(superTy, subTy); - } + log.pushSeen(superTy, subTy); - if (const UnionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) + if (const UnionTypeVar* uv = log.getMutable(subTy)) { tryUnifyUnionWithType(subTy, uv, superTy); } - else if (const UnionTypeVar* uv = FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) + else if (const UnionTypeVar* uv = log.getMutable(superTy)) { tryUnifyTypeWithUnion(subTy, superTy, uv, cacheEnabled, isFunctionCall); } - else if (const IntersectionTypeVar* uv = - FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) + else if (const IntersectionTypeVar* uv = log.getMutable(superTy)) { tryUnifyTypeWithIntersection(subTy, superTy, uv); } - else if (const IntersectionTypeVar* uv = - FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) + else if (const IntersectionTypeVar* uv = log.getMutable(subTy)) { tryUnifyIntersectionWithType(subTy, uv, superTy, cacheEnabled, isFunctionCall); } - else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || - (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) + else if (log.getMutable(superTy) && log.getMutable(subTy)) tryUnifyPrimitives(subTy, superTy); - else if (FFlag::LuauSingletonTypes && - ((FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy)) || - (FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy))) && - (FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy))) + else if (FFlag::LuauSingletonTypes && (log.getMutable(superTy) || log.getMutable(superTy)) && + log.getMutable(subTy)) tryUnifySingletons(subTy, superTy); - else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || - (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) + else if (log.getMutable(superTy) && log.getMutable(subTy)) tryUnifyFunctions(subTy, superTy, isFunctionCall); - else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy) && log.getMutable(subTy)) || - (!FFlag::LuauUseCommittingTxnLog && get(superTy) && get(subTy))) + else if (log.getMutable(superTy) && log.getMutable(subTy)) { tryUnifyTables(subTy, superTy, isIntersection); @@ -607,29 +500,23 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool } // tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical. - else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy)) || - (!FFlag::LuauUseCommittingTxnLog && get(superTy))) + else if (log.getMutable(superTy)) tryUnifyWithMetatable(subTy, superTy, /*reversed*/ false); - else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(subTy)) || - (!FFlag::LuauUseCommittingTxnLog && get(subTy))) + else if (log.getMutable(subTy)) tryUnifyWithMetatable(superTy, subTy, /*reversed*/ true); - else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(superTy)) || - (!FFlag::LuauUseCommittingTxnLog && get(superTy))) + else if (log.getMutable(superTy)) tryUnifyWithClass(subTy, superTy, /*reversed*/ false); // Unification of nonclasses with classes is almost, but not quite symmetrical. // The order in which we perform this test is significant in the case that both types are classes. - else if ((FFlag::LuauUseCommittingTxnLog && log.getMutable(subTy)) || (!FFlag::LuauUseCommittingTxnLog && get(subTy))) + else if (log.getMutable(subTy)) tryUnifyWithClass(subTy, superTy, /*reversed*/ true); else reportError(TypeError{location, TypeMismatch{superTy, subTy}}); - if (FFlag::LuauUseCommittingTxnLog) - log.popSeen(superTy, subTy); - else - DEPRECATED_log.popSeen(superTy, subTy); + log.popSeen(superTy, subTy); } void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId superTy) @@ -660,28 +547,12 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter) { - if (!FFlag::LuauUseCommittingTxnLog) - innerState.DEPRECATED_log.rollback(); } else { - if (FFlag::LuauUseCommittingTxnLog) - { - if (i == count - 1) - { - log.concat(std::move(innerState.log)); - } - } - else + if (i == count - 1) { - if (i != count - 1) - { - innerState.DEPRECATED_log.rollback(); - } - else - { - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - } + log.concat(std::move(innerState.log)); } ++i; @@ -692,7 +563,7 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId if (FFlag::LuauDifferentOrderOfUnificationDoesntMatter) { auto tryBind = [this, subTy](TypeId superOption) { - superOption = FFlag::LuauUseCommittingTxnLog ? log.follow(superOption) : follow(superOption); + superOption = log.follow(superOption); // just skip if the superOption is not free-ish. auto ttv = log.getMutable(superOption); @@ -701,36 +572,17 @@ void Unifier::tryUnifyUnionWithType(TypeId subTy, const UnionTypeVar* uv, TypeId // Since we have already checked if S <: T, checking it again will not queue up the type for replacement. // So we'll have to do it ourselves. We assume they unified cleanly if they are still in the seen set. - if (FFlag::LuauUseCommittingTxnLog) - { - if (log.haveSeen(subTy, superOption)) - { - // TODO: would it be nice for TxnLog::replace to do this? - if (log.is(superOption)) - log.bindTable(superOption, subTy); - else - log.replace(superOption, *subTy); - } - } - else + if (log.haveSeen(subTy, superOption)) { - if (DEPRECATED_log.haveSeen(subTy, superOption)) - { - if (auto ttv = getMutable(superOption)) - { - DEPRECATED_log(ttv); - ttv->boundTo = subTy; - } - else - { - DEPRECATED_log(superOption); - *asMutable(superOption) = BoundTypeVar(subTy); - } - } + // TODO: would it be nice for TxnLog::replace to do this? + if (log.is(superOption)) + log.bindTable(superOption, subTy); + else + log.replace(superOption, *subTy); } }; - if (auto utv = (FFlag::LuauUseCommittingTxnLog ? log.getMutable(superTy) : get(superTy))) + if (auto utv = log.getMutable(superTy)) { for (TypeId ty : utv) tryBind(ty); @@ -815,10 +667,7 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp if (innerState.errors.empty()) { found = true; - if (FFlag::LuauUseCommittingTxnLog) - log.concat(std::move(innerState.log)); - else - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + log.concat(std::move(innerState.log)); break; } @@ -833,9 +682,6 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp if (!failedOption) failedOption = {innerState.errors.front()}; } - - if (!FFlag::LuauUseCommittingTxnLog) - innerState.DEPRECATED_log.rollback(); } if (unificationTooComplex) @@ -870,10 +716,7 @@ void Unifier::tryUnifyTypeWithIntersection(TypeId subTy, TypeId superTy, const I firstFailedOption = {innerState.errors.front()}; } - if (FFlag::LuauUseCommittingTxnLog) - log.concat(std::move(innerState.log)); - else - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + log.concat(std::move(innerState.log)); } if (unificationTooComplex) @@ -915,19 +758,13 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV if (innerState.errors.empty()) { found = true; - if (FFlag::LuauUseCommittingTxnLog) - log.concat(std::move(innerState.log)); - else - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + log.concat(std::move(innerState.log)); break; } else if (auto e = hasUnificationTooComplex(innerState.errors)) { unificationTooComplex = e; } - - if (!FFlag::LuauUseCommittingTxnLog) - innerState.DEPRECATED_log.rollback(); } if (unificationTooComplex) @@ -971,78 +808,6 @@ void Unifier::cacheResult(TypeId subTy, TypeId superTy) sharedState.cachedUnify.insert({subTy, superTy}); } -struct DEPRECATED_WeirdIter -{ - TypePackId packId; - const TypePack* pack; - size_t index; - bool growing; - TypeLevel level; - - DEPRECATED_WeirdIter(TypePackId packId) - : packId(packId) - , pack(get(packId)) - , index(0) - , growing(false) - { - while (pack && pack->head.empty() && pack->tail) - { - packId = *pack->tail; - pack = get(packId); - } - } - - DEPRECATED_WeirdIter(const DEPRECATED_WeirdIter&) = default; - - const TypeId& operator*() - { - LUAU_ASSERT(good()); - return pack->head[index]; - } - - bool good() const - { - return pack != nullptr && index < pack->head.size(); - } - - bool advance() - { - if (!pack) - return good(); - - if (index < pack->head.size()) - ++index; - - if (growing || index < pack->head.size()) - return good(); - - if (pack->tail) - { - packId = follow(*pack->tail); - pack = get(packId); - index = 0; - } - - return good(); - } - - bool canGrow() const - { - return nullptr != get(packId); - } - - void grow(TypePackId newTail) - { - LUAU_ASSERT(canGrow()); - level = get(packId)->level; - *asMutable(packId) = Unifiable::Bound(newTail); - packId = newTail; - pack = get(newTail); - index = 0; - growing = true; - } -}; - struct WeirdIter { TypePackId packId; @@ -1141,9 +906,6 @@ ErrorVec Unifier::canUnify(TypeId subTy, TypeId superTy) Unifier s = makeChildUnifier(); s.tryUnify_(subTy, superTy); - if (!FFlag::LuauUseCommittingTxnLog) - s.DEPRECATED_log.rollback(); - return s.errors; } @@ -1152,9 +914,6 @@ ErrorVec Unifier::canUnify(TypePackId subTy, TypePackId superTy, bool isFunction Unifier s = makeChildUnifier(); s.tryUnify_(subTy, superTy, isFunctionCall); - if (!FFlag::LuauUseCommittingTxnLog) - s.DEPRECATED_log.rollback(); - return s.errors; } @@ -1200,441 +959,229 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal return; } - if (FFlag::LuauUseCommittingTxnLog) + superTp = log.follow(superTp); + subTp = log.follow(subTp); + + while (auto tp = log.getMutable(subTp)) + { + if (tp->head.empty() && tp->tail) + subTp = log.follow(*tp->tail); + else + break; + } + + while (auto tp = log.getMutable(superTp)) + { + if (tp->head.empty() && tp->tail) + superTp = log.follow(*tp->tail); + else + break; + } + + if (superTp == subTp) + return; + + if (FFlag::LuauTxnLogSeesTypePacks2 && log.haveSeen(superTp, subTp)) + return; + + if (log.getMutable(superTp)) { - superTp = log.follow(superTp); - subTp = log.follow(subTp); + occursCheck(superTp, subTp); - while (auto tp = log.getMutable(subTp)) + if (!log.getMutable(superTp)) { - if (tp->head.empty() && tp->tail) - subTp = log.follow(*tp->tail); - else - break; + log.replace(superTp, Unifiable::Bound(subTp)); } + } + else if (log.getMutable(subTp)) + { + occursCheck(subTp, superTp); - while (auto tp = log.getMutable(superTp)) + if (!log.getMutable(subTp)) { - if (tp->head.empty() && tp->tail) - superTp = log.follow(*tp->tail); - else - break; + log.replace(subTp, Unifiable::Bound(superTp)); } + } + else if (log.getMutable(superTp)) + tryUnifyWithAny(subTp, superTp); + else if (log.getMutable(subTp)) + tryUnifyWithAny(superTp, subTp); + else if (log.getMutable(superTp)) + tryUnifyVariadics(subTp, superTp, false); + else if (log.getMutable(subTp)) + tryUnifyVariadics(superTp, subTp, true); + else if (log.getMutable(superTp) && log.getMutable(subTp)) + { + auto superTpv = log.getMutable(superTp); + auto subTpv = log.getMutable(subTp); - if (superTp == subTp) - return; + // If the size of two heads does not match, but both packs have free tail + // We set the sentinel variable to say so to avoid growing it forever. + auto [superTypes, superTail] = logAwareFlatten(superTp, log); + auto [subTypes, subTail] = logAwareFlatten(subTp, log); - if (FFlag::LuauTxnLogSeesTypePacks2 && log.haveSeen(superTp, subTp)) - return; + bool noInfiniteGrowth = (superTypes.size() != subTypes.size()) && (superTail && log.getMutable(*superTail)) && + (subTail && log.getMutable(*subTail)); + + auto superIter = WeirdIter(superTp, log); + auto subIter = WeirdIter(subTp, log); - if (log.getMutable(superTp)) + auto mkFreshType = [this](TypeLevel level) { + return types->freshType(level); + }; + + const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); + + int loopCount = 0; + + do { - occursCheck(superTp, subTp); + if (FInt::LuauTypeInferTypePackLoopLimit > 0 && loopCount >= FInt::LuauTypeInferTypePackLoopLimit) + ice("Detected possibly infinite TypePack growth"); - if (!log.getMutable(superTp)) + ++loopCount; + + if (superIter.good() && subIter.growing) { - log.replace(superTp, Unifiable::Bound(subTp)); + subIter.pushType(mkFreshType(subIter.level)); } - } - else if (log.getMutable(subTp)) - { - occursCheck(subTp, superTp); - if (!log.getMutable(subTp)) + if (subIter.good() && superIter.growing) { - log.replace(subTp, Unifiable::Bound(superTp)); + superIter.pushType(mkFreshType(superIter.level)); } - } - else if (log.getMutable(superTp)) - tryUnifyWithAny(subTp, superTp); - else if (log.getMutable(subTp)) - tryUnifyWithAny(superTp, subTp); - else if (log.getMutable(superTp)) - tryUnifyVariadics(subTp, superTp, false); - else if (log.getMutable(subTp)) - tryUnifyVariadics(superTp, subTp, true); - else if (log.getMutable(superTp) && log.getMutable(subTp)) - { - auto superTpv = log.getMutable(superTp); - auto subTpv = log.getMutable(subTp); - - // If the size of two heads does not match, but both packs have free tail - // We set the sentinel variable to say so to avoid growing it forever. - auto [superTypes, superTail] = logAwareFlatten(superTp, log); - auto [subTypes, subTail] = logAwareFlatten(subTp, log); - bool noInfiniteGrowth = (superTypes.size() != subTypes.size()) && (superTail && log.getMutable(*superTail)) && - (subTail && log.getMutable(*subTail)); + if (superIter.good() && subIter.good()) + { + tryUnify_(*subIter, *superIter); - auto superIter = WeirdIter(superTp, log); - auto subIter = WeirdIter(subTp, log); + if (!errors.empty() && !firstPackErrorPos) + firstPackErrorPos = loopCount; - auto mkFreshType = [this](TypeLevel level) { - return types->freshType(level); - }; + superIter.advance(); + subIter.advance(); + continue; + } - const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); + // If both are at the end, we're done + if (!superIter.good() && !subIter.good()) + { + if (subTpv->tail && superTpv->tail) + { + tryUnify_(*subTpv->tail, *superTpv->tail); + break; + } - int loopCount = 0; + const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; + const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; + if (lFreeTail) + tryUnify_(emptyTp, *superTpv->tail); + else if (rFreeTail) + tryUnify_(emptyTp, *subTpv->tail); - do - { - if (FInt::LuauTypeInferTypePackLoopLimit > 0 && loopCount >= FInt::LuauTypeInferTypePackLoopLimit) - ice("Detected possibly infinite TypePack growth"); + break; + } - ++loopCount; + // If both tails are free, bind one to the other and call it a day + if (superIter.canGrow() && subIter.canGrow()) + return tryUnify_(*subIter.pack->tail, *superIter.pack->tail); - if (superIter.good() && subIter.growing) + // If just one side is free on its tail, grow it to fit the other side. + // FIXME: The tail-most tail of the growing pack should be the same as the tail-most tail of the non-growing pack. + if (superIter.canGrow()) + superIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + else if (subIter.canGrow()) + subIter.grow(types->addTypePack(TypePackVar(TypePack{}))); + else + { + // A union type including nil marks an optional argument + if (superIter.good() && isOptional(*superIter)) { - subIter.pushType(mkFreshType(subIter.level)); + superIter.advance(); + continue; } - - if (subIter.good() && superIter.growing) + else if (subIter.good() && isOptional(*subIter)) { - superIter.pushType(mkFreshType(superIter.level)); + subIter.advance(); + continue; } - if (superIter.good() && subIter.good()) + // In nonstrict mode, any also marks an optional argument. + else if (superIter.good() && isNonstrictMode() && log.getMutable(log.follow(*superIter))) { - tryUnify_(*subIter, *superIter); - - if (!errors.empty() && !firstPackErrorPos) - firstPackErrorPos = loopCount; - superIter.advance(); - subIter.advance(); continue; } - // If both are at the end, we're done - if (!superIter.good() && !subIter.good()) + if (log.getMutable(superIter.packId)) + { + tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); + return; + } + + if (log.getMutable(subIter.packId)) { - if (subTpv->tail && superTpv->tail) - { - tryUnify_(*subTpv->tail, *superTpv->tail); - break; - } - - const bool lFreeTail = superTpv->tail && log.getMutable(log.follow(*superTpv->tail)) != nullptr; - const bool rFreeTail = subTpv->tail && log.getMutable(log.follow(*subTpv->tail)) != nullptr; - if (lFreeTail) - tryUnify_(emptyTp, *superTpv->tail); - else if (rFreeTail) - tryUnify_(emptyTp, *subTpv->tail); + tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); + return; + } - break; + if (!isFunctionCall && subIter.good()) + { + // Sometimes it is ok to pass too many arguments + return; } - // If both tails are free, bind one to the other and call it a day - if (superIter.canGrow() && subIter.canGrow()) - return tryUnify_(*subIter.pack->tail, *superIter.pack->tail); + // This is a bit weird because we don't actually know expected vs actual. We just know + // subtype vs supertype. If we are checking the values returned by a function, we swap + // these to produce the expected error message. + size_t expectedSize = size(superTp); + size_t actualSize = size(subTp); + if (ctx == CountMismatch::Result) + std::swap(expectedSize, actualSize); + reportError(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); - // If just one side is free on its tail, grow it to fit the other side. - // FIXME: The tail-most tail of the growing pack should be the same as the tail-most tail of the non-growing pack. - if (superIter.canGrow()) - superIter.grow(types->addTypePack(TypePackVar(TypePack{}))); - else if (subIter.canGrow()) - subIter.grow(types->addTypePack(TypePackVar(TypePack{}))); - else + while (superIter.good()) { - // A union type including nil marks an optional argument - if (superIter.good() && isOptional(*superIter)) - { - superIter.advance(); - continue; - } - else if (subIter.good() && isOptional(*subIter)) - { - subIter.advance(); - continue; - } - - // In nonstrict mode, any also marks an optional argument. - else if (superIter.good() && isNonstrictMode() && log.getMutable(log.follow(*superIter))) - { - superIter.advance(); - continue; - } - - if (log.getMutable(superIter.packId)) - { - tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); - return; - } - - if (log.getMutable(subIter.packId)) - { - tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); - return; - } - - if (!isFunctionCall && subIter.good()) - { - // Sometimes it is ok to pass too many arguments - return; - } - - // This is a bit weird because we don't actually know expected vs actual. We just know - // subtype vs supertype. If we are checking the values returned by a function, we swap - // these to produce the expected error message. - size_t expectedSize = size(superTp); - size_t actualSize = size(subTp); - if (ctx == CountMismatch::Result) - std::swap(expectedSize, actualSize); - reportError(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); - - while (superIter.good()) - { - tryUnify_(*superIter, getSingletonTypes().errorRecoveryType()); - superIter.advance(); - } - - while (subIter.good()) - { - tryUnify_(*subIter, getSingletonTypes().errorRecoveryType()); - subIter.advance(); - } + tryUnify_(*superIter, getSingletonTypes().errorRecoveryType()); + superIter.advance(); + } - return; + while (subIter.good()) + { + tryUnify_(*subIter, getSingletonTypes().errorRecoveryType()); + subIter.advance(); } - } while (!noInfiniteGrowth); - } - else - { - reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); - } + return; + } + + } while (!noInfiniteGrowth); } else { - superTp = follow(superTp); - subTp = follow(subTp); + reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); + } +} - while (auto tp = get(subTp)) - { - if (tp->head.empty() && tp->tail) - subTp = follow(*tp->tail); - else - break; - } +void Unifier::tryUnifyPrimitives(TypeId subTy, TypeId superTy) +{ + const PrimitiveTypeVar* superPrim = get(superTy); + const PrimitiveTypeVar* subPrim = get(subTy); + if (!superPrim || !subPrim) + ice("passed non primitive types to unifyPrimitives"); - while (auto tp = get(superTp)) - { - if (tp->head.empty() && tp->tail) - superTp = follow(*tp->tail); - else - break; - } + if (superPrim->type != subPrim->type) + reportError(TypeError{location, TypeMismatch{superTy, subTy}}); +} - if (superTp == subTp) - return; +void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) +{ + const PrimitiveTypeVar* superPrim = get(superTy); + const SingletonTypeVar* superSingleton = get(superTy); + const SingletonTypeVar* subSingleton = get(subTy); - if (FFlag::LuauTxnLogSeesTypePacks2 && DEPRECATED_log.haveSeen(superTp, subTp)) - return; - - if (get(superTp)) - { - occursCheck(superTp, subTp); - - if (!get(superTp)) - { - DEPRECATED_log(superTp); - *asMutable(superTp) = Unifiable::Bound(subTp); - } - } - else if (get(subTp)) - { - occursCheck(subTp, superTp); - - if (!get(subTp)) - { - DEPRECATED_log(subTp); - *asMutable(subTp) = Unifiable::Bound(superTp); - } - } - - else if (get(superTp)) - tryUnifyWithAny(subTp, superTp); - - else if (get(subTp)) - tryUnifyWithAny(superTp, subTp); - - else if (get(superTp)) - tryUnifyVariadics(subTp, superTp, false); - else if (get(subTp)) - tryUnifyVariadics(superTp, subTp, true); - - else if (get(superTp) && get(subTp)) - { - auto superTpv = get(superTp); - auto subTpv = get(subTp); - - // If the size of two heads does not match, but both packs have free tail - // We set the sentinel variable to say so to avoid growing it forever. - auto [superTypes, superTail] = flatten(superTp); - auto [subTypes, subTail] = flatten(subTp); - - bool noInfiniteGrowth = - (superTypes.size() != subTypes.size()) && (superTail && get(*superTail)) && (subTail && get(*subTail)); - - auto superIter = DEPRECATED_WeirdIter{superTp}; - auto subIter = DEPRECATED_WeirdIter{subTp}; - - auto mkFreshType = [this](TypeLevel level) { - return types->freshType(level); - }; - - const TypePackId emptyTp = types->addTypePack(TypePack{{}, std::nullopt}); - - int loopCount = 0; - - do - { - if (FInt::LuauTypeInferTypePackLoopLimit > 0 && loopCount >= FInt::LuauTypeInferTypePackLoopLimit) - ice("Detected possibly infinite TypePack growth"); - - ++loopCount; - - if (superIter.good() && subIter.growing) - asMutable(subIter.pack)->head.push_back(mkFreshType(subIter.level)); - - if (subIter.good() && superIter.growing) - asMutable(superIter.pack)->head.push_back(mkFreshType(superIter.level)); - - if (superIter.good() && subIter.good()) - { - tryUnify_(*subIter, *superIter); - - if (!errors.empty() && !firstPackErrorPos) - firstPackErrorPos = loopCount; - - superIter.advance(); - subIter.advance(); - continue; - } - - // If both are at the end, we're done - if (!superIter.good() && !subIter.good()) - { - if (subTpv->tail && superTpv->tail) - { - tryUnify_(*subTpv->tail, *superTpv->tail); - break; - } - - const bool lFreeTail = superTpv->tail && get(follow(*superTpv->tail)) != nullptr; - const bool rFreeTail = subTpv->tail && get(follow(*subTpv->tail)) != nullptr; - if (lFreeTail) - tryUnify_(emptyTp, *superTpv->tail); - else if (rFreeTail) - tryUnify_(emptyTp, *subTpv->tail); - - break; - } - - // If both tails are free, bind one to the other and call it a day - if (superIter.canGrow() && subIter.canGrow()) - return tryUnify_(*subIter.pack->tail, *superIter.pack->tail); - - // If just one side is free on its tail, grow it to fit the other side. - // FIXME: The tail-most tail of the growing pack should be the same as the tail-most tail of the non-growing pack. - if (superIter.canGrow()) - superIter.grow(types->addTypePack(TypePackVar(TypePack{}))); - - else if (subIter.canGrow()) - subIter.grow(types->addTypePack(TypePackVar(TypePack{}))); - - else - { - // A union type including nil marks an optional argument - if (superIter.good() && isOptional(*superIter)) - { - superIter.advance(); - continue; - } - else if (subIter.good() && isOptional(*subIter)) - { - subIter.advance(); - continue; - } - - // In nonstrict mode, any also marks an optional argument. - else if (superIter.good() && isNonstrictMode() && get(follow(*superIter))) - { - superIter.advance(); - continue; - } - - if (get(superIter.packId)) - { - tryUnifyVariadics(subIter.packId, superIter.packId, false, int(subIter.index)); - return; - } - - if (get(subIter.packId)) - { - tryUnifyVariadics(superIter.packId, subIter.packId, true, int(superIter.index)); - return; - } - - if (!isFunctionCall && subIter.good()) - { - // Sometimes it is ok to pass too many arguments - return; - } - - // This is a bit weird because we don't actually know expected vs actual. We just know - // subtype vs supertype. If we are checking the values returned by a function, we swap - // these to produce the expected error message. - size_t expectedSize = size(superTp); - size_t actualSize = size(subTp); - if (ctx == CountMismatch::Result) - std::swap(expectedSize, actualSize); - reportError(TypeError{location, CountMismatch{expectedSize, actualSize, ctx}}); - - while (superIter.good()) - { - tryUnify_(*superIter, getSingletonTypes().errorRecoveryType()); - superIter.advance(); - } - - while (subIter.good()) - { - tryUnify_(*subIter, getSingletonTypes().errorRecoveryType()); - subIter.advance(); - } - - return; - } - - } while (!noInfiniteGrowth); - } - else - { - reportError(TypeError{location, GenericError{"Failed to unify type packs"}}); - } - } -} - -void Unifier::tryUnifyPrimitives(TypeId subTy, TypeId superTy) -{ - const PrimitiveTypeVar* superPrim = get(superTy); - const PrimitiveTypeVar* subPrim = get(subTy); - if (!superPrim || !subPrim) - ice("passed non primitive types to unifyPrimitives"); - - if (superPrim->type != subPrim->type) - reportError(TypeError{location, TypeMismatch{superTy, subTy}}); -} - -void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) -{ - const PrimitiveTypeVar* superPrim = get(superTy); - const SingletonTypeVar* superSingleton = get(superTy); - const SingletonTypeVar* subSingleton = get(subTy); - - if ((!superPrim && !superSingleton) || !subSingleton) - ice("passed non singleton/primitive types to unifySingletons"); + if ((!superPrim && !superSingleton) || !subSingleton) + ice("passed non singleton/primitive types to unifySingletons"); if (superSingleton && *superSingleton == *subSingleton) return; @@ -1650,14 +1197,8 @@ void Unifier::tryUnifySingletons(TypeId subTy, TypeId superTy) void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCall) { - FunctionTypeVar* superFunction = getMutable(superTy); - FunctionTypeVar* subFunction = getMutable(subTy); - - if (FFlag::LuauUseCommittingTxnLog) - { - superFunction = log.getMutable(superTy); - subFunction = log.getMutable(subTy); - } + FunctionTypeVar* superFunction = log.getMutable(superTy); + FunctionTypeVar* subFunction = log.getMutable(subTy); if (!superFunction || !subFunction) ice("passed non-function types to unifyFunction"); @@ -1680,20 +1221,14 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal for (size_t i = 0; i < numGenerics; i++) { - if (FFlag::LuauUseCommittingTxnLog) - log.pushSeen(superFunction->generics[i], subFunction->generics[i]); - else - DEPRECATED_log.pushSeen(superFunction->generics[i], subFunction->generics[i]); + log.pushSeen(superFunction->generics[i], subFunction->generics[i]); } if (FFlag::LuauTxnLogSeesTypePacks2) { for (size_t i = 0; i < numGenericPacks; i++) { - if (FFlag::LuauUseCommittingTxnLog) - log.pushSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); - else - DEPRECATED_log.pushSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); + log.pushSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); } } @@ -1734,14 +1269,7 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal reportError(TypeError{location, TypeMismatch{superTy, subTy, "", innerState.errors.front()}}); } - if (FFlag::LuauUseCommittingTxnLog) - { - log.concat(std::move(innerState.log)); - } - else - { - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - } + log.concat(std::move(innerState.log)); } else { @@ -1754,33 +1282,19 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal if (!FFlag::LuauImmutableTypes) { - if (FFlag::LuauUseCommittingTxnLog) + if (superFunction->definition && !subFunction->definition && !subTy->persistent) { - if (superFunction->definition && !subFunction->definition && !subTy->persistent) - { - PendingType* newSubTy = log.queue(subTy); - FunctionTypeVar* newSubFtv = getMutable(newSubTy); - LUAU_ASSERT(newSubFtv); - newSubFtv->definition = superFunction->definition; - } - else if (!superFunction->definition && subFunction->definition && !superTy->persistent) - { - PendingType* newSuperTy = log.queue(superTy); - FunctionTypeVar* newSuperFtv = getMutable(newSuperTy); - LUAU_ASSERT(newSuperFtv); - newSuperFtv->definition = subFunction->definition; - } + PendingType* newSubTy = log.queue(subTy); + FunctionTypeVar* newSubFtv = getMutable(newSubTy); + LUAU_ASSERT(newSubFtv); + newSubFtv->definition = superFunction->definition; } - else + else if (!superFunction->definition && subFunction->definition && !superTy->persistent) { - if (superFunction->definition && !subFunction->definition && !subTy->persistent) - { - subFunction->definition = superFunction->definition; - } - else if (!superFunction->definition && subFunction->definition && !superTy->persistent) - { - superFunction->definition = subFunction->definition; - } + PendingType* newSuperTy = log.queue(superTy); + FunctionTypeVar* newSuperFtv = getMutable(newSuperTy); + LUAU_ASSERT(newSuperFtv); + newSuperFtv->definition = subFunction->definition; } } @@ -1790,19 +1304,13 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal { for (int i = int(numGenericPacks) - 1; 0 <= i; i--) { - if (FFlag::LuauUseCommittingTxnLog) - log.popSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); - else - DEPRECATED_log.popSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); + log.popSeen(superFunction->genericPacks[i], subFunction->genericPacks[i]); } } for (int i = int(numGenerics) - 1; 0 <= i; i--) { - if (FFlag::LuauUseCommittingTxnLog) - log.popSeen(superFunction->generics[i], subFunction->generics[i]); - else - DEPRECATED_log.popSeen(superFunction->generics[i], subFunction->generics[i]); + log.popSeen(superFunction->generics[i], subFunction->generics[i]); } } @@ -1833,14 +1341,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (!FFlag::LuauTableSubtypingVariance2) return DEPRECATED_tryUnifyTables(subTy, superTy, isIntersection); - TableTypeVar* superTable = getMutable(superTy); - TableTypeVar* subTable = getMutable(subTy); - - if (FFlag::LuauUseCommittingTxnLog) - { - superTable = log.getMutable(superTy); - subTable = log.getMutable(subTy); - } + TableTypeVar* superTable = log.getMutable(superTy); + TableTypeVar* subTable = log.getMutable(subTy); if (!superTable || !subTable) ice("passed non-table types to unifyTables"); @@ -1855,8 +1357,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { auto subIter = subTable->props.find(propName); - bool isAny = - FFlag::LuauUseCommittingTxnLog ? log.getMutable(log.follow(superProp.type)) : get(follow(superProp.type)); + bool isAny = log.getMutable(log.follow(superProp.type)); if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type) && !isAny) missingProperties.push_back(propName); @@ -1877,8 +1378,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { auto superIter = superTable->props.find(propName); - bool isAny = - FFlag::LuauUseCommittingTxnLog ? log.getMutable(log.follow(subProp.type)) : get(follow(subProp.type)); + bool isAny = log.is(log.follow(subProp.type)); if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || (!isOptional(subProp.type) && !isAny))) extraProperties.push_back(propName); } @@ -1906,18 +1406,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); - if (FFlag::LuauUseCommittingTxnLog) - { - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); - } - else - { - if (innerState.errors.empty()) - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - else - innerState.DEPRECATED_log.rollback(); - } + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); } else if (subTable->indexer && maybeString(subTable->indexer->indexType)) { @@ -1931,18 +1421,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); - if (FFlag::LuauUseCommittingTxnLog) - { - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); - } - else - { - if (innerState.errors.empty()) - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - else - innerState.DEPRECATED_log.rollback(); - } + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); } else if ((!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && (isOptional(prop.type) || get(follow(prop.type)))) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` @@ -1953,22 +1433,30 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } else if (subTable->state == TableState::Free) { - if (FFlag::LuauUseCommittingTxnLog) - { - PendingType* pendingSub = log.queue(subTy); - TableTypeVar* ttv = getMutable(pendingSub); - LUAU_ASSERT(ttv); - ttv->props[name] = prop; - subTable = ttv; - } - else - { - DEPRECATED_log(subTy); - subTable->props[name] = prop; - } + PendingType* pendingSub = log.queue(subTy); + TableTypeVar* ttv = getMutable(pendingSub); + LUAU_ASSERT(ttv); + ttv->props[name] = prop; + subTable = ttv; } else missingProperties.push_back(name); + + if (FFlag::LuauTxnLogCheckForInvalidation) + { + // Recursive unification can change the txn log, and invalidate the old + // table. If we detect that this has happened, we start over, with the updated + // txn log. + TableTypeVar* newSuperTable = log.getMutable(superTy); + TableTypeVar* newSubTable = log.getMutable(subTy); + if (superTable != newSuperTable || subTable != newSubTable) + { + if (errors.empty()) + return tryUnifyTables(subTy, superTy, isIntersection); + else + return; + } + } } for (const auto& [name, prop] : subTable->props) @@ -1990,18 +1478,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) checkChildUnifierTypeMismatch(innerState.errors, name, superTy, subTy); - if (FFlag::LuauUseCommittingTxnLog) - { - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); - } - else - { - if (innerState.errors.empty()) - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - else - innerState.DEPRECATED_log.rollback(); - } + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); } else if (superTable->state == TableState::Unsealed) { @@ -2011,18 +1489,10 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) Property clone = prop; clone.type = deeplyOptional(clone.type); - if (FFlag::LuauUseCommittingTxnLog) - { - PendingType* pendingSuper = log.queue(superTy); - TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); - pendingSuperTtv->props[name] = clone; - superTable = pendingSuperTtv; - } - else - { - DEPRECATED_log(superTy); - superTable->props[name] = clone; - } + PendingType* pendingSuper = log.queue(superTy); + TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); + pendingSuperTtv->props[name] = clone; + superTable = pendingSuperTtv; } else if (variance == Covariant) { @@ -2032,21 +1502,29 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) } else if (superTable->state == TableState::Free) { - if (FFlag::LuauUseCommittingTxnLog) - { - PendingType* pendingSuper = log.queue(superTy); - TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); - pendingSuperTtv->props[name] = prop; - superTable = pendingSuperTtv; - } - else - { - DEPRECATED_log(superTy); - superTable->props[name] = prop; - } + PendingType* pendingSuper = log.queue(superTy); + TableTypeVar* pendingSuperTtv = getMutable(pendingSuper); + pendingSuperTtv->props[name] = prop; + superTable = pendingSuperTtv; } else extraProperties.push_back(name); + + if (FFlag::LuauTxnLogCheckForInvalidation) + { + // Recursive unification can change the txn log, and invalidate the old + // table. If we detect that this has happened, we start over, with the updated + // txn log. + TableTypeVar* newSuperTable = log.getMutable(superTy); + TableTypeVar* newSubTable = log.getMutable(subTy); + if (superTable != newSuperTable || subTable != newSubTable) + { + if (errors.empty()) + return tryUnifyTables(subTy, superTy, isIntersection); + else + return; + } + } } // Unify indexers @@ -2060,18 +1538,8 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); - if (FFlag::LuauUseCommittingTxnLog) - { - if (innerState.errors.empty()) - log.concat(std::move(innerState.log)); - } - else - { - if (innerState.errors.empty()) - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - else - innerState.DEPRECATED_log.rollback(); - } + if (innerState.errors.empty()) + log.concat(std::move(innerState.log)); } else if (superTable->indexer) { @@ -2081,15 +1549,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. // TODO: we only need to do this if the supertype's indexer is read/write // since that can add indexed elements. - if (FFlag::LuauUseCommittingTxnLog) - { - log.changeIndexer(subTy, superTable->indexer); - } - else - { - DEPRECATED_log(subTy); - subTable->indexer = superTable->indexer; - } + log.changeIndexer(subTy, superTable->indexer); } } else if (subTable->indexer && variance == Invariant) @@ -2097,15 +1557,29 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) // Symmetric if we are invariant if (superTable->state == TableState::Unsealed || superTable->state == TableState::Free) { - if (FFlag::LuauUseCommittingTxnLog) - { - log.changeIndexer(superTy, subTable->indexer); - } + log.changeIndexer(superTy, subTable->indexer); + } + } + + if (FFlag::LuauTxnLogDontRetryForIndexers) + { + // Changing the indexer can invalidate the table pointers. + superTable = log.getMutable(superTy); + subTable = log.getMutable(subTy); + } + else if (FFlag::LuauTxnLogCheckForInvalidation) + { + // Recursive unification can change the txn log, and invalidate the old + // table. If we detect that this has happened, we start over, with the updated + // txn log. + TableTypeVar* newSuperTable = log.getMutable(superTy); + TableTypeVar* newSubTable = log.getMutable(subTy); + if (superTable != newSuperTable || subTable != newSubTable) + { + if (errors.empty()) + return tryUnifyTables(subTy, superTy, isIntersection); else - { - DEPRECATED_log(superTy); - superTable->indexer = subTable->indexer; - } + return; } } @@ -2134,27 +1608,11 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (superTable->state == TableState::Free) { - if (FFlag::LuauUseCommittingTxnLog) - { - log.bindTable(superTy, subTy); - } - else - { - DEPRECATED_log(superTable); - superTable->boundTo = subTy; - } + log.bindTable(superTy, subTy); } else if (subTable->state == TableState::Free) { - if (FFlag::LuauUseCommittingTxnLog) - { - log.bindTable(subTy, superTy); - } - else - { - DEPRECATED_log(subTable); - subTable->boundTo = superTy; - } + log.bindTable(subTy, superTy); } } @@ -2197,14 +1655,8 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt Resetter resetter{&variance}; variance = Invariant; - TableTypeVar* superTable = getMutable(superTy); - TableTypeVar* subTable = getMutable(subTy); - - if (FFlag::LuauUseCommittingTxnLog) - { - superTable = log.getMutable(superTy); - subTable = log.getMutable(subTy); - } + TableTypeVar* superTable = log.getMutable(superTy); + TableTypeVar* subTable = log.getMutable(subTy); if (!superTable || !subTable) ice("passed non-table types to unifyTables"); @@ -2231,15 +1683,7 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt // avoid creating a cycle when the types are already pointing at each other if (follow(superTy) != follow(subTy)) { - if (FFlag::LuauUseCommittingTxnLog) - { - log.bindTable(superTy, subTy); - } - else - { - DEPRECATED_log(superTable); - superTable->boundTo = subTy; - } + log.bindTable(superTy, subTy); } return; } @@ -2268,14 +1712,7 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt // e.g. table.insert(t, 1) where t is a non-sealed table and doesn't have an indexer. if (subTable->state == TableState::Unsealed) { - if (FFlag::LuauUseCommittingTxnLog) - { - log.changeIndexer(subTy, superTable->indexer); - } - else - { - subTable->indexer = superTable->indexer; - } + log.changeIndexer(subTy, superTable->indexer); } else reportError(TypeError{location, CannotExtendTable{subTy, CannotExtendTable::Indexer}}); @@ -2295,14 +1732,8 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) { - TableTypeVar* freeTable = getMutable(superTy); - TableTypeVar* subTable = getMutable(subTy); - - if (FFlag::LuauUseCommittingTxnLog) - { - freeTable = log.getMutable(superTy); - subTable = log.getMutable(subTy); - } + TableTypeVar* freeTable = log.getMutable(superTy); + TableTypeVar* subTable = log.getMutable(subTy); if (!freeTable || !subTable) ice("passed non-table types to tryUnifyFreeTable"); @@ -2323,22 +1754,11 @@ void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) * I believe this is guaranteed to terminate eventually because this will * only happen when a free table is bound to another table. */ - if (FFlag::LuauUseCommittingTxnLog) - { - if (!log.getMutable(superTy) || !log.getMutable(subTy)) - return tryUnify_(subTy, superTy); - - if (TableTypeVar* pendingFreeTtv = log.getMutable(superTy); pendingFreeTtv && pendingFreeTtv->boundTo) - return tryUnify_(subTy, superTy); - } - else - { - if (!get(superTy) || !get(subTy)) - return tryUnify_(subTy, superTy); + if (!log.getMutable(superTy) || !log.getMutable(subTy)) + return tryUnify_(subTy, superTy); - if (freeTable->boundTo) - return tryUnify_(subTy, superTy); - } + if (TableTypeVar* pendingFreeTtv = log.getMutable(superTy); pendingFreeTtv && pendingFreeTtv->boundTo) + return tryUnify_(subTy, superTy); } else { @@ -2346,17 +1766,10 @@ void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) // properties than we previously thought. Else, it is an error. if (subTable->state == TableState::Free) { - if (FFlag::LuauUseCommittingTxnLog) - { - PendingType* pendingSub = log.queue(subTy); - TableTypeVar* pendingSubTtv = getMutable(pendingSub); - LUAU_ASSERT(pendingSubTtv); - pendingSubTtv->props.insert({freeName, freeProp}); - } - else - { - subTable->props.insert({freeName, freeProp}); - } + PendingType* pendingSub = log.queue(subTy); + TableTypeVar* pendingSubTtv = getMutable(pendingSub); + LUAU_ASSERT(pendingSubTtv); + pendingSubTtv->props.insert({freeName, freeProp}); } else reportError(TypeError{location, UnknownProperty{subTy, freeName}}); @@ -2370,47 +1783,23 @@ void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); - if (FFlag::LuauUseCommittingTxnLog) - log.concat(std::move(innerState.log)); - else - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + log.concat(std::move(innerState.log)); } else if (subTable->state == TableState::Free && freeTable->indexer) { - if (FFlag::LuauUseCommittingTxnLog) - { - log.changeIndexer(superTy, subTable->indexer); - } - else - { - freeTable->indexer = subTable->indexer; - } + log.changeIndexer(superTy, subTable->indexer); } if (!freeTable->boundTo && subTable->state != TableState::Free) { - if (FFlag::LuauUseCommittingTxnLog) - { - log.bindTable(superTy, subTy); - } - else - { - DEPRECATED_log(freeTable); - freeTable->boundTo = subTy; - } + log.bindTable(superTy, subTy); } } void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersection) { - TableTypeVar* superTable = getMutable(superTy); - TableTypeVar* subTable = getMutable(subTy); - - if (FFlag::LuauUseCommittingTxnLog) - { - superTable = log.getMutable(superTy); - subTable = log.getMutable(subTy); - } + TableTypeVar* superTable = log.getMutable(superTy); + TableTypeVar* subTable = log.getMutable(subTy); if (!superTable || !subTable) ice("passed non-table types to unifySealedTables"); @@ -2476,77 +1865,39 @@ void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersec if (superTable->indexer || subTable->indexer) { - if (FFlag::LuauUseCommittingTxnLog) + if (superTable->indexer && subTable->indexer) + innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); + else if (subTable->state == TableState::Unsealed) { - if (superTable->indexer && subTable->indexer) - innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); - else if (subTable->state == TableState::Unsealed) - { - if (superTable->indexer && !subTable->indexer) - { - log.changeIndexer(subTy, superTable->indexer); - } - } - else if (superTable->state == TableState::Unsealed) + if (superTable->indexer && !subTable->indexer) { - if (subTable->indexer && !superTable->indexer) - { - log.changeIndexer(superTy, subTable->indexer); - } - } - else if (superTable->indexer) - { - innerState.tryUnify_(getSingletonTypes().stringType, superTable->indexer->indexType); - for (const auto& [name, type] : subTable->props) - { - const auto& it = superTable->props.find(name); - if (it == superTable->props.end()) - innerState.tryUnify_(type.type, superTable->indexer->indexResultType); - } + log.changeIndexer(subTy, superTable->indexer); } - else - innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } - else + else if (superTable->state == TableState::Unsealed) { - if (superTable->indexer && subTable->indexer) - innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); - else if (subTable->state == TableState::Unsealed) + if (subTable->indexer && !superTable->indexer) { - if (superTable->indexer && !subTable->indexer) - subTable->indexer = superTable->indexer; - } - else if (superTable->state == TableState::Unsealed) - { - if (subTable->indexer && !superTable->indexer) - superTable->indexer = subTable->indexer; + log.changeIndexer(superTy, subTable->indexer); } - else if (superTable->indexer) + } + else if (superTable->indexer) + { + innerState.tryUnify_(getSingletonTypes().stringType, superTable->indexer->indexType); + for (const auto& [name, type] : subTable->props) { - innerState.tryUnify_(getSingletonTypes().stringType, superTable->indexer->indexType); - // We already try to unify properties in both tables. - // Skip those and just look for the ones remaining and see if they fit into the indexer. - for (const auto& [name, type] : subTable->props) - { - const auto& it = superTable->props.find(name); - if (it == superTable->props.end()) - innerState.tryUnify_(type.type, superTable->indexer->indexResultType); - } + const auto& it = superTable->props.find(name); + if (it == superTable->props.end()) + innerState.tryUnify_(type.type, superTable->indexer->indexResultType); } - else - innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } + else + innerState.reportError(TypeError{location, TypeMismatch{superTy, subTy}}); } - if (FFlag::LuauUseCommittingTxnLog) - { - if (!errorReported) - log.concat(std::move(innerState.log)); - } + if (!errorReported) + log.concat(std::move(innerState.log)); else - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - - if (errorReported) return; if (!missingPropertiesInSuper.empty()) @@ -2594,8 +1945,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) TypeError mismatchError = TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy}}; - if (const MetatableTypeVar* subMetatable = - FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : get(subTy)) + if (const MetatableTypeVar* subMetatable = log.getMutable(subTy)) { Unifier innerState = makeChildUnifier(); innerState.tryUnify_(subMetatable->table, superMetatable->table); @@ -2606,27 +1956,16 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) else if (!innerState.errors.empty()) reportError(TypeError{location, TypeMismatch{reversed ? subTy : superTy, reversed ? superTy : subTy, "", innerState.errors.front()}}); - if (FFlag::LuauUseCommittingTxnLog) - log.concat(std::move(innerState.log)); - else - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); + log.concat(std::move(innerState.log)); } - else if (TableTypeVar* subTable = FFlag::LuauUseCommittingTxnLog ? log.getMutable(subTy) : getMutable(subTy)) + else if (TableTypeVar* subTable = log.getMutable(subTy)) { switch (subTable->state) { case TableState::Free: { tryUnify_(subTy, superMetatable->table); - - if (FFlag::LuauUseCommittingTxnLog) - { - log.bindTable(subTy, superTy); - } - else - { - subTable->boundTo = superTy; - } + log.bindTable(subTy, superTy); break; } @@ -2637,8 +1976,7 @@ void Unifier::tryUnifyWithMetatable(TypeId subTy, TypeId superTy, bool reversed) reportError(mismatchError); } } - else if (FFlag::LuauUseCommittingTxnLog ? (log.getMutable(subTy) || log.getMutable(subTy)) - : (get(subTy) || get(subTy))) + else if (log.getMutable(subTy) || log.getMutable(subTy)) { } else @@ -2711,28 +2049,13 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) checkChildUnifierTypeMismatch(innerState.errors, propName, reversed ? subTy : superTy, reversed ? superTy : subTy); - if (FFlag::LuauUseCommittingTxnLog) + if (innerState.errors.empty()) { - if (innerState.errors.empty()) - { - log.concat(std::move(innerState.log)); - } - else - { - ok = false; - } + log.concat(std::move(innerState.log)); } else { - if (innerState.errors.empty()) - { - DEPRECATED_log.concat(std::move(innerState.DEPRECATED_log)); - } - else - { - ok = false; - innerState.DEPRECATED_log.rollback(); - } + ok = false; } } } @@ -2747,15 +2070,7 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) if (!ok) return; - if (FFlag::LuauUseCommittingTxnLog) - { - log.bindTable(subTy, superTy); - } - else - { - DEPRECATED_log(subTable); - subTable->boundTo = superTy; - } + log.bindTable(subTy, superTy); } else return fail(); @@ -2771,54 +2086,30 @@ static void queueTypePack(std::vector& queue, DenseHashSet& { while (true) { - a = FFlag::LuauFollowWithCommittingTxnLogInAnyUnification ? state.log.follow(a) : follow(a); + a = state.log.follow(a); if (seenTypePacks.find(a)) break; seenTypePacks.insert(a); - if (FFlag::LuauUseCommittingTxnLog) + if (state.log.getMutable(a)) { - if (state.log.getMutable(a)) - { - state.log.replace(a, Unifiable::Bound{anyTypePack}); - } - else if (auto tp = state.log.getMutable(a)) - { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; - } + state.log.replace(a, Unifiable::Bound{anyTypePack}); } - else + else if (auto tp = state.log.getMutable(a)) { - if (get(a)) - { - state.DEPRECATED_log(a); - *asMutable(a) = Unifiable::Bound{anyTypePack}; - } - else if (auto tp = get(a)) - { - queue.insert(queue.end(), tp->head.begin(), tp->head.end()); - if (tp->tail) - a = *tp->tail; - else - break; - } + queue.insert(queue.end(), tp->head.begin(), tp->head.end()); + if (tp->tail) + a = *tp->tail; + else + break; } } } void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool reversed, int subOffset) { - const VariadicTypePack* superVariadic = get(superTp); - - if (FFlag::LuauUseCommittingTxnLog) - { - superVariadic = log.getMutable(superTp); - } + const VariadicTypePack* superVariadic = log.getMutable(superTp); if (!superVariadic) ice("passed non-variadic pack to tryUnifyVariadics"); @@ -2843,15 +2134,7 @@ void Unifier::tryUnifyVariadics(TypePackId subTp, TypePackId superTp, bool rever TypePackId tail = follow(*maybeTail); if (get(tail)) { - if (FFlag::LuauUseCommittingTxnLog) - { - log.replace(tail, BoundTypePack(superTp)); - } - else - { - DEPRECATED_log(tail); - *asMutable(tail) = BoundTypePack{superTp}; - } + log.replace(tail, BoundTypePack(superTp)); } else if (const VariadicTypePack* vtp = get(tail)) { @@ -2882,103 +2165,54 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas { while (!queue.empty()) { - if (FFlag::LuauUseCommittingTxnLog) - { - TypeId ty = state.log.follow(queue.back()); - queue.pop_back(); + TypeId ty = state.log.follow(queue.back()); + queue.pop_back(); - // Types from other modules don't have free types - if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) - continue; - - if (seen.find(ty)) - continue; + // Types from other modules don't have free types + if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + continue; - seen.insert(ty); + if (seen.find(ty)) + continue; - if (state.log.getMutable(ty)) - { - state.log.replace(ty, BoundTypeVar{anyType}); - } - else if (auto fun = state.log.getMutable(ty)) - { - queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); - queueTypePack(queue, seenTypePacks, state, fun->retType, anyTypePack); - } - else if (auto table = state.log.getMutable(ty)) - { - for (const auto& [_name, prop] : table->props) - queue.push_back(prop.type); + seen.insert(ty); - if (table->indexer) - { - queue.push_back(table->indexer->indexType); - queue.push_back(table->indexer->indexResultType); - } - } - else if (auto mt = state.log.getMutable(ty)) - { - queue.push_back(mt->table); - queue.push_back(mt->metatable); - } - else if (state.log.getMutable(ty)) - { - // ClassTypeVars never contain free typevars. - } - else if (auto union_ = state.log.getMutable(ty)) - queue.insert(queue.end(), union_->options.begin(), union_->options.end()); - else if (auto intersection = state.log.getMutable(ty)) - queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); - else - { - } // Primitives, any, errors, and generics are left untouched. + if (state.log.getMutable(ty)) + { + state.log.replace(ty, BoundTypeVar{anyType}); } - else + else if (auto fun = state.log.getMutable(ty)) { - TypeId ty = follow(queue.back()); - queue.pop_back(); - if (seen.find(ty)) - continue; - seen.insert(ty); - - if (get(ty)) - { - state.DEPRECATED_log(ty); - *asMutable(ty) = BoundTypeVar{anyType}; - } - else if (auto fun = get(ty)) - { - queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); - queueTypePack(queue, seenTypePacks, state, fun->retType, anyTypePack); - } - else if (auto table = get(ty)) - { - for (const auto& [_name, prop] : table->props) - queue.push_back(prop.type); + queueTypePack(queue, seenTypePacks, state, fun->argTypes, anyTypePack); + queueTypePack(queue, seenTypePacks, state, fun->retType, anyTypePack); + } + else if (auto table = state.log.getMutable(ty)) + { + for (const auto& [_name, prop] : table->props) + queue.push_back(prop.type); - if (table->indexer) - { - queue.push_back(table->indexer->indexType); - queue.push_back(table->indexer->indexResultType); - } - } - else if (auto mt = get(ty)) - { - queue.push_back(mt->table); - queue.push_back(mt->metatable); - } - else if (get(ty)) + if (table->indexer) { - // ClassTypeVars never contain free typevars. + queue.push_back(table->indexer->indexType); + queue.push_back(table->indexer->indexResultType); } - else if (auto union_ = get(ty)) - queue.insert(queue.end(), union_->options.begin(), union_->options.end()); - else if (auto intersection = get(ty)) - queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); - else - { - } // Primitives, any, errors, and generics are left untouched. } + else if (auto mt = state.log.getMutable(ty)) + { + queue.push_back(mt->table); + queue.push_back(mt->metatable); + } + else if (state.log.getMutable(ty)) + { + // ClassTypeVars never contain free typevars. + } + else if (auto union_ = state.log.getMutable(ty)) + queue.insert(queue.end(), union_->options.begin(), union_->options.end()); + else if (auto intersection = state.log.getMutable(ty)) + queue.insert(queue.end(), intersection->parts.begin(), intersection->parts.end()); + else + { + } // Primitives, any, errors, and generics are left untouched. } } @@ -3038,79 +2272,39 @@ void Unifier::occursCheck(DenseHashSet& seen, TypeId needle, TypeId hays occursCheck(seen, needle, tv); }; - if (FFlag::LuauUseCommittingTxnLog) - { - needle = log.follow(needle); - haystack = log.follow(haystack); + needle = log.follow(needle); + haystack = log.follow(haystack); - if (seen.find(haystack)) - return; + if (seen.find(haystack)) + return; - seen.insert(haystack); + seen.insert(haystack); - if (log.getMutable(needle)) - return; + if (log.getMutable(needle)) + return; - if (!log.getMutable(needle)) - ice("Expected needle to be free"); + if (!log.getMutable(needle)) + ice("Expected needle to be free"); - if (needle == haystack) - { - reportError(TypeError{location, OccursCheckFailed{}}); - log.replace(needle, *getSingletonTypes().errorRecoveryType()); + if (needle == haystack) + { + reportError(TypeError{location, OccursCheckFailed{}}); + log.replace(needle, *getSingletonTypes().errorRecoveryType()); - return; - } + return; + } - if (log.getMutable(haystack)) - return; - else if (auto a = log.getMutable(haystack)) - { - for (TypeId ty : a->options) - check(ty); - } - else if (auto a = log.getMutable(haystack)) - { - for (TypeId ty : a->parts) - check(ty); - } + if (log.getMutable(haystack)) + return; + else if (auto a = log.getMutable(haystack)) + { + for (TypeId ty : a->options) + check(ty); } - else + else if (auto a = log.getMutable(haystack)) { - needle = follow(needle); - haystack = follow(haystack); - - if (seen.find(haystack)) - return; - - seen.insert(haystack); - - if (get(needle)) - return; - - if (!get(needle)) - ice("Expected needle to be free"); - - if (needle == haystack) - { - reportError(TypeError{location, OccursCheckFailed{}}); - DEPRECATED_log(needle); - *asMutable(needle) = *getSingletonTypes().errorRecoveryType(); - return; - } - - if (get(haystack)) - return; - else if (auto a = get(haystack)) - { - for (TypeId ty : a->options) - check(ty); - } - else if (auto a = get(haystack)) - { - for (TypeId ty : a->parts) - check(ty); - } + for (TypeId ty : a->parts) + check(ty); } } @@ -3123,87 +2317,45 @@ void Unifier::occursCheck(TypePackId needle, TypePackId haystack) void Unifier::occursCheck(DenseHashSet& seen, TypePackId needle, TypePackId haystack) { - if (FFlag::LuauUseCommittingTxnLog) - { - needle = log.follow(needle); - haystack = log.follow(haystack); - - if (seen.find(haystack)) - return; - - seen.insert(haystack); - - if (log.getMutable(needle)) - return; + needle = log.follow(needle); + haystack = log.follow(haystack); - if (!log.getMutable(needle)) - ice("Expected needle pack to be free"); + if (seen.find(haystack)) + return; - RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); + seen.insert(haystack); - while (!log.getMutable(haystack)) - { - if (needle == haystack) - { - reportError(TypeError{location, OccursCheckFailed{}}); - log.replace(needle, *getSingletonTypes().errorRecoveryTypePack()); + if (log.getMutable(needle)) + return; - return; - } + if (!log.getMutable(needle)) + ice("Expected needle pack to be free"); - if (auto a = get(haystack); a && a->tail) - { - haystack = log.follow(*a->tail); - continue; - } + RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); - break; - } - } - else + while (!log.getMutable(haystack)) { - needle = follow(needle); - haystack = follow(haystack); - - if (seen.find(haystack)) - return; - - seen.insert(haystack); + if (needle == haystack) + { + reportError(TypeError{location, OccursCheckFailed{}}); + log.replace(needle, *getSingletonTypes().errorRecoveryTypePack()); - if (get(needle)) return; + } - if (!get(needle)) - ice("Expected needle pack to be free"); - - RecursionLimiter _ra(&sharedState.counters.recursionCount, FInt::LuauTypeInferRecursionLimit); - - while (!get(haystack)) + if (auto a = get(haystack); a && a->tail) { - if (needle == haystack) - { - reportError(TypeError{location, OccursCheckFailed{}}); - DEPRECATED_log(needle); - *asMutable(needle) = *getSingletonTypes().errorRecoveryTypePack(); - } - - if (auto a = get(haystack); a && a->tail) - { - haystack = follow(*a->tail); - continue; - } - - break; + haystack = log.follow(*a->tail); + continue; } + + break; } } Unifier Unifier::makeChildUnifier() { - if (FFlag::LuauUseCommittingTxnLog) - return Unifier{types, mode, log.sharedSeen, location, variance, sharedState, &log}; - else - return Unifier{types, mode, DEPRECATED_log.sharedSeen, location, variance, sharedState, &log}; + return Unifier{types, mode, log.sharedSeen, location, variance, sharedState, &log}; } bool Unifier::isNonstrictMode() const diff --git a/VM/include/lua.h b/VM/include/lua.h index 0a561f274..274c4ed9f 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -229,19 +229,46 @@ LUA_API void lua_setthreaddata(lua_State* L, void* data); enum lua_GCOp { + /* stop and resume incremental garbage collection */ LUA_GCSTOP, LUA_GCRESTART, + + /* run a full GC cycle; not recommended for latency sensitive applications */ LUA_GCCOLLECT, + + /* return the heap size in KB and the remainder in bytes */ LUA_GCCOUNT, LUA_GCCOUNTB, + + /* return 1 if GC is active (not stopped); note that GC may not be actively collecting even if it's running */ LUA_GCISRUNNING, - // garbage collection is handled by 'assists' that perform some amount of GC work matching pace of allocation - // explicit GC steps allow to perform some amount of work at custom points to offset the need for GC assists - // note that GC might also be paused for some duration (until bytes allocated meet the threshold) - // if an explicit step is performed during this pause, it will trigger the start of the next collection cycle + /* + ** perform an explicit GC step, with the step size specified in KB + ** + ** garbage collection is handled by 'assists' that perform some amount of GC work matching pace of allocation + ** explicit GC steps allow to perform some amount of work at custom points to offset the need for GC assists + ** note that GC might also be paused for some duration (until bytes allocated meet the threshold) + ** if an explicit step is performed during this pause, it will trigger the start of the next collection cycle + */ LUA_GCSTEP, + /* + ** tune GC parameters G (goal), S (step multiplier) and step size (usually best left ignored) + ** + ** garbage collection is incremental and tries to maintain the heap size to balance memory and performance overhead + ** this overhead is determined by G (goal) which is the ratio between total heap size and the amount of live data in it + ** G is specified in percentages; by default G=200% which means that the heap is allowed to grow to ~2x the size of live data. + ** + ** collector tries to collect S% of allocated bytes by interrupting the application after step size bytes were allocated. + ** when S is too small, collector may not be able to catch up and the effective goal that can be reached will be larger. + ** S is specified in percentages; by default S=200% which means that collector will run at ~2x the pace of allocations. + ** + ** it is recommended to set S in the interval [100 / (G - 100), 100 + 100 / (G - 100))] with a minimum value of 150%; for example: + ** - for G=200%, S should be in the interval [150%, 200%] + ** - for G=150%, S should be in the interval [200%, 300%] + ** - for G=125%, S should be in the interval [400%, 500%] + */ LUA_GCSETGOAL, LUA_GCSETSTEPMUL, LUA_GCSETSTEPSIZE, diff --git a/VM/include/luaconf.h b/VM/include/luaconf.h index c5bf1c184..b93cbf7c7 100644 --- a/VM/include/luaconf.h +++ b/VM/include/luaconf.h @@ -59,33 +59,6 @@ #define LUA_IDSIZE 256 #endif -/* -@@ LUAI_GCGOAL defines the desired top heap size in relation to the live heap -@* size at the end of the GC cycle -** CHANGE it if you want the GC to run faster or slower (higher values -** mean larger GC pauses which mean slower collection.) You can also change -** this value dynamically. -*/ -#ifndef LUAI_GCGOAL -#define LUAI_GCGOAL 200 /* 200% (allow heap to double compared to live heap size) */ -#endif - -/* -@@ LUAI_GCSTEPMUL / LUAI_GCSTEPSIZE define the default speed of garbage collection -@* relative to memory allocation. -** Every LUAI_GCSTEPSIZE KB allocated, incremental collector collects LUAI_GCSTEPSIZE -** times LUAI_GCSTEPMUL% bytes. -** CHANGE it if you want to change the granularity of the garbage -** collection. -*/ -#ifndef LUAI_GCSTEPMUL -#define LUAI_GCSTEPMUL 200 /* GC runs 'twice the speed' of memory allocation */ -#endif - -#ifndef LUAI_GCSTEPSIZE -#define LUAI_GCSTEPSIZE 1 /* GC runs every KB of memory allocation */ -#endif - /* LUA_MINSTACK is the guaranteed number of Lua stack slots available to a C function */ #ifndef LUA_MINSTACK #define LUA_MINSTACK 20 diff --git a/VM/include/lualib.h b/VM/include/lualib.h index baf27b47e..bebd0a0f0 100644 --- a/VM/include/lualib.h +++ b/VM/include/lualib.h @@ -54,6 +54,8 @@ LUALIB_API lua_State* luaL_newstate(void); LUALIB_API const char* luaL_findtable(lua_State* L, int idx, const char* fname, int szhint); +LUALIB_API const char* luaL_typename(lua_State* L, int idx); + /* ** =============================================================== ** some useful macros @@ -66,8 +68,6 @@ LUALIB_API const char* luaL_findtable(lua_State* L, int idx, const char* fname, #define luaL_checkstring(L, n) (luaL_checklstring(L, (n), NULL)) #define luaL_optstring(L, n, d) (luaL_optlstring(L, (n), (d), NULL)) -#define luaL_typename(L, i) lua_typename(L, lua_type(L, (i))) - #define luaL_getmetatable(L, n) (lua_getfield(L, LUA_REGISTRYINDEX, (n))) #define luaL_opt(L, f, n, d) (lua_isnoneornil(L, (n)) ? (d) : f(L, (n))) diff --git a/VM/src/laux.cpp b/VM/src/laux.cpp index 9a6f77938..9fe2ebb6e 100644 --- a/VM/src/laux.cpp +++ b/VM/src/laux.cpp @@ -11,6 +11,8 @@ #include +LUAU_DYNAMIC_FASTFLAGVARIABLE(LuauMorePreciseLuaLTypeName, false) + /* convert a stack index to positive */ #define abs_index(L, i) ((i) > 0 || (i) <= LUA_REGISTRYINDEX ? (i) : lua_gettop(L) + (i) + 1) @@ -333,6 +335,19 @@ const char* luaL_findtable(lua_State* L, int idx, const char* fname, int szhint) return NULL; } +const char* luaL_typename(lua_State* L, int idx) +{ + if (DFFlag::LuauMorePreciseLuaLTypeName) + { + const TValue* obj = luaA_toobject(L, idx); + return luaT_objtypename(L, obj); + } + else + { + return lua_typename(L, lua_type(L, idx)); + } +} + /* ** {====================================================== ** Generic Buffer manipulation diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp index 988fd315e..2307598e8 100644 --- a/VM/src/lbaselib.cpp +++ b/VM/src/lbaselib.cpp @@ -10,6 +10,8 @@ #include #include +LUAU_DYNAMIC_FASTFLAG(LuauMorePreciseLuaLTypeName) + static void writestring(const char* s, size_t l) { fwrite(s, 1, l, stdout); @@ -186,7 +188,14 @@ static int luaB_gcinfo(lua_State* L) static int luaB_type(lua_State* L) { luaL_checkany(L, 1); - lua_pushstring(L, luaL_typename(L, 1)); + if (DFFlag::LuauMorePreciseLuaLTypeName) + { + lua_pushstring(L, lua_typename(L, lua_type(L, 1))); + } + else + { + lua_pushstring(L, luaL_typename(L, 1)); + } return 1; } diff --git a/VM/src/lgc.h b/VM/src/lgc.h index cbeeebd48..ad8ee78a0 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -6,6 +6,13 @@ #include "lobject.h" #include "lstate.h" +/* +** Default settings for GC tunables (settable via lua_gc) +*/ +#define LUAI_GCGOAL 200 /* 200% (allow heap to double compared to live heap size) */ +#define LUAI_GCSTEPMUL 200 /* GC runs 'twice the speed' of memory allocation */ +#define LUAI_GCSTEPSIZE 1 /* GC runs every KB of memory allocation */ + /* ** Possible states of the Garbage Collector */ diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index ce890ba89..55b0618ad 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -14,7 +14,6 @@ LUAU_FASTFLAG(LuauTraceTypesInNonstrictMode2) LUAU_FASTFLAG(LuauSetMetatableDoesNotTimeTravel) -LUAU_FASTFLAG(LuauUseCommittingTxnLog) LUAU_FASTFLAG(LuauTableCloneType) using namespace Luau; @@ -1912,14 +1911,9 @@ local bar: @1= foo CHECK(!ac.entryMap.count("foo")); } -// Switch back to TEST_CASE_FIXTURE with regular ACFixture when removing the -// LuauUseCommittingTxnLog flag. -TEST_CASE("type_correct_function_no_parenthesis") +TEST_CASE_FIXTURE(ACFixture, "type_correct_function_no_parenthesis") { - ScopedFastFlag sff_LuauUseCommittingTxnLog = ScopedFastFlag("LuauUseCommittingTxnLog", true); - ACFixture fix; - - fix.check(R"( + check(R"( local function target(a: (number) -> number) return a(4) end local function bar1(a: number) return -a end local function bar2(a: string) return a .. 'x' end @@ -1927,7 +1921,7 @@ local function bar2(a: string) return a .. 'x' end return target(b@1 )"); - auto ac = fix.autocomplete('1'); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("bar1")); CHECK(ac.entryMap["bar1"].typeCorrect == TypeCorrectKind::Correct); @@ -1937,8 +1931,6 @@ return target(b@1 TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses") { - ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); - check(R"( local function bar(a: number) return -a end local abc = b@1 @@ -1952,8 +1944,6 @@ local abc = b@1 TEST_CASE_FIXTURE(ACFixture, "function_result_passed_to_function_has_parentheses") { - ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); - check(R"( local function foo() return 1 end local function bar(a: number) return -a end @@ -1978,14 +1968,9 @@ local fp: @1= f CHECK(ac.entryMap.count("({ x: number, y: number }) -> number")); } -// Switch back to TEST_CASE_FIXTURE with regular ACFixture when removing the -// LuauUseCommittingTxnLog flag. -TEST_CASE("type_correct_keywords") +TEST_CASE_FIXTURE(ACFixture, "type_correct_keywords") { - ScopedFastFlag sff_LuauUseCommittingTxnLog = ScopedFastFlag("LuauUseCommittingTxnLog", true); - ACFixture fix; - - fix.check(R"( + check(R"( local function a(x: boolean) end local function b(x: number?) end local function c(x: (number) -> string) end @@ -2002,26 +1987,26 @@ local dc = d(f@4) local ec = e(f@5) )"); - auto ac = fix.autocomplete('1'); + auto ac = autocomplete('1'); CHECK(ac.entryMap.count("tru")); CHECK(ac.entryMap["tru"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["true"].typeCorrect == TypeCorrectKind::Correct); CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::Correct); - ac = fix.autocomplete('2'); + ac = autocomplete('2'); CHECK(ac.entryMap.count("ni")); CHECK(ac.entryMap["ni"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["nil"].typeCorrect == TypeCorrectKind::Correct); - ac = fix.autocomplete('3'); + ac = autocomplete('3'); CHECK(ac.entryMap.count("false")); CHECK(ac.entryMap["false"].typeCorrect == TypeCorrectKind::None); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); - ac = fix.autocomplete('4'); + ac = autocomplete('4'); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); - ac = fix.autocomplete('5'); + ac = autocomplete('5'); CHECK(ac.entryMap["function"].typeCorrect == TypeCorrectKind::Correct); } @@ -2512,23 +2497,21 @@ local t = { CHECK(ac.entryMap.count("second")); } -TEST_CASE("autocomplete_documentation_symbols") +TEST_CASE_FIXTURE(Fixture, "autocomplete_documentation_symbols") { - Fixture fix(FFlag::LuauUseCommittingTxnLog); - - fix.loadDefinition(R"( + loadDefinition(R"( declare y: { x: number, } )"); - fix.fileResolver.source["Module/A"] = R"( + fileResolver.source["Module/A"] = R"( local a = y. )"; - fix.frontend.check("Module/A"); + frontend.check("Module/A"); - auto ac = autocomplete(fix.frontend, "Module/A", Position{1, 21}, nullCallback); + auto ac = autocomplete(frontend, "Module/A", Position{1, 21}, nullCallback); REQUIRE(ac.entryMap.count("x")); CHECK_EQ(ac.entryMap["x"].documentationSymbol, "@test/global/y.x"); @@ -2646,8 +2629,6 @@ local a: A<(number, s@1> TEST_CASE_FIXTURE(ACFixture, "autocomplete_first_function_arg_expected_type") { - ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); - check(R"( local function foo1() return 1 end local function foo2() return "1" end @@ -2720,7 +2701,6 @@ type A = () -> T TEST_CASE_FIXTURE(ACFixture, "autocomplete_oop_implicit_self") { - ScopedFastFlag flag("LuauMissingFollowACMetatables", true); check(R"( --!strict local Class = {} @@ -2764,8 +2744,6 @@ TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons") TEST_CASE_FIXTURE(ACFixture, "function_in_assignment_has_parentheses_2") { - ScopedFastFlag luauAutocompleteAvoidMutation("LuauAutocompleteAvoidMutation", true); - check(R"( local bar: ((number) -> number) & (number, number) -> number) local abc = b@1 diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index d584eb2d2..711c0aa1d 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -7,6 +7,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauFixIncorrectLineNumberDuplicateType) + TEST_SUITE_BEGIN("TypeAliases"); TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_type_alias") @@ -241,6 +243,27 @@ TEST_CASE_FIXTURE(Fixture, "export_type_and_type_alias_are_duplicates") CHECK_EQ(dtd->name, "Foo"); } +TEST_CASE_FIXTURE(Fixture, "reported_location_is_correct_when_type_alias_are_duplicates") +{ + CheckResult result = check(R"( + type A = string + type B = number + type C = string + type B = number + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto dtd = get(result.errors[0]); + REQUIRE(dtd); + CHECK_EQ(dtd->name, "B"); + + if (FFlag::LuauFixIncorrectLineNumberDuplicateType) + CHECK_EQ(dtd->previousLocation.begin.line + 1, 3); + else + CHECK_EQ(dtd->previousLocation.begin.line + 1, 1); +} + TEST_CASE_FIXTURE(Fixture, "stringify_optional_parameterized_alias") { CheckResult result = check(R"( diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index bf9907703..8da655b34 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -8,8 +8,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauUseCommittingTxnLog) - TEST_SUITE_BEGIN("BuiltinTests"); TEST_CASE_FIXTURE(Fixture, "math_things_are_defined") @@ -443,28 +441,19 @@ TEST_CASE_FIXTURE(Fixture, "os_time_takes_optional_date_table") CHECK_EQ(*typeChecker.numberType, *requireType("n3")); } -// Switch back to TEST_CASE_FIXTURE with regular Fixture when removing the -// LuauUseCommittingTxnLog flag. -TEST_CASE("thread_is_a_type") +TEST_CASE_FIXTURE(Fixture, "thread_is_a_type") { - Fixture fix(FFlag::LuauUseCommittingTxnLog); - - CheckResult result = fix.check(R"( + CheckResult result = check(R"( local co = coroutine.create(function() end) )"); - // Replace with LUAU_REQUIRE_NO_ERRORS(result) when using TEST_CASE_FIXTURE. - CHECK(result.errors.size() == 0); - CHECK_EQ(*fix.typeChecker.threadType, *fix.requireType("co")); + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(*typeChecker.threadType, *requireType("co")); } -// Switch back to TEST_CASE_FIXTURE with regular Fixture when removing the -// LuauUseCommittingTxnLog flag. -TEST_CASE("coroutine_resume_anything_goes") +TEST_CASE_FIXTURE(Fixture, "coroutine_resume_anything_goes") { - Fixture fix(FFlag::LuauUseCommittingTxnLog); - - CheckResult result = fix.check(R"( + CheckResult result = check(R"( local function nifty(x, y) print(x, y) local z = coroutine.yield(1, 2) @@ -477,17 +466,12 @@ TEST_CASE("coroutine_resume_anything_goes") local answer = coroutine.resume(co, 3) )"); - // Replace with LUAU_REQUIRE_NO_ERRORS(result) when using TEST_CASE_FIXTURE. - CHECK(result.errors.size() == 0); + LUAU_REQUIRE_NO_ERRORS(result); } -// Switch back to TEST_CASE_FIXTURE with regular Fixture when removing the -// LuauUseCommittingTxnLog flag. -TEST_CASE("coroutine_wrap_anything_goes") +TEST_CASE_FIXTURE(Fixture, "coroutine_wrap_anything_goes") { - Fixture fix(FFlag::LuauUseCommittingTxnLog); - - CheckResult result = fix.check(R"( + CheckResult result = check(R"( --!nonstrict local function nifty(x, y) print(x, y) @@ -501,8 +485,7 @@ TEST_CASE("coroutine_wrap_anything_goes") local answer = f(3) )"); - // Replace with LUAU_REQUIRE_NO_ERRORS(result) when using TEST_CASE_FIXTURE. - CHECK(result.errors.size() == 0); + LUAU_REQUIRE_NO_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "setmetatable_should_not_mutate_persisted_types") @@ -961,4 +944,18 @@ TEST_CASE_FIXTURE(Fixture, "table_freeze_is_generic") CHECK_EQ("*unknown*", toString(requireType("d"))); } +TEST_CASE_FIXTURE(Fixture, "set_metatable_needs_arguments") +{ + ScopedFastFlag sff{"LuauSetMetaTableArgsCheck", true}; + CheckResult result = check(R"( +local a = {b=setmetatable} +a.b() +a:b() +a:b({}) + )"); + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 0}, {2, 5}}, CountMismatch{2, 0}})); + CHECK_EQ(result.errors[1], (TypeError{Location{{3, 0}, {3, 5}}, CountMismatch{2, 1}})); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index c482847bf..547fbab15 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -8,6 +8,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauFixArgumentCountMismatchAmountWithGenericTypes) + TEST_SUITE_BEGIN("GenericsTests"); TEST_CASE_FIXTURE(Fixture, "check_generic_function") @@ -786,4 +788,47 @@ local TheDispatcher: Dispatcher = { LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "generic_argument_count_too_few") +{ + CheckResult result = check(R"( +function test(a: number) + return 1 +end + +function wrapper(f: (A...) -> number, ...: A...) +end + +wrapper(test) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) + CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 1 argument, but 1 is specified)"); +} + +TEST_CASE_FIXTURE(Fixture, "generic_argument_count_too_many") +{ + CheckResult result = check(R"( +function test2(a: number, b: string) + return 1 +end + +function wrapper(f: (A...) -> number, ...: A...) +end + +wrapper(test2, 1, "", 3) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) + CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 3 arguments, but 4 are specified)"); + else + CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 1 argument, but 4 are specified)"); +} + + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index da035ba14..a5eba5dfe 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2334,4 +2334,54 @@ TEST_CASE_FIXTURE(Fixture, "pass_a_union_of_tables_to_a_function_that_requires_a REQUIRE_EQ("{| [any]: any, x: number, y: number |} | {| y: number |}", toString(requireType("b"))); } +TEST_CASE_FIXTURE(Fixture, "unifying_tables_shouldnt_uaf1") +{ + ScopedFastFlag sff{"LuauTxnLogCheckForInvalidation", true}; + + CheckResult result = check(R"( +-- This example produced a UAF at one point, caused by pointers to table types becoming +-- invalidated by child unifiers. (Calling log.concat can cause pointers to become invalid.) +type _Entry = { + a: number, + + middle: (self: _Entry) -> (), + + z: number +} + +export type AnyEntry = _Entry + +local Entry = {} +Entry.__index = Entry + +function Entry:dispose() + self:middle() + forgetChildren(self) -- unify free with sealed AnyEntry +end + +function forgetChildren(parent: AnyEntry) +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "unifying_tables_shouldnt_uaf2") +{ + ScopedFastFlag sff{"LuauTxnLogCheckForInvalidation", true}; + + CheckResult result = check(R"( +-- Another example that UAFd, this time found by fuzzing. +local _ +do +_._ *= (_[{n0=_[{[{[_]=_,}]=_,}],}])[_] +_ = (_.n0) +end +_._ *= (_[false])[_] +_ = (_.cos) + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index f63579b56..d7bbad20d 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -5144,7 +5144,6 @@ end TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") { - ScopedFastFlag committingTxnLog{"LuauUseCommittingTxnLog", true}; ScopedFastFlag subtypingVariance{"LuauTableSubtypingVariance2", true}; CheckResult result = check(R"( @@ -5355,4 +5354,41 @@ end LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "function_decl_quantify_right_type") +{ + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify", true}; + + fileResolver.source["game/isAMagicMock"] = R"( +--!nonstrict +return function(value) + return false +end + )"; + + CheckResult result = check(R"( +--!nonstrict +local MagicMock = {} +MagicMock.is = require(game.isAMagicMock) + +function MagicMock.is(value) + return false +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_sealed_overwrite") +{ + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify", true}; + + CheckResult result = check(R"( +function string.len(): number + return 1 +end + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tryUnify.test.cpp b/tests/TypeInfer.tryUnify.test.cpp index f6ee3ccce..d8de25940 100644 --- a/tests/TypeInfer.tryUnify.test.cpp +++ b/tests/TypeInfer.tryUnify.test.cpp @@ -9,8 +9,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauUseCommittingTxnLog) - struct TryUnifyFixture : Fixture { TypeArena arena; @@ -43,8 +41,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "compatible_functions_are_unified") state.tryUnify(&functionTwo, &functionOne); CHECK(state.errors.empty()); - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); CHECK_EQ(functionOne, functionTwo); } @@ -86,8 +83,7 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "tables_can_be_unified") CHECK(state.errors.empty()); - if (FFlag::LuauUseCommittingTxnLog) - state.log.commit(); + state.log.commit(); CHECK_EQ(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); } @@ -110,9 +106,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "incompatible_tables_are_preserved") CHECK_EQ(1, state.errors.size()); - if (!FFlag::LuauUseCommittingTxnLog) - state.DEPRECATED_log.rollback(); - CHECK_NE(*getMutable(&tableOne)->props["foo"].type, *getMutable(&tableTwo)->props["foo"].type); } @@ -217,34 +210,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "cli_41095_concat_log_in_sealed_table_unifica CHECK_EQ(toString(result.errors[1]), "Available overloads: ({a}, a) -> (); and ({a}, number, a) -> ()"); } -TEST_CASE("undo_new_prop_on_unsealed_table") -{ - ScopedFastFlag flags[] = { - {"LuauTableSubtypingVariance2", true}, - // This test makes no sense with a committing TxnLog. - {"LuauUseCommittingTxnLog", false}, - }; - // I am not sure how to make this happen in Luau code. - - TryUnifyFixture fix; - - TypeId unsealedTable = fix.arena.addType(TableTypeVar{TableState::Unsealed, TypeLevel{}}); - TypeId sealedTable = - fix.arena.addType(TableTypeVar{{{"prop", Property{getSingletonTypes().numberType}}}, std::nullopt, TypeLevel{}, TableState::Sealed}); - - const TableTypeVar* ttv = get(unsealedTable); - REQUIRE(ttv); - - fix.state.tryUnify(sealedTable, unsealedTable); - - // To be honest, it's really quite spooky here that we're amending an unsealed table in this case. - CHECK(!ttv->props.empty()); - - fix.state.DEPRECATED_log.rollback(); - - CHECK(ttv->props.empty()); -} - TEST_CASE_FIXTURE(TryUnifyFixture, "free_tail_is_grown_properly") { TypePackId threeNumbers = arena.addTypePack(TypePack{{typeChecker.numberType, typeChecker.numberType, typeChecker.numberType}, std::nullopt}); @@ -267,11 +232,6 @@ TEST_CASE_FIXTURE(TryUnifyFixture, "recursive_metatable_getmatchtag") TEST_CASE_FIXTURE(TryUnifyFixture, "cli_50320_follow_in_any_unification") { - ScopedFastFlag sffs[] = { - {"LuauUseCommittingTxnLog", true}, - {"LuauFollowWithCommittingTxnLogInAnyUnification", true}, - }; - TypePackVar free{FreeTypePack{TypeLevel{}}}; TypePackVar target{TypePack{}}; diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index 6b96f4498..fcc21c18e 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -9,8 +9,6 @@ using namespace Luau; -LUAU_FASTFLAG(LuauUseCommittingTxnLog) - TEST_SUITE_BEGIN("TypePackTests"); TEST_CASE_FIXTURE(Fixture, "infer_multi_return") @@ -264,13 +262,9 @@ TEST_CASE_FIXTURE(Fixture, "variadic_pack_syntax") CHECK_EQ(toString(requireType("foo")), "(...number) -> ()"); } -// Switch back to TEST_CASE_FIXTURE with regular Fixture when removing the -// LuauUseCommittingTxnLog flag. -TEST_CASE("type_pack_hidden_free_tail_infinite_growth") +TEST_CASE_FIXTURE(Fixture, "type_pack_hidden_free_tail_infinite_growth") { - Fixture fix(FFlag::LuauUseCommittingTxnLog); - - CheckResult result = fix.check(R"( + CheckResult result = check(R"( --!nonstrict if _ then _[function(l0)end],l0 = _ @@ -282,8 +276,7 @@ elseif _ then end )"); - // Switch back to LUAU_REQUIRE_ERRORS(result) when using TEST_CASE_FIXTURE. - CHECK(result.errors.size() > 0); + LUAU_REQUIRE_ERRORS(result); } TEST_CASE_FIXTURE(Fixture, "variadic_argument_tail") diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index 0e0b6ebba..ad4cecd87 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -7,7 +7,6 @@ #include "doctest.h" LUAU_FASTFLAG(LuauEqConstraint) -LUAU_FASTFLAG(LuauUseCommittingTxnLog) using namespace Luau; @@ -282,19 +281,16 @@ local c = b:foo(1, 2) CHECK_EQ("Value of type 'A?' could be nil", toString(result.errors[0])); } -TEST_CASE("optional_union_follow") +TEST_CASE_FIXTURE(Fixture, "optional_union_follow") { - Fixture fix(FFlag::LuauUseCommittingTxnLog); - - CheckResult result = fix.check(R"( + CheckResult result = check(R"( local y: number? = 2 local x = y local function f(a: number, b: typeof(x), c: typeof(x)) return -a end return f() )"); - REQUIRE_EQ(result.errors.size(), 1); - // LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); auto acm = get(result.errors[0]); REQUIRE(acm); diff --git a/tools/LuauVisualize.py b/tools/LuauVisualize.py new file mode 100644 index 000000000..40f8d6be0 --- /dev/null +++ b/tools/LuauVisualize.py @@ -0,0 +1,107 @@ +# This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +# HACK: LLDB's python API doesn't afford anything helpful for getting at variadic template parameters. +# We're forced to resort to parsing names as strings. +def templateParams(s): + depth = 0 + start = s.find('<') + 1 + result = [] + for i, c in enumerate(s[start:], start): + if c == '<': + depth += 1 + elif c == '>': + if depth == 0: + result.append(s[start: i].strip()) + break + depth -= 1 + elif c == ',' and depth == 0: + result.append(s[start: i].strip()) + start = i + 1 + return result + +def getType(target, typeName): + stars = 0 + + typeName = typeName.strip() + while typeName.endswith('*'): + stars += 1 + typeName = typeName[:-1] + + if typeName.startswith('const '): + typeName = typeName[6:] + + ty = target.FindFirstType(typeName.strip()) + for _ in range(stars): + ty = ty.GetPointerType() + + return ty + +def luau_variant_summary(valobj, internal_dict, options): + type_id = valobj.GetChildMemberWithName("typeid").GetValueAsUnsigned() + storage = valobj.GetChildMemberWithName("storage") + params = templateParams(valobj.GetType().GetCanonicalType().GetName()) + stored_type = params[type_id] + value = storage.Cast(stored_type.GetPointerType()).Dereference() + return stored_type.GetDisplayTypeName() + " [" + value.GetValue() + "]" + +class LuauVariantSyntheticChildrenProvider: + node_names = ["type", "value"] + + def __init__(self, valobj, internal_dict): + self.valobj = valobj + self.type_index = None + self.current_type = None + self.type_params = [] + self.stored_value = None + + def num_children(self): + return len(self.node_names) + + def has_children(self): + return True + + def get_child_index(self, name): + try: + return self.node_names.index(name) + except ValueError: + return -1 + + def get_child_at_index(self, index): + try: + node = self.node_names[index] + except IndexError: + return None + + if node == "type": + if self.current_type: + return self.valobj.CreateValueFromExpression(node, f"(const char*)\"{self.current_type.GetDisplayTypeName()}\"") + else: + return self.valobj.CreateValueFromExpression(node, "(const char*)\"\"") + elif node == "value": + if self.stored_value is not None: + if self.current_type is not None: + return self.valobj.CreateValueFromData(node, self.stored_value.GetData(), self.current_type) + else: + return self.valobj.CreateValueExpression(node, "(const char*)\"\"") + else: + return self.valobj.CreateValueFromExpression(node, "(const char*)\"\"") + else: + return None + + def update(self): + self.type_index = self.valobj.GetChildMemberWithName("typeid").GetValueAsSigned() + self.type_params = templateParams(self.valobj.GetType().GetCanonicalType().GetName()) + + if len(self.type_params) > self.type_index: + self.current_type = getType(self.valobj.GetTarget(), self.type_params[self.type_index]) + + if self.current_type: + storage = self.valobj.GetChildMemberWithName("storage") + self.stored_value = storage.Cast(self.current_type.GetPointerType()).Dereference() + else: + self.stored_value = None + else: + self.current_type = None + self.stored_value = None + + return False diff --git a/tools/lldb-formatters.lldb b/tools/lldb-formatters.lldb new file mode 100644 index 000000000..3868ac20c --- /dev/null +++ b/tools/lldb-formatters.lldb @@ -0,0 +1,2 @@ +type synthetic add -x "^Luau::Variant<.+>$" -l LuauVisualize.LuauVariantSyntheticChildrenProvider +type summary add -x "^Luau::Variant<.+>$" -l LuauVisualize.luau_variant_summary From adecd840675c642f1f553d39930aa8014e82faa7 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 17 Mar 2022 17:06:25 -0700 Subject: [PATCH 30/32] Sync to upstream/release/519 --- Analysis/include/Luau/Error.h | 1 + Analysis/include/Luau/Unifier.h | 2 + Analysis/src/Autocomplete.cpp | 186 +- Analysis/src/EmbeddedBuiltinDefinitions.cpp | 184 +- Analysis/src/Error.cpp | 11 +- Analysis/src/Module.cpp | 7 + Analysis/src/TypeInfer.cpp | 125 +- Analysis/src/TypeVar.cpp | 10 +- Analysis/src/Unifier.cpp | 78 +- Ast/src/Parser.cpp | 117 +- Compiler/src/Builtins.cpp | 4 +- Compiler/src/Compiler.cpp | 4 - Sources.cmake | 7 + VM/include/lua.h | 15 +- VM/src/lapi.cpp | 20 +- VM/src/lbaselib.cpp | 16 +- VM/src/lgc.h | 4 +- VM/src/lmem.cpp | 2 +- VM/src/lstring.cpp | 2 +- VM/src/ltable.cpp | 71 +- VM/src/ltm.cpp | 3 +- VM/src/ludata.cpp | 8 +- VM/src/ludata.h | 3 + VM/src/lvmexecute.cpp | 5 + tests/Autocomplete.test.cpp | 162 + tests/Compiler.test.cpp | 2 - tests/Conformance.test.cpp | 142 +- tests/Linter.test.cpp | 2 - tests/Module.test.cpp | 10 +- tests/NonstrictMode.test.cpp | 22 +- tests/TypeInfer.aliases.test.cpp | 21 + tests/TypeInfer.anyerror.test.cpp | 335 ++ tests/TypeInfer.builtins.test.cpp | 60 + tests/TypeInfer.definitions.test.cpp | 18 + tests/TypeInfer.functions.test.cpp | 1338 +++++ tests/TypeInfer.generics.test.cpp | 301 + tests/TypeInfer.intersectionTypes.test.cpp | 28 + tests/TypeInfer.loops.test.cpp | 473 ++ tests/TypeInfer.modules.test.cpp | 310 + tests/TypeInfer.oop.test.cpp | 275 + tests/TypeInfer.operators.test.cpp | 759 +++ tests/TypeInfer.primitives.test.cpp | 100 + tests/TypeInfer.refinements.test.cpp | 18 + tests/TypeInfer.singletons.test.cpp | 116 +- tests/TypeInfer.tables.test.cpp | 500 ++ tests/TypeInfer.test.cpp | 5135 ++--------------- tests/TypeInfer.typePacks.cpp | 83 + tests/TypeInfer.unionTypes.test.cpp | 16 + tests/conformance/basic.lua | 19 + tests/conformance/debugger.lua | 4 +- tests/conformance/errors.lua | 2 + tests/conformance/interrupt.lua | 11 + tools/{gdb-printers.py => gdb_printers.py} | 0 tools/lldb-formatters.lldb | 2 - tools/lldb_formatters.lldb | 2 + .../{LuauVisualize.py => lldb_formatters.py} | 0 56 files changed, 6017 insertions(+), 5134 deletions(-) create mode 100644 tests/TypeInfer.anyerror.test.cpp create mode 100644 tests/TypeInfer.functions.test.cpp create mode 100644 tests/TypeInfer.loops.test.cpp create mode 100644 tests/TypeInfer.modules.test.cpp create mode 100644 tests/TypeInfer.oop.test.cpp create mode 100644 tests/TypeInfer.operators.test.cpp create mode 100644 tests/TypeInfer.primitives.test.cpp create mode 100644 tests/conformance/interrupt.lua rename tools/{gdb-printers.py => gdb_printers.py} (100%) delete mode 100644 tools/lldb-formatters.lldb create mode 100644 tools/lldb_formatters.lldb rename tools/{LuauVisualize.py => lldb_formatters.py} (100%) diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index a71e02246..72350255e 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -107,6 +107,7 @@ struct FunctionDoesNotTakeSelf struct FunctionRequiresSelf { + // TODO: Delete with LuauAnyInIsOptionalIsOptional int requiredExtraNils = 0; bool operator==(const FunctionRequiresSelf& rhs) const; diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index 71958f4a1..f1ffbcc01 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -86,6 +86,8 @@ struct Unifier void tryUnifyIndexer(const TableIndexer& subIndexer, const TableIndexer& superIndexer); TypeId widen(TypeId ty); + TypePackId widen(TypePackId tp); + TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); void cacheResult(TypeId subTy, TypeId superTy); diff --git a/Analysis/src/Autocomplete.cpp b/Analysis/src/Autocomplete.cpp index e94c432f8..492edf256 100644 --- a/Analysis/src/Autocomplete.cpp +++ b/Analysis/src/Autocomplete.cpp @@ -14,6 +14,7 @@ #include LUAU_FASTFLAGVARIABLE(LuauIfElseExprFixCompletionIssue, false); +LUAU_FASTFLAG(LuauSelfCallAutocompleteFix) static const std::unordered_set kStatementStartingKeywords = { "while", "if", "local", "repeat", "function", "do", "for", "return", "break", "continue", "type", "export"}; @@ -228,11 +229,22 @@ static std::optional findExpectedTypeAt(const Module& module, AstNode* n return *it; } +static bool checkTypeMatch(TypeArena* typeArena, TypeId subTy, TypeId superTy) +{ + InternalErrorReporter iceReporter; + UnifierSharedState unifierState(&iceReporter); + Unifier unifier(typeArena, Mode::Strict, Location(), Variance::Covariant, unifierState); + + return unifier.canUnify(subTy, superTy).empty(); +} + static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typeArena, AstNode* node, Position position, TypeId ty) { ty = follow(ty); auto canUnify = [&typeArena](TypeId subTy, TypeId superTy) { + LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix); + InternalErrorReporter iceReporter; UnifierSharedState unifierState(&iceReporter); Unifier unifier(typeArena, Mode::Strict, Location(), Variance::Covariant, unifierState); @@ -249,20 +261,30 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ TypeId expectedType = follow(*typeAtPosition); - auto checkFunctionType = [&canUnify, &expectedType](const FunctionTypeVar* ftv) { - auto [retHead, retTail] = flatten(ftv->retType); - - if (!retHead.empty() && canUnify(retHead.front(), expectedType)) - return true; + auto checkFunctionType = [typeArena, &canUnify, &expectedType](const FunctionTypeVar* ftv) { + if (FFlag::LuauSelfCallAutocompleteFix) + { + if (std::optional firstRetTy = first(ftv->retType)) + return checkTypeMatch(typeArena, *firstRetTy, expectedType); - // We might only have a variadic tail pack, check if the element is compatible - if (retTail) + return false; + } + else { - if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) + auto [retHead, retTail] = flatten(ftv->retType); + + if (!retHead.empty() && canUnify(retHead.front(), expectedType)) return true; - } - return false; + // We might only have a variadic tail pack, check if the element is compatible + if (retTail) + { + if (const VariadicTypePack* vtp = get(follow(*retTail)); vtp && canUnify(vtp->ty, expectedType)) + return true; + } + + return false; + } }; // We also want to suggest functions that return compatible result @@ -281,7 +303,10 @@ static TypeCorrectKind checkTypeCorrectKind(const Module& module, TypeArena* typ } } - return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + if (FFlag::LuauSelfCallAutocompleteFix) + return checkTypeMatch(typeArena, ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; + else + return canUnify(ty, expectedType) ? TypeCorrectKind::Correct : TypeCorrectKind::None; } enum class PropIndexType @@ -291,16 +316,22 @@ enum class PropIndexType Key, }; -static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId ty, PropIndexType indexType, const std::vector& nodes, - AutocompleteEntryMap& result, std::unordered_set& seen, std::optional containingClass = std::nullopt) +static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId rootTy, TypeId ty, PropIndexType indexType, + const std::vector& nodes, AutocompleteEntryMap& result, std::unordered_set& seen, + std::optional containingClass = std::nullopt) { + if (FFlag::LuauSelfCallAutocompleteFix) + rootTy = follow(rootTy); + ty = follow(ty); if (seen.count(ty)) return; seen.insert(ty); - auto isWrongIndexer = [indexType, useStrictFunctionIndexers = !!get(ty)](Luau::TypeId type) { + auto isWrongIndexer_DEPRECATED = [indexType, useStrictFunctionIndexers = !!get(ty)](Luau::TypeId type) { + LUAU_ASSERT(!FFlag::LuauSelfCallAutocompleteFix); + if (indexType == PropIndexType::Key) return false; @@ -331,6 +362,48 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId return colonIndex; } }; + auto isWrongIndexer = [typeArena, rootTy, indexType](Luau::TypeId type) { + LUAU_ASSERT(FFlag::LuauSelfCallAutocompleteFix); + + if (indexType == PropIndexType::Key) + return false; + + bool calledWithSelf = indexType == PropIndexType::Colon; + + auto isCompatibleCall = [typeArena, rootTy, calledWithSelf](const FunctionTypeVar* ftv) { + if (get(rootTy)) + { + // Calls on classes require strict match between how function is declared and how it's called + return calledWithSelf == ftv->hasSelf; + } + + if (std::optional firstArgTy = first(ftv->argTypes)) + { + if (checkTypeMatch(typeArena, rootTy, *firstArgTy)) + return calledWithSelf; + } + + return !calledWithSelf; + }; + + if (const FunctionTypeVar* ftv = get(type)) + return !isCompatibleCall(ftv); + + // For intersections, any part that is successful makes the whole call successful + if (const IntersectionTypeVar* itv = get(type)) + { + for (auto subType : itv->parts) + { + if (const FunctionTypeVar* ftv = get(Luau::follow(subType))) + { + if (isCompatibleCall(ftv)) + return false; + } + } + } + + return calledWithSelf; + }; auto fillProps = [&](const ClassTypeVar::Props& props) { for (const auto& [name, prop] : props) @@ -349,7 +422,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId AutocompleteEntryKind::Property, type, prop.deprecated, - isWrongIndexer(type), + FFlag::LuauSelfCallAutocompleteFix ? isWrongIndexer(type) : isWrongIndexer_DEPRECATED(type), typeCorrect, containingClass, &prop, @@ -361,34 +434,60 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId } }; + auto fillMetatableProps = [&](const TableTypeVar* mtable) { + auto indexIt = mtable->props.find("__index"); + if (indexIt != mtable->props.end()) + { + TypeId followed = follow(indexIt->second.type); + if (get(followed) || get(followed)) + { + autocompleteProps(module, typeArena, rootTy, followed, indexType, nodes, result, seen); + } + else if (auto indexFunction = get(followed)) + { + std::optional indexFunctionResult = first(indexFunction->retType); + if (indexFunctionResult) + autocompleteProps(module, typeArena, rootTy, *indexFunctionResult, indexType, nodes, result, seen); + } + } + }; + if (auto cls = get(ty)) { containingClass = containingClass.value_or(cls); fillProps(cls->props); if (cls->parent) - autocompleteProps(module, typeArena, *cls->parent, indexType, nodes, result, seen, cls); + autocompleteProps(module, typeArena, rootTy, *cls->parent, indexType, nodes, result, seen, cls); } else if (auto tbl = get(ty)) fillProps(tbl->props); else if (auto mt = get(ty)) { - autocompleteProps(module, typeArena, mt->table, indexType, nodes, result, seen); - - auto mtable = get(mt->metatable); - if (!mtable) - return; + autocompleteProps(module, typeArena, rootTy, mt->table, indexType, nodes, result, seen); - auto indexIt = mtable->props.find("__index"); - if (indexIt != mtable->props.end()) + if (FFlag::LuauSelfCallAutocompleteFix) { - TypeId followed = follow(indexIt->second.type); - if (get(followed) || get(followed)) - autocompleteProps(module, typeArena, followed, indexType, nodes, result, seen); - else if (auto indexFunction = get(followed)) + if (auto mtable = get(mt->metatable)) + fillMetatableProps(mtable); + } + else + { + auto mtable = get(mt->metatable); + if (!mtable) + return; + + auto indexIt = mtable->props.find("__index"); + if (indexIt != mtable->props.end()) { - std::optional indexFunctionResult = first(indexFunction->retType); - if (indexFunctionResult) - autocompleteProps(module, typeArena, *indexFunctionResult, indexType, nodes, result, seen); + TypeId followed = follow(indexIt->second.type); + if (get(followed) || get(followed)) + autocompleteProps(module, typeArena, rootTy, followed, indexType, nodes, result, seen); + else if (auto indexFunction = get(followed)) + { + std::optional indexFunctionResult = first(indexFunction->retType); + if (indexFunctionResult) + autocompleteProps(module, typeArena, rootTy, *indexFunctionResult, indexType, nodes, result, seen); + } } } } @@ -400,7 +499,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId AutocompleteEntryMap inner; std::unordered_set innerSeen = seen; - autocompleteProps(module, typeArena, ty, indexType, nodes, inner, innerSeen); + autocompleteProps(module, typeArena, rootTy, ty, indexType, nodes, inner, innerSeen); for (auto& pair : inner) result.insert(pair); @@ -423,14 +522,17 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId if (iter == endIter) return; - autocompleteProps(module, typeArena, *iter, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, rootTy, *iter, indexType, nodes, result, seen); ++iter; while (iter != endIter) { AutocompleteEntryMap inner; - std::unordered_set innerSeen = seen; + std::unordered_set innerSeen; + + if (!FFlag::LuauSelfCallAutocompleteFix) + innerSeen = seen; if (isNil(*iter)) { @@ -438,7 +540,7 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId continue; } - autocompleteProps(module, typeArena, *iter, indexType, nodes, inner, innerSeen); + autocompleteProps(module, typeArena, rootTy, *iter, indexType, nodes, inner, innerSeen); std::unordered_set toRemove; @@ -455,6 +557,18 @@ static void autocompleteProps(const Module& module, TypeArena* typeArena, TypeId ++iter; } } + else if (auto pt = get(ty); pt && FFlag::LuauSelfCallAutocompleteFix) + { + if (pt->metatable) + { + if (auto mtable = get(*pt->metatable)) + fillMetatableProps(mtable); + } + } + else if (FFlag::LuauSelfCallAutocompleteFix && get(get(ty))) + { + autocompleteProps(module, typeArena, rootTy, getSingletonTypes().stringType, indexType, nodes, result, seen); + } } static void autocompleteKeywords( @@ -482,7 +596,7 @@ static void autocompleteProps( const Module& module, TypeArena* typeArena, TypeId ty, PropIndexType indexType, const std::vector& nodes, AutocompleteEntryMap& result) { std::unordered_set seen; - autocompleteProps(module, typeArena, ty, indexType, nodes, result, seen); + autocompleteProps(module, typeArena, ty, ty, indexType, nodes, result, seen); } AutocompleteEntryMap autocompleteProps( @@ -1352,7 +1466,7 @@ static AutocompleteResult autocomplete(const SourceModule& sourceModule, const M TypeId ty = follow(*it); PropIndexType indexType = indexName->op == ':' ? PropIndexType::Colon : PropIndexType::Point; - if (isString(ty)) + if (!FFlag::LuauSelfCallAutocompleteFix && isString(ty)) return {autocompleteProps(*module, typeArena, typeChecker.globalScope->bindings[AstName{"string"}].typeId, indexType, finder.ancestry), finder.ancestry}; else diff --git a/Analysis/src/EmbeddedBuiltinDefinitions.cpp b/Analysis/src/EmbeddedBuiltinDefinitions.cpp index 471b61ad8..be3fcd7da 100644 --- a/Analysis/src/EmbeddedBuiltinDefinitions.cpp +++ b/Analysis/src/EmbeddedBuiltinDefinitions.cpp @@ -95,104 +95,104 @@ declare os: { declare function require(target: any): any -declare function getfenv(target: any?): { [string]: any } +declare function getfenv(target: any): { [string]: any } declare _G: any declare _VERSION: string declare function gcinfo(): number - declare function print(...: T...) - - declare function type(value: T): string - declare function typeof(value: T): string - - -- `assert` has a magic function attached that will give more detailed type information - declare function assert(value: T, errorMessage: string?): T - - declare function error(message: T, level: number?) - - declare function tostring(value: T): string - declare function tonumber(value: T, radix: number?): number? - - declare function rawequal(a: T1, b: T2): boolean - declare function rawget(tab: {[K]: V}, k: K): V - declare function rawset(tab: {[K]: V}, k: K, v: V): {[K]: V} - - declare function setfenv(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)? - - declare function ipairs(tab: {V}): (({V}, number) -> (number, V), {V}, number) - - declare function pcall(f: (A...) -> R..., ...: A...): (boolean, R...) - - -- FIXME: The actual type of `xpcall` is: - -- (f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, R2...) - -- Since we can't represent the return value, we use (boolean, R1...). - declare function xpcall(f: (A...) -> R1..., err: (E) -> R2..., ...: A...): (boolean, R1...) - - -- `select` has a magic function attached to provide more detailed type information - declare function select(i: string | number, ...: A...): ...any - - -- FIXME: This type is not entirely correct - `loadstring` returns a function or - -- (nil, string). - declare function loadstring(src: string, chunkname: string?): (((A...) -> any)?, string?) - - declare function newproxy(mt: boolean?): any - - declare coroutine: { - create: ((A...) -> R...) -> thread, - resume: (thread, A...) -> (boolean, R...), - running: () -> thread, - status: (thread) -> string, - -- FIXME: This technically returns a function, but we can't represent this yet. - wrap: ((A...) -> R...) -> any, - yield: (A...) -> R..., - isyieldable: () -> boolean, - close: (thread) -> (boolean, any?) - } - - declare table: { - concat: ({V}, string?, number?, number?) -> string, - insert: (({V}, V) -> ()) & (({V}, number, V) -> ()), - maxn: ({V}) -> number, - remove: ({V}, number?) -> V?, - sort: ({V}, ((V, V) -> boolean)?) -> (), - create: (number, V?) -> {V}, - find: ({V}, V, number?) -> number?, - - unpack: ({V}, number?, number?) -> ...V, - pack: (...V) -> { n: number, [number]: V }, - - getn: ({V}) -> number, - foreach: ({[K]: V}, (K, V) -> ()) -> (), - foreachi: ({V}, (number, V) -> ()) -> (), - - move: ({V}, number, number, number, {V}?) -> {V}, - clear: ({[K]: V}) -> (), - - isfrozen: ({[K]: V}) -> boolean, - } - - declare debug: { - info: ((thread, number, string) -> R...) & ((number, string) -> R...) & (((A...) -> R1..., string) -> R2...), - traceback: ((string?, number?) -> string) & ((thread, string?, number?) -> string), - } - - declare utf8: { - char: (number, ...number) -> string, - charpattern: string, - codes: (string) -> ((string, number) -> (number, number), string, number), - -- FIXME - codepoint: (string, number?, number?) -> (number, ...number), - len: (string, number?, number?) -> (number?, number?), - offset: (string, number?, number?) -> number, - nfdnormalize: (string) -> string, - nfcnormalize: (string) -> string, - graphemes: (string, number?, number?) -> (() -> (number, number)), - } - - -- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. - declare function unpack(tab: {V}, i: number?, j: number?): ...V +declare function print(...: T...) + +declare function type(value: T): string +declare function typeof(value: T): string + +-- `assert` has a magic function attached that will give more detailed type information +declare function assert(value: T, errorMessage: string?): T + +declare function error(message: T, level: number?) + +declare function tostring(value: T): string +declare function tonumber(value: T, radix: number?): number? + +declare function rawequal(a: T1, b: T2): boolean +declare function rawget(tab: {[K]: V}, k: K): V +declare function rawset(tab: {[K]: V}, k: K, v: V): {[K]: V} + +declare function setfenv(target: number | (T...) -> R..., env: {[string]: any}): ((T...) -> R...)? + +declare function ipairs(tab: {V}): (({V}, number) -> (number, V), {V}, number) + +declare function pcall(f: (A...) -> R..., ...: A...): (boolean, R...) + +-- FIXME: The actual type of `xpcall` is: +-- (f: (A...) -> R1..., err: (E) -> R2..., A...) -> (true, R1...) | (false, R2...) +-- Since we can't represent the return value, we use (boolean, R1...). +declare function xpcall(f: (A...) -> R1..., err: (E) -> R2..., ...: A...): (boolean, R1...) + +-- `select` has a magic function attached to provide more detailed type information +declare function select(i: string | number, ...: A...): ...any + +-- FIXME: This type is not entirely correct - `loadstring` returns a function or +-- (nil, string). +declare function loadstring(src: string, chunkname: string?): (((A...) -> any)?, string?) + +declare function newproxy(mt: boolean?): any + +declare coroutine: { + create: ((A...) -> R...) -> thread, + resume: (thread, A...) -> (boolean, R...), + running: () -> thread, + status: (thread) -> string, + -- FIXME: This technically returns a function, but we can't represent this yet. + wrap: ((A...) -> R...) -> any, + yield: (A...) -> R..., + isyieldable: () -> boolean, + close: (thread) -> (boolean, any) +} + +declare table: { + concat: ({V}, string?, number?, number?) -> string, + insert: (({V}, V) -> ()) & (({V}, number, V) -> ()), + maxn: ({V}) -> number, + remove: ({V}, number?) -> V?, + sort: ({V}, ((V, V) -> boolean)?) -> (), + create: (number, V?) -> {V}, + find: ({V}, V, number?) -> number?, + + unpack: ({V}, number?, number?) -> ...V, + pack: (...V) -> { n: number, [number]: V }, + + getn: ({V}) -> number, + foreach: ({[K]: V}, (K, V) -> ()) -> (), + foreachi: ({V}, (number, V) -> ()) -> (), + + move: ({V}, number, number, number, {V}?) -> {V}, + clear: ({[K]: V}) -> (), + + isfrozen: ({[K]: V}) -> boolean, +} + +declare debug: { + info: ((thread, number, string) -> R...) & ((number, string) -> R...) & (((A...) -> R1..., string) -> R2...), + traceback: ((string?, number?) -> string) & ((thread, string?, number?) -> string), +} + +declare utf8: { + char: (number, ...number) -> string, + charpattern: string, + codes: (string) -> ((string, number) -> (number, number), string, number), + -- FIXME + codepoint: (string, number?, number?) -> (number, ...number), + len: (string, number?, number?) -> (number?, number?), + offset: (string, number?, number?) -> number, + nfdnormalize: (string) -> string, + nfcnormalize: (string) -> string, + graphemes: (string, number?, number?) -> (() -> (number, number)), +} + +-- Cannot use `typeof` here because it will produce a polytype when we expect a monotype. +declare function unpack(tab: {V}, i: number?, j: number?): ...V )BUILTIN_SRC"; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 88069f1f5..26d3b76da 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -7,6 +7,8 @@ #include +LUAU_FASTFLAGVARIABLE(BetterDiagnosticCodesInStudio, false); + static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) { std::string s = "expects "; @@ -223,7 +225,14 @@ struct ErrorConverter std::string operator()(const Luau::SyntaxError& e) const { - return "Syntax error: " + e.message; + if (FFlag::BetterDiagnosticCodesInStudio) + { + return e.message; + } + else + { + return "Syntax error: " + e.message; + } } std::string operator()(const Luau::CodeTooComplex&) const diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index 76dc72d22..a330a98d1 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -14,6 +14,7 @@ LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) // Remove with FFlagLuauImmutableTypes LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) +LUAU_FASTFLAGVARIABLE(LuauCloneDeclaredGlobals, false) LUAU_FASTFLAG(LuauImmutableTypes) namespace Luau @@ -536,6 +537,12 @@ bool Module::clonePublicInterface() if (get(follow(ty))) *asMutable(ty) = AnyTypeVar{}; + if (FFlag::LuauCloneDeclaredGlobals) + { + for (auto& [name, ty] : declaredGlobals) + ty = clone(ty, interfaceTypes, seenTypes, seenTypePacks, cloneState); + } + freeze(internalTypes); freeze(interfaceTypes); diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 3fe4c90ef..41e8ce55f 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -29,22 +29,24 @@ LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) LUAU_FASTFLAGVARIABLE(LuauImmutableTypes, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) +LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false) LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions2, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) -LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify, false) +LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify2, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) LUAU_FASTFLAGVARIABLE(LuauAssertStripsFalsyTypes, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. -LUAU_FASTFLAG(LuauWidenIfSupertypeIsFree) +LUAU_FASTFLAG(LuauWidenIfSupertypeIsFree2) LUAU_FASTFLAGVARIABLE(LuauDoNotTryToReduce, false) LUAU_FASTFLAGVARIABLE(LuauDoNotAccidentallyDependOnPointerOrdering, false) LUAU_FASTFLAGVARIABLE(LuauFixArgumentCountMismatchAmountWithGenericTypes, false) LUAU_FASTFLAGVARIABLE(LuauFixIncorrectLineNumberDuplicateType, false) +LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) namespace Luau { @@ -1099,7 +1101,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco scope->bindings[name->local] = {anyIfNonstrict(quantify(funScope, ty, name->local->location)), name->local->location}; return; } - else if (auto name = function.name->as(); name && FFlag::LuauStatFunctionSimplify) + else if (auto name = function.name->as(); name && FFlag::LuauStatFunctionSimplify2) { TypeId exprTy = checkExpr(scope, *name->expr).type; TableTypeVar* ttv = getMutableTableType(exprTy); @@ -1111,7 +1113,10 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco reportError(TypeError{function.location, OnlyTablesCanHaveMethods{exprTy}}); } else if (ttv->state == TableState::Sealed) - reportError(TypeError{function.location, CannotExtendTable{exprTy, CannotExtendTable::Property, name->index.value}}); + { + if (!ttv->indexer || !isPrim(ttv->indexer->indexType, PrimitiveTypeVar::String)) + reportError(TypeError{function.location, CannotExtendTable{exprTy, CannotExtendTable::Property, name->index.value}}); + } ty = follow(ty); @@ -1134,7 +1139,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco if (ttv && ttv->state != TableState::Sealed) ttv->props[name->index.value] = {follow(quantify(funScope, ty, name->indexLocation)), /* deprecated */ false, {}, name->indexLocation}; } - else if (FFlag::LuauStatFunctionSimplify) + else if (FFlag::LuauStatFunctionSimplify2) { LUAU_ASSERT(function.name->is()); @@ -1144,7 +1149,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco } else if (function.func->self) { - LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify); + LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify2); AstExprIndexName* indexName = function.name->as(); if (!indexName) @@ -1183,7 +1188,7 @@ void TypeChecker::check(const ScopePtr& scope, TypeId ty, const ScopePtr& funSco } else { - LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify); + LUAU_ASSERT(!FFlag::LuauStatFunctionSimplify2); TypeId leftType = checkLValueBinding(scope, *function.name); @@ -1410,6 +1415,9 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatDeclareClass& declar { ftv->argNames.insert(ftv->argNames.begin(), FunctionArgument{"self", {}}); ftv->argTypes = addTypePack(TypePack{{classTy}, ftv->argTypes}); + + if (FFlag::LuauSelfCallAutocompleteFix) + ftv->hasSelf = true; } } @@ -1883,19 +1891,27 @@ std::optional TypeChecker::tryStripUnionFromNil(TypeId ty) { if (const UnionTypeVar* utv = get(ty)) { - bool hasNil = false; - - for (TypeId option : utv) + if (FFlag::LuauAnyInIsOptionalIsOptional) { - if (isNil(option)) + if (!std::any_of(begin(utv), end(utv), isNil)) + return ty; + } + else + { + bool hasNil = false; + + for (TypeId option : utv) { - hasNil = true; - break; + if (isNil(option)) + { + hasNil = true; + break; + } } - } - if (!hasNil) - return ty; + if (!hasNil) + return ty; + } std::vector result; @@ -1916,14 +1932,34 @@ std::optional TypeChecker::tryStripUnionFromNil(TypeId ty) TypeId TypeChecker::stripFromNilAndReport(TypeId ty, const Location& location) { - if (isOptional(ty)) + if (FFlag::LuauAnyInIsOptionalIsOptional) { - if (std::optional strippedUnion = tryStripUnionFromNil(follow(ty))) + ty = follow(ty); + + if (auto utv = get(ty)) + { + if (!std::any_of(begin(utv), end(utv), isNil)) + return ty; + + } + + if (std::optional strippedUnion = tryStripUnionFromNil(ty)) { reportError(location, OptionalValueAccess{ty}); return follow(*strippedUnion); } } + else + { + if (isOptional(ty)) + { + if (std::optional strippedUnion = tryStripUnionFromNil(follow(ty))) + { + reportError(location, OptionalValueAccess{ty}); + return follow(*strippedUnion); + } + } + } return ty; } @@ -2935,9 +2971,25 @@ TypeId TypeChecker::checkFunctionName(const ScopePtr& scope, AstExpr& funName, T return errorRecoveryType(scope); } - // Cannot extend sealed table, but we dont report an error here because it will be reported during AstStatFunction check - if (lhsType->persistent || ttv->state == TableState::Sealed) - return errorRecoveryType(scope); + if (FFlag::LuauStatFunctionSimplify2) + { + if (lhsType->persistent) + return errorRecoveryType(scope); + + // Cannot extend sealed table, but we dont report an error here because it will be reported during AstStatFunction check + if (ttv->state == TableState::Sealed) + { + if (ttv->indexer && isPrim(ttv->indexer->indexType, PrimitiveTypeVar::String)) + return ttv->indexer->indexResultType; + else + return errorRecoveryType(scope); + } + } + else + { + if (lhsType->persistent || ttv->state == TableState::Sealed) + return errorRecoveryType(scope); + } Name name = indexName->index.value; @@ -3393,7 +3445,7 @@ void TypeChecker::checkArgumentList( else if (state.log.getMutable(t)) { } // ok - else if (isNonstrictMode() && state.log.getMutable(t)) + else if (!FFlag::LuauAnyInIsOptionalIsOptional && isNonstrictMode() && state.log.getMutable(t)) { } // ok else @@ -3467,7 +3519,11 @@ void TypeChecker::checkArgumentList( } TypePackId varPack = addTypePack(TypePackVar{TypePack{rest, argIter.tail()}}); - state.tryUnify(varPack, tail); + if (FFlag::LuauWidenIfSupertypeIsFree2) + state.tryUnify(varPack, tail); + else + state.tryUnify(tail, varPack); + return; } else if (state.log.getMutable(tail)) @@ -3542,6 +3598,23 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A actualFunctionType = follow(actualFunctionType); + TypePackId retPack; + if (!FFlag::LuauWidenIfSupertypeIsFree2) + { + retPack = freshTypePack(scope->level); + } + else + { + if (auto free = get(actualFunctionType)) + { + retPack = freshTypePack(free->level); + TypePackId freshArgPack = freshTypePack(free->level); + *asMutable(actualFunctionType) = FunctionTypeVar(free->level, freshArgPack, retPack); + } + else + retPack = freshTypePack(scope->level); + } + // checkExpr will log the pre-instantiated type of the function. // That's not nearly as interesting as the instantiated type, which will include details about how // generic functions are being instantiated for this particular callsite. @@ -3550,8 +3623,6 @@ ExprResult TypeChecker::checkExprPack(const ScopePtr& scope, const A std::vector overloads = flattenIntersection(actualFunctionType); - TypePackId retPack = freshTypePack(scope->level); - std::vector> expectedTypes = getExpectedTypesForCall(overloads, expr.args.size, expr.self); ExprResult argListResult = checkExprList(scope, expr.location, expr.args, false, {}, expectedTypes); @@ -3682,7 +3753,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope // has been instantiated, so is a monotype. We can therefore // unify it with a monomorphic function. TypeId r = addType(FunctionTypeVar(scope->level, argPack, retPack)); - if (FFlag::LuauWidenIfSupertypeIsFree) + if (FFlag::LuauWidenIfSupertypeIsFree2) { UnifierOptions options; options.isFunctionCall = true; @@ -3772,7 +3843,7 @@ std::optional> TypeChecker::checkCallOverload(const Scope { state.log.commit(); - if (isNonstrictMode() && !expr.self && expr.func->is() && ftv->hasSelf) + if (!FFlag::LuauAnyInIsOptionalIsOptional && isNonstrictMode() && !expr.self && expr.func->is() && ftv->hasSelf) { // If we are running in nonstrict mode, passing fewer arguments than the function is declared to take AND // the function is declared with colon notation AND we use dot notation, warn. diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 5af2c8a62..895495352 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -26,6 +26,7 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit) LUAU_FASTFLAG(LuauErrorRecoveryType) LUAU_FASTFLAG(LuauSubtypingAddOptPropsToUnsealedTables) LUAU_FASTFLAG(LuauDiscriminableUnions2) +LUAU_FASTFLAGVARIABLE(LuauAnyInIsOptionalIsOptional, false) namespace Luau { @@ -201,11 +202,16 @@ bool isOptional(TypeId ty) if (isNil(ty)) return true; - auto utv = get(follow(ty)); + ty = follow(ty); + + if (FFlag::LuauAnyInIsOptionalIsOptional && get(ty)) + return true; + + auto utv = get(ty); if (!utv) return false; - return std::any_of(begin(utv), end(utv), isNil); + return std::any_of(begin(utv), end(utv), FFlag::LuauAnyInIsOptionalIsOptional ? isOptional : isNil); } bool isTableIntersection(TypeId ty) diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 7b781f269..60a9c9a5d 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -20,11 +20,13 @@ LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) -LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree, false) +LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree2, false) LUAU_FASTFLAGVARIABLE(LuauDifferentOrderOfUnificationDoesntMatter, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogSeesTypePacks2, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogCheckForInvalidation, false) +LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogDontRetryForIndexers, false) +LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) namespace Luau { @@ -272,12 +274,6 @@ TypePackId Widen::clean(TypePackId) bool Widen::ignoreChildren(TypeId ty) { - // Sometimes we unify ("hi") -> free1 with (free2) -> free3, so don't ignore functions. - // TODO: should we be doing this? we would need to rework how checkCallOverload does the unification. - if (log->is(ty)) - return false; - - // We only care about unions. return !log->is(ty); } @@ -990,7 +986,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal if (!log.getMutable(superTp)) { - log.replace(superTp, Unifiable::Bound(subTp)); + log.replace(superTp, Unifiable::Bound(widen(subTp))); } } else if (log.getMutable(subTp)) @@ -1107,7 +1103,7 @@ void Unifier::tryUnify_(TypePackId subTp, TypePackId superTp, bool isFunctionCal } // In nonstrict mode, any also marks an optional argument. - else if (superIter.good() && isNonstrictMode() && log.getMutable(log.follow(*superIter))) + else if (!FFlag::LuauAnyInIsOptionalIsOptional && superIter.good() && isNonstrictMode() && log.getMutable(log.follow(*superIter))) { superIter.advance(); continue; @@ -1280,6 +1276,13 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal tryUnify_(subFunction->retType, superFunction->retType); } + if (FFlag::LuauTxnLogRefreshFunctionPointers) + { + // Updating the log may have invalidated the function pointers + superFunction = log.getMutable(superTy); + subFunction = log.getMutable(subTy); + } + if (!FFlag::LuauImmutableTypes) { if (superFunction->definition && !subFunction->definition && !subTy->persistent) @@ -1357,10 +1360,18 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { auto subIter = subTable->props.find(propName); - bool isAny = log.getMutable(log.follow(superProp.type)); + if (FFlag::LuauAnyInIsOptionalIsOptional) + { + if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type)) + missingProperties.push_back(propName); + } + else + { + bool isAny = log.getMutable(log.follow(superProp.type)); - if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type) && !isAny) - missingProperties.push_back(propName); + if (subIter == subTable->props.end() && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && !isOptional(superProp.type) && !isAny) + missingProperties.push_back(propName); + } } if (!missingProperties.empty()) @@ -1378,9 +1389,17 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) { auto superIter = superTable->props.find(propName); - bool isAny = log.is(log.follow(subProp.type)); - if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || (!isOptional(subProp.type) && !isAny))) - extraProperties.push_back(propName); + if (FFlag::LuauAnyInIsOptionalIsOptional) + { + if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || !isOptional(subProp.type))) + extraProperties.push_back(propName); + } + else + { + bool isAny = log.is(log.follow(subProp.type)); + if (superIter == superTable->props.end() && (FFlag::LuauSubtypingAddOptPropsToUnsealedTables || (!isOptional(subProp.type) && !isAny))) + extraProperties.push_back(propName); + } } if (!extraProperties.empty()) @@ -1424,6 +1443,12 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) if (innerState.errors.empty()) log.concat(std::move(innerState.log)); } + else if (FFlag::LuauAnyInIsOptionalIsOptional && (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && isOptional(prop.type)) + // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` + // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. + // TODO: if the supertype is written to, the subtype may no longer be precise (alias analysis?) + { + } else if ((!FFlag::LuauSubtypingAddOptPropsToUnsealedTables || subTable->state == TableState::Unsealed) && (isOptional(prop.type) || get(follow(prop.type)))) // This is sound because unsealed table types are precise, so `{ p : T } <: { p : T, q : U? }` // since if `t : { p : T }` then we are guaranteed that `t.q` is `nil`. @@ -1497,6 +1522,9 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) else if (variance == Covariant) { } + else if (FFlag::LuauAnyInIsOptionalIsOptional && !FFlag::LuauSubtypingAddOptPropsToUnsealedTables && isOptional(prop.type)) + { + } else if (!FFlag::LuauSubtypingAddOptPropsToUnsealedTables && (isOptional(prop.type) || get(follow(prop.type)))) { } @@ -1618,7 +1646,7 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) TypeId Unifier::widen(TypeId ty) { - if (!FFlag::LuauWidenIfSupertypeIsFree) + if (!FFlag::LuauWidenIfSupertypeIsFree2) return ty; Widen widen{types}; @@ -1627,10 +1655,21 @@ TypeId Unifier::widen(TypeId ty) return result.value_or(ty); } +TypePackId Unifier::widen(TypePackId tp) +{ + if (!FFlag::LuauWidenIfSupertypeIsFree2) + return tp; + + Widen widen{types}; + std::optional result = widen.substitute(tp); + // TODO: what does it mean for substitution to fail to widen? + return result.value_or(tp); +} + TypeId Unifier::deeplyOptional(TypeId ty, std::unordered_map seen) { ty = follow(ty); - if (get(ty)) + if (!FFlag::LuauAnyInIsOptionalIsOptional && get(ty)) return ty; else if (isOptional(ty)) return ty; @@ -1744,7 +1783,10 @@ void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) { if (auto subProp = findTablePropertyRespectingMeta(subTy, freeName)) { - tryUnify_(freeProp.type, *subProp); + if (FFlag::LuauWidenIfSupertypeIsFree2) + tryUnify_(*subProp, freeProp.type); + else + tryUnify_(freeProp.type, *subProp); /* * TypeVars are commonly cyclic, so it is entirely possible diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 1cb8f1343..941a3ea4f 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -11,18 +11,11 @@ LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) -LUAU_FASTFLAGVARIABLE(LuauParseAllHotComments, false) LUAU_FASTFLAGVARIABLE(LuauTableFieldFunctionDebugname, false) namespace Luau { -static bool isComment(const Lexeme& lexeme) -{ - LUAU_ASSERT(!FFlag::LuauParseAllHotComments); - return lexeme.type == Lexeme::Comment || lexeme.type == Lexeme::BlockComment; -} - ParseError::ParseError(const Location& location, const std::string& message) : location(location) , message(message) @@ -146,54 +139,13 @@ ParseResult Parser::parse(const char* buffer, size_t bufferSize, AstNameTable& n { LUAU_TIMETRACE_SCOPE("Parser::parse", "Parser"); - Parser p(buffer, bufferSize, names, allocator, FFlag::LuauParseAllHotComments ? options : ParseOptions()); + Parser p(buffer, bufferSize, names, allocator, options); try { - if (FFlag::LuauParseAllHotComments) - { - AstStatBlock* root = p.parseChunk(); - - return ParseResult{root, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)}; - } - else - { - std::vector hotcomments; - - while (isComment(p.lexer.current()) || p.lexer.current().type == Lexeme::BrokenComment) - { - const char* text = p.lexer.current().data; - unsigned int length = p.lexer.current().length; - - if (length && text[0] == '!') - { - unsigned int end = length; - while (end > 0 && isSpace(text[end - 1])) - --end; - - hotcomments.push_back({true, p.lexer.current().location, std::string(text + 1, text + end)}); - } - - const Lexeme::Type type = p.lexer.current().type; - const Location loc = p.lexer.current().location; - - if (options.captureComments) - p.commentLocations.push_back(Comment{type, loc}); - - if (type == Lexeme::BrokenComment) - break; - - p.lexer.next(); - } + AstStatBlock* root = p.parseChunk(); - p.lexer.setSkipComments(true); - - p.options = options; - - AstStatBlock* root = p.parseChunk(); - - return ParseResult{root, hotcomments, p.parseErrors, std::move(p.commentLocations)}; - } + return ParseResult{root, std::move(p.hotcomments), std::move(p.parseErrors), std::move(p.commentLocations)}; } catch (ParseError& err) { @@ -225,10 +177,11 @@ Parser::Parser(const char* buffer, size_t bufferSize, AstNameTable& names, Alloc matchRecoveryStopOnToken.assign(Lexeme::Type::Reserved_END, 0); matchRecoveryStopOnToken[Lexeme::Type::Eof] = 1; - if (FFlag::LuauParseAllHotComments) - lexer.setSkipComments(true); + // required for lookahead() to work across a comment boundary and for nextLexeme() to work when captureComments is false + lexer.setSkipComments(true); - // read first lexeme + // read first lexeme (any hot comments get .header = true) + LUAU_ASSERT(hotcommentHeader); nextLexeme(); // all hot comments parsed after the first non-comment lexeme are special in that they don't affect type checking / linting mode @@ -2831,49 +2784,31 @@ void Parser::nextLexeme() { if (options.captureComments) { - if (FFlag::LuauParseAllHotComments) - { - Lexeme::Type type = lexer.next(/* skipComments= */ false).type; + Lexeme::Type type = lexer.next(/* skipComments= */ false).type; - while (type == Lexeme::BrokenComment || type == Lexeme::Comment || type == Lexeme::BlockComment) - { - const Lexeme& lexeme = lexer.current(); - commentLocations.push_back(Comment{lexeme.type, lexeme.location}); - - // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. - // The parser will turn this into a proper syntax error. - if (lexeme.type == Lexeme::BrokenComment) - return; + while (type == Lexeme::BrokenComment || type == Lexeme::Comment || type == Lexeme::BlockComment) + { + const Lexeme& lexeme = lexer.current(); + commentLocations.push_back(Comment{lexeme.type, lexeme.location}); - // Comments starting with ! are called "hot comments" and contain directives for type checking / linting - if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!') - { - const char* text = lexeme.data; + // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. + // The parser will turn this into a proper syntax error. + if (lexeme.type == Lexeme::BrokenComment) + return; - unsigned int end = lexeme.length; - while (end > 0 && isSpace(text[end - 1])) - --end; + // Comments starting with ! are called "hot comments" and contain directives for type checking / linting + if (lexeme.type == Lexeme::Comment && lexeme.length && lexeme.data[0] == '!') + { + const char* text = lexeme.data; - hotcomments.push_back({hotcommentHeader, lexeme.location, std::string(text + 1, text + end)}); - } + unsigned int end = lexeme.length; + while (end > 0 && isSpace(text[end - 1])) + --end; - type = lexer.next(/* skipComments= */ false).type; - } - } - else - { - while (true) - { - const Lexeme& lexeme = lexer.next(/*skipComments*/ false); - // Subtlety: Broken comments are weird because we record them as comments AND pass them to the parser as a lexeme. - // The parser will turn this into a proper syntax error. - if (lexeme.type == Lexeme::BrokenComment) - commentLocations.push_back(Comment{lexeme.type, lexeme.location}); - if (isComment(lexeme)) - commentLocations.push_back(Comment{lexeme.type, lexeme.location}); - else - return; + hotcomments.push_back({hotcommentHeader, lexeme.location, std::string(text + 1, text + end)}); } + + type = lexer.next(/* skipComments= */ false).type; } } else diff --git a/Compiler/src/Builtins.cpp b/Compiler/src/Builtins.cpp index 26360c495..ff7531128 100644 --- a/Compiler/src/Builtins.cpp +++ b/Compiler/src/Builtins.cpp @@ -4,8 +4,6 @@ #include "Luau/Bytecode.h" #include "Luau/Compiler.h" -LUAU_FASTFLAGVARIABLE(LuauCompileSelectBuiltin2, false) - namespace Luau { namespace Compile @@ -64,7 +62,7 @@ int getBuiltinFunctionId(const Builtin& builtin, const CompileOptions& options) if (builtin.isGlobal("unpack")) return LBF_TABLE_UNPACK; - if (FFlag::LuauCompileSelectBuiltin2 && builtin.isGlobal("select")) + if (builtin.isGlobal("select")) return LBF_SELECT_VARARG; if (builtin.object == "math") diff --git a/Compiler/src/Compiler.cpp b/Compiler/src/Compiler.cpp index 656a99265..6330bf1ff 100644 --- a/Compiler/src/Compiler.cpp +++ b/Compiler/src/Compiler.cpp @@ -15,8 +15,6 @@ #include #include -LUAU_FASTFLAG(LuauCompileSelectBuiltin2) - namespace Luau { @@ -265,7 +263,6 @@ struct Compiler void compileExprSelectVararg(AstExprCall* expr, uint8_t target, uint8_t targetCount, bool targetTop, bool multRet, uint8_t regs) { - LUAU_ASSERT(FFlag::LuauCompileSelectBuiltin2); LUAU_ASSERT(targetCount == 1); LUAU_ASSERT(!expr->self); LUAU_ASSERT(expr->args.size == 2 && expr->args.data[1]->is()); @@ -407,7 +404,6 @@ struct Compiler if (bfid == LBF_SELECT_VARARG) { - LUAU_ASSERT(FFlag::LuauCompileSelectBuiltin2); // Optimization: compile select(_, ...) as FASTCALL1; the builtin will read variadic arguments directly // note: for now we restrict this to single-return expressions since our runtime code doesn't deal with general cases if (multRet == false && targetCount == 1 && expr->args.size == 2 && expr->args.data[1]->is()) diff --git a/Sources.cmake b/Sources.cmake index 615641eb1..59b384971 100644 --- a/Sources.cmake +++ b/Sources.cmake @@ -232,11 +232,18 @@ if(TARGET Luau.UnitTest) tests/Transpiler.test.cpp tests/TypeInfer.aliases.test.cpp tests/TypeInfer.annotations.test.cpp + tests/TypeInfer.anyerror.test.cpp tests/TypeInfer.builtins.test.cpp tests/TypeInfer.classes.test.cpp tests/TypeInfer.definitions.test.cpp + tests/TypeInfer.functions.test.cpp tests/TypeInfer.generics.test.cpp tests/TypeInfer.intersectionTypes.test.cpp + tests/TypeInfer.loops.test.cpp + tests/TypeInfer.modules.test.cpp + tests/TypeInfer.oop.test.cpp + tests/TypeInfer.operators.test.cpp + tests/TypeInfer.primitives.test.cpp tests/TypeInfer.provisional.test.cpp tests/TypeInfer.refinements.test.cpp tests/TypeInfer.singletons.test.cpp diff --git a/VM/include/lua.h b/VM/include/lua.h index 274c4ed9f..d08b73eab 100644 --- a/VM/include/lua.h +++ b/VM/include/lua.h @@ -172,9 +172,12 @@ LUA_API const char* lua_pushvfstring(lua_State* L, const char* fmt, va_list argp LUA_API LUA_PRINTF_ATTR(2, 3) const char* lua_pushfstringL(lua_State* L, const char* fmt, ...); LUA_API void lua_pushcclosurek(lua_State* L, lua_CFunction fn, const char* debugname, int nup, lua_Continuation cont); LUA_API void lua_pushboolean(lua_State* L, int b); -LUA_API void lua_pushlightuserdata(lua_State* L, void* p); LUA_API int lua_pushthread(lua_State* L); +LUA_API void lua_pushlightuserdata(lua_State* L, void* p); +LUA_API void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag); +LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)); + /* ** get functions (Lua -> stack) */ @@ -189,8 +192,6 @@ LUA_API void lua_setreadonly(lua_State* L, int idx, int enabled); LUA_API int lua_getreadonly(lua_State* L, int idx); LUA_API void lua_setsafeenv(lua_State* L, int idx, int enabled); -LUA_API void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag); -LUA_API void* lua_newuserdatadtor(lua_State* L, size_t sz, void (*dtor)(void*)); LUA_API int lua_getmetatable(lua_State* L, int objindex); LUA_API void lua_getfenv(lua_State* L, int idx); @@ -276,6 +277,14 @@ enum lua_GCOp LUA_API int lua_gc(lua_State* L, int what, int data); +/* +** memory statistics +** all allocated bytes are attributed to the memory category of the running thread (0..LUA_MEMORY_CATEGORIES-1) +*/ + +LUA_API void lua_setmemcat(lua_State* L, int category); +LUA_API size_t lua_totalbytes(lua_State* L, int category); + /* ** miscellaneous functions */ diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index f7f154428..3c0873147 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -35,8 +35,8 @@ const char* luau_ident = "$Luau: Copyright (C) 2019-2022 Roblox Corporation $\n" static Table* getcurrenv(lua_State* L) { - if (L->ci == L->base_ci) /* no enclosing function? */ - return L->gt; /* use global table as environment */ + if (L->ci == L->base_ci) /* no enclosing function? */ + return L->gt; /* use global table as environment */ else return curr_func(L)->env; } @@ -1188,7 +1188,7 @@ void lua_concat(lua_State* L, int n) void* lua_newuserdatatagged(lua_State* L, size_t sz, int tag) { - api_check(L, unsigned(tag) < LUA_UTAG_LIMIT); + api_check(L, unsigned(tag) < LUA_UTAG_LIMIT || tag == UTAG_PROXY); luaC_checkGC(L); luaC_checkthreadsleep(L); Udata* u = luaU_newudata(L, sz, tag); @@ -1317,7 +1317,7 @@ void lua_setuserdatadtor(lua_State* L, int tag, void (*dtor)(void*)) L->global->udatagc[tag] = dtor; } -LUA_API void lua_clonefunction(lua_State* L, int idx) +void lua_clonefunction(lua_State* L, int idx) { StkId p = index2addr(L, idx); api_check(L, isLfunction(p)); @@ -1333,3 +1333,15 @@ lua_Callbacks* lua_callbacks(lua_State* L) { return &L->global->cb; } + +void lua_setmemcat(lua_State* L, int category) +{ + api_check(L, unsigned(category) < LUA_MEMORY_CATEGORIES); + L->activememcat = uint8_t(category); +} + +size_t lua_totalbytes(lua_State* L, int category) +{ + api_check(L, category < LUA_MEMORY_CATEGORIES); + return category < 0 ? L->global->totalbytes : L->global->memcatbytes[category]; +} diff --git a/VM/src/lbaselib.cpp b/VM/src/lbaselib.cpp index 2307598e8..96ad493b0 100644 --- a/VM/src/lbaselib.cpp +++ b/VM/src/lbaselib.cpp @@ -5,6 +5,7 @@ #include "lstate.h" #include "lapi.h" #include "ldo.h" +#include "ludata.h" #include #include @@ -190,6 +191,7 @@ static int luaB_type(lua_State* L) luaL_checkany(L, 1); if (DFFlag::LuauMorePreciseLuaLTypeName) { + /* resulting name doesn't differentiate between userdata types */ lua_pushstring(L, lua_typename(L, lua_type(L, 1))); } else @@ -202,8 +204,16 @@ static int luaB_type(lua_State* L) static int luaB_typeof(lua_State* L) { luaL_checkany(L, 1); - const TValue* obj = luaA_toobject(L, 1); - lua_pushstring(L, luaT_objtypename(L, obj)); + if (DFFlag::LuauMorePreciseLuaLTypeName) + { + /* resulting name returns __type if specified unless the input is a newproxy-created userdata */ + lua_pushstring(L, luaL_typename(L, 1)); + } + else + { + const TValue* obj = luaA_toobject(L, 1); + lua_pushstring(L, luaT_objtypename(L, obj)); + } return 1; } @@ -412,7 +422,7 @@ static int luaB_newproxy(lua_State* L) bool needsmt = lua_toboolean(L, 1); - lua_newuserdata(L, 0); + lua_newuserdatatagged(L, 0, UTAG_PROXY); if (needsmt) { diff --git a/VM/src/lgc.h b/VM/src/lgc.h index ad8ee78a0..ebf999b53 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -9,9 +9,9 @@ /* ** Default settings for GC tunables (settable via lua_gc) */ -#define LUAI_GCGOAL 200 /* 200% (allow heap to double compared to live heap size) */ +#define LUAI_GCGOAL 200 /* 200% (allow heap to double compared to live heap size) */ #define LUAI_GCSTEPMUL 200 /* GC runs 'twice the speed' of memory allocation */ -#define LUAI_GCSTEPSIZE 1 /* GC runs every KB of memory allocation */ +#define LUAI_GCSTEPSIZE 1 /* GC runs every KB of memory allocation */ /* ** Possible states of the Garbage Collector diff --git a/VM/src/lmem.cpp b/VM/src/lmem.cpp index 899cb0c0c..3cbdafff2 100644 --- a/VM/src/lmem.cpp +++ b/VM/src/lmem.cpp @@ -98,7 +98,7 @@ */ #if defined(__APPLE__) #define ABISWITCH(x64, ms32, gcc32) (sizeof(void*) == 8 ? x64 : gcc32) -#elif defined(__i386__) +#elif defined(__i386__) && !defined(_MSC_VER) #define ABISWITCH(x64, ms32, gcc32) (gcc32) #else // Android somehow uses a similar ABI to MSVC, *not* to iOS... diff --git a/VM/src/lstring.cpp b/VM/src/lstring.cpp index 872501468..c0cd3e261 100644 --- a/VM/src/lstring.cpp +++ b/VM/src/lstring.cpp @@ -53,7 +53,7 @@ void luaS_resize(lua_State* L, int newsize) { TString* p = tb->hash[i]; while (p) - { /* for each node in the list */ + { /* for each node in the list */ TString* next = p->next; /* save next */ unsigned int h = p->hash; int h1 = lmod(h, newsize); /* new position */ diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index ef0b4b93a..2deec2b9a 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -24,6 +24,8 @@ #include +LUAU_FASTFLAGVARIABLE(LuauTableRehashRework, false) + // max size of both array and hash part is 2^MAXBITS #define MAXBITS 26 #define MAXSIZE (1 << MAXBITS) @@ -351,6 +353,22 @@ static void setnodevector(lua_State* L, Table* t, int size) t->lastfree = size; /* all positions are free */ } +static TValue* newkey(lua_State* L, Table* t, const TValue* key); + +static TValue* arrayornewkey(lua_State* L, Table* t, const TValue* key) +{ + if (ttisnumber(key)) + { + int k; + double n = nvalue(key); + luai_num2int(k, n); + if (luai_numeq(cast_num(k), n) && cast_to(unsigned int, k - 1) < cast_to(unsigned int, t->sizearray)) + return &t->array[k - 1]; + } + + return newkey(L, t, key); +} + static void resize(lua_State* L, Table* t, int nasize, int nhsize) { if (nasize > MAXSIZE || nhsize > MAXSIZE) @@ -369,22 +387,50 @@ static void resize(lua_State* L, Table* t, int nasize, int nhsize) for (int i = nasize; i < oldasize; i++) { if (!ttisnil(&t->array[i])) - setobjt2t(L, luaH_setnum(L, t, i + 1), &t->array[i]); + { + if (FFlag::LuauTableRehashRework) + { + TValue ok; + setnvalue(&ok, cast_num(i + 1)); + setobjt2t(L, newkey(L, t, &ok), &t->array[i]); + } + else + { + setobjt2t(L, luaH_setnum(L, t, i + 1), &t->array[i]); + } + } } /* shrink array */ luaM_reallocarray(L, t->array, oldasize, nasize, TValue, t->memcat); } /* re-insert elements from hash part */ - for (int i = twoto(oldhsize) - 1; i >= 0; i--) + if (FFlag::LuauTableRehashRework) { - LuaNode* old = nold + i; - if (!ttisnil(gval(old))) + for (int i = twoto(oldhsize) - 1; i >= 0; i--) { - TValue ok; - getnodekey(L, &ok, old); - setobjt2t(L, luaH_set(L, t, &ok), gval(old)); + LuaNode* old = nold + i; + if (!ttisnil(gval(old))) + { + TValue ok; + getnodekey(L, &ok, old); + setobjt2t(L, arrayornewkey(L, t, &ok), gval(old)); + } } } + else + { + for (int i = twoto(oldhsize) - 1; i >= 0; i--) + { + LuaNode* old = nold + i; + if (!ttisnil(gval(old))) + { + TValue ok; + getnodekey(L, &ok, old); + setobjt2t(L, luaH_set(L, t, &ok), gval(old)); + } + } + } + if (nold != dummynode) luaM_freearray(L, nold, twoto(oldhsize), LuaNode, t->memcat); /* free old array */ } @@ -482,7 +528,16 @@ static TValue* newkey(lua_State* L, Table* t, const TValue* key) if (n == NULL) { /* cannot find a free place? */ rehash(L, t, key); /* grow table */ - return luaH_set(L, t, key); /* re-insert key into grown table */ + + if (!FFlag::LuauTableRehashRework) + { + return luaH_set(L, t, key); /* re-insert key into grown table */ + } + else + { + // after rehash, numeric keys might be located in the new array part, but won't be found in the node part + return arrayornewkey(L, t, key); + } } LUAU_ASSERT(n != dummynode); TValue mk; diff --git a/VM/src/ltm.cpp b/VM/src/ltm.cpp index a77a7c720..106efb2b4 100644 --- a/VM/src/ltm.cpp +++ b/VM/src/ltm.cpp @@ -4,6 +4,7 @@ #include "lstate.h" #include "lstring.h" +#include "ludata.h" #include "ltable.h" #include "lgc.h" @@ -116,7 +117,7 @@ const TValue* luaT_gettmbyobj(lua_State* L, const TValue* o, TMS event) const TString* luaT_objtypenamestr(lua_State* L, const TValue* o) { - if (ttisuserdata(o) && uvalue(o)->tag && uvalue(o)->metatable) + if (ttisuserdata(o) && uvalue(o)->tag != UTAG_PROXY && uvalue(o)->metatable) { const TValue* type = luaH_getstr(uvalue(o)->metatable, L->global->tmname[TM_TYPE]); diff --git a/VM/src/ludata.cpp b/VM/src/ludata.cpp index 0dfac508f..819d1863c 100644 --- a/VM/src/ludata.cpp +++ b/VM/src/ludata.cpp @@ -22,13 +22,11 @@ Udata* luaU_newudata(lua_State* L, size_t s, int tag) void luaU_freeudata(lua_State* L, Udata* u, lua_Page* page) { - LUAU_ASSERT(u->tag < LUA_UTAG_LIMIT || u->tag == UTAG_IDTOR); - void (*dtor)(void*) = nullptr; - if (u->tag == UTAG_IDTOR) - memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor)); - else if (u->tag) + if (u->tag < LUA_UTAG_LIMIT) dtor = L->global->udatagc[u->tag]; + else if (u->tag == UTAG_IDTOR) + memcpy(&dtor, &u->data + u->len - sizeof(dtor), sizeof(dtor)); if (dtor) dtor(u->data); diff --git a/VM/src/ludata.h b/VM/src/ludata.h index ec374c28b..f24e4a328 100644 --- a/VM/src/ludata.h +++ b/VM/src/ludata.h @@ -7,6 +7,9 @@ /* special tag value is used for user data with inline dtors */ #define UTAG_IDTOR LUA_UTAG_LIMIT +/* special tag value is used for newproxy-created user data (all other user data objects are host-exposed) */ +#define UTAG_PROXY (LUA_UTAG_LIMIT + 1) + #define sizeudata(len) (offsetof(Udata, data) + len) LUAI_FUNC Udata* luaU_newudata(lua_State* L, size_t s, int tag); diff --git a/VM/src/lvmexecute.cpp b/VM/src/lvmexecute.cpp index 6c31d36f2..96a87b7ea 100644 --- a/VM/src/lvmexecute.cpp +++ b/VM/src/lvmexecute.cpp @@ -77,6 +77,11 @@ if (LUAU_UNLIKELY(!!interrupt)) \ { /* the interrupt hook is called right before we advance pc */ \ VM_PROTECT(L->ci->savedpc++; interrupt(L, -1)); \ + if (L->status != 0) \ + { \ + L->ci->savedpc--; \ + goto exit; \ + } \ } \ } #endif diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 55b0618ad..17fd6b133 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2755,4 +2755,166 @@ local abc = b@1 CHECK(ac.entryMap["bar"].parens == ParenthesesRecommendation::CursorInside); } +TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_on_class") +{ + ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + + loadDefinition(R"( +declare class Foo + function one(self): number + two: () -> number +end + )"); + + { + check(R"( +local t: Foo +t:@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("one")); + REQUIRE(ac.entryMap.count("two")); + CHECK(!ac.entryMap["one"].wrongIndexType); + CHECK(ac.entryMap["two"].wrongIndexType); + } + + { + check(R"( +local t: Foo +t.@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("one")); + REQUIRE(ac.entryMap.count("two")); + CHECK(ac.entryMap["one"].wrongIndexType); + CHECK(!ac.entryMap["two"].wrongIndexType); + } +} + +TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls") +{ + ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + + check(R"( +local t = {} +function t.m() end +t:@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("m")); + CHECK(ac.entryMap["m"].wrongIndexType); +} + +TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_2") +{ + ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + + check(R"( +local f: (() -> number) & ((number) -> number) = function(x: number?) return 2 end +local t = {} +t.f = f +t:@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("f")); + CHECK(ac.entryMap["f"].wrongIndexType); +} + +TEST_CASE_FIXTURE(ACFixture, "no_incompatible_self_calls_provisional") +{ + check(R"( +local t = {} +function t.m(x: typeof(t)) end +t:@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("m")); + // We can make changes to mark this as a wrong way to call even though it's compatible + CHECK(!ac.entryMap["m"].wrongIndexType); +} + +TEST_CASE_FIXTURE(ACFixture, "string_prim_self_calls_are_fine") +{ + ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + + check(R"( +local s = "hello" +s:@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("byte")); + CHECK(ac.entryMap["byte"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("char")); + CHECK(ac.entryMap["char"].wrongIndexType == true); + REQUIRE(ac.entryMap.count("sub")); + CHECK(ac.entryMap["sub"].wrongIndexType == false); +} + +TEST_CASE_FIXTURE(ACFixture, "string_prim_non_self_calls_are_avoided") +{ + ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + + check(R"( +local s = "hello" +s.@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("byte")); + CHECK(ac.entryMap["byte"].wrongIndexType == true); + REQUIRE(ac.entryMap.count("char")); + CHECK(ac.entryMap["char"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("sub")); + CHECK(ac.entryMap["sub"].wrongIndexType == true); +} + +TEST_CASE_FIXTURE(ACFixture, "string_library_non_self_calls_are_fine") +{ + ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + + check(R"( +string.@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("byte")); + CHECK(ac.entryMap["byte"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("char")); + CHECK(ac.entryMap["char"].wrongIndexType == false); + REQUIRE(ac.entryMap.count("sub")); + CHECK(ac.entryMap["sub"].wrongIndexType == false); +} + +TEST_CASE_FIXTURE(ACFixture, "string_library_self_calls_are_invalid") +{ + ScopedFastFlag selfCallAutocompleteFix{"LuauSelfCallAutocompleteFix", true}; + + check(R"( +string:@1 + )"); + + auto ac = autocomplete('1'); + + REQUIRE(ac.entryMap.count("byte")); + CHECK(ac.entryMap["byte"].wrongIndexType == true); + REQUIRE(ac.entryMap.count("char")); + CHECK(ac.entryMap["char"].wrongIndexType == true); + REQUIRE(ac.entryMap.count("sub")); + CHECK(ac.entryMap["sub"].wrongIndexType == true); +} + TEST_SUITE_END(); diff --git a/tests/Compiler.test.cpp b/tests/Compiler.test.cpp index f982c86fa..3dc57da0c 100644 --- a/tests/Compiler.test.cpp +++ b/tests/Compiler.test.cpp @@ -2819,8 +2819,6 @@ RETURN R1 -1 TEST_CASE("FastcallSelect") { - ScopedFastFlag sff("LuauCompileSelectBuiltin2", true); - // select(_, ...) compiles to a builtin call CHECK_EQ("\n" + compileFunction0("return (select('#', ...))"), R"( LOADK R1 K0 diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 63fbb363b..9e4cb4a59 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -569,6 +569,11 @@ TEST_CASE("Debugger") CHECK(lua_tointeger(L, -1) == 50); lua_pop(L, 1); + int v = lua_getargument(L, 0, 2); + REQUIRE(v); + CHECK(lua_tointeger(L, -1) == 42); + lua_pop(L, 1); + // test lua_getlocal const char* l = lua_getlocal(L, 0, 1); REQUIRE(l); @@ -652,31 +657,6 @@ TEST_CASE("SameHash") CHECK(luaS_hash(buf + 1, 120) == luaS_hash(buf + 2, 120)); } -TEST_CASE("InlineDtor") -{ - static int dtorhits = 0; - - dtorhits = 0; - - { - StateRef globalState(luaL_newstate(), lua_close); - lua_State* L = globalState.get(); - - void* u1 = lua_newuserdatadtor(L, 4, [](void* data) { - dtorhits += *(int*)data; - }); - - void* u2 = lua_newuserdatadtor(L, 1, [](void* data) { - dtorhits += *(char*)data; - }); - - *(int*)u1 = 39; - *(char*)u2 = 3; - } - - CHECK(dtorhits == 42); -} - TEST_CASE("Reference") { static int dtorhits = 0; @@ -969,7 +949,7 @@ TEST_CASE("StringConversion") TEST_CASE("GCDump") { // internal function, declared in lgc.h - not exposed via lua.h - extern void luaC_dump(lua_State* L, void* file, const char* (*categoryName)(lua_State* L, uint8_t memcat)); + extern void luaC_dump(lua_State * L, void* file, const char* (*categoryName)(lua_State * L, uint8_t memcat)); StateRef globalState(luaL_newstate(), lua_close); lua_State* L = globalState.get(); @@ -1015,4 +995,114 @@ TEST_CASE("GCDump") fclose(f); } +TEST_CASE("Interrupt") +{ + static const int expectedhits[] = { + 2, + 9, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 5, + 6, + 11, + }; + static int index; + + index = 0; + + runConformance( + "interrupt.lua", + [](lua_State* L) { + auto* cb = lua_callbacks(L); + + // note: for simplicity here we setup the interrupt callback once + // however, this carries a noticeable performance cost. in a real application, + // it's advised to set interrupt callback on a timer from a different thread, + // and set it back to nullptr once the interrupt triggered. + cb->interrupt = [](lua_State* L, int gc) { + if (gc >= 0) + return; + + CHECK(index < int(std::size(expectedhits))); + + lua_Debug ar = {}; + lua_getinfo(L, 0, "l", &ar); + + CHECK(ar.currentline == expectedhits[index]); + + index++; + + // check that we can yield inside an interrupt + if (index == 5) + lua_yield(L, 0); + }; + }, + [](lua_State* L) { + CHECK(index == 5); // a single yield point + }); + + CHECK(index == int(std::size(expectedhits))); +} + +TEST_CASE("UserdataApi") +{ + static int dtorhits = 0; + + dtorhits = 0; + + StateRef globalState(luaL_newstate(), lua_close); + lua_State* L = globalState.get(); + + // setup dtor for tag 42 (created later) + lua_setuserdatadtor(L, 42, [](void* data) { + dtorhits += *(int*)data; + }); + + // light user data + int lud; + lua_pushlightuserdata(L, &lud); + + CHECK(lua_touserdata(L, -1) == &lud); + CHECK(lua_topointer(L, -1) == &lud); + + // regular user data + int* ud1 = (int*)lua_newuserdata(L, 4); + *ud1 = 42; + + CHECK(lua_touserdata(L, -1) == ud1); + CHECK(lua_topointer(L, -1) == ud1); + + // tagged user data + int* ud2 = (int*)lua_newuserdatatagged(L, 4, 42); + *ud2 = -4; + + CHECK(lua_touserdatatagged(L, -1, 42) == ud2); + CHECK(lua_touserdatatagged(L, -1, 41) == nullptr); + CHECK(lua_userdatatag(L, -1) == 42); + + // user data with inline dtor + void* ud3 = lua_newuserdatadtor(L, 4, [](void* data) { + dtorhits += *(int*)data; + }); + + void* ud4 = lua_newuserdatadtor(L, 1, [](void* data) { + dtorhits += *(char*)data; + }); + + *(int*)ud3 = 43; + *(char*)ud4 = 3; + + globalState.reset(); + + CHECK(dtorhits == 42); +} + TEST_SUITE_END(); diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 4d6c207cc..91b23197c 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -1667,8 +1667,6 @@ _ = (math.random() < 0.5 and false) or 42 -- currently ignored TEST_CASE_FIXTURE(Fixture, "WrongComment") { - ScopedFastFlag sff("LuauParseAllHotComments", true); - LintResult result = lint(R"( --!strict --!struct diff --git a/tests/Module.test.cpp b/tests/Module.test.cpp index e3993cc53..82b7a3509 100644 --- a/tests/Module.test.cpp +++ b/tests/Module.test.cpp @@ -248,10 +248,12 @@ TEST_CASE_FIXTURE(Fixture, "clone_seal_free_tables") TEST_CASE_FIXTURE(Fixture, "clone_self_property") { + ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; + fileResolver.source["Module/A"] = R"( --!nonstrict local a = {} - function a:foo(x) + function a:foo(x: number) return -x; end return a; @@ -267,10 +269,10 @@ TEST_CASE_FIXTURE(Fixture, "clone_self_property") )"; result = frontend.check("Module/B"); - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(result.errors[0]), "This function was declared to accept self, but you did not pass enough arguments. Use a colon instead of a " - "dot or pass 1 extra nil to suppress this warning"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("This function must be called with self. Did you mean to use a colon instead of a dot?", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "clone_recursion_limit") diff --git a/tests/NonstrictMode.test.cpp b/tests/NonstrictMode.test.cpp index 5bad99014..d3faea2aa 100644 --- a/tests/NonstrictMode.test.cpp +++ b/tests/NonstrictMode.test.cpp @@ -126,6 +126,8 @@ TEST_CASE_FIXTURE(Fixture, "parameters_having_type_any_are_optional") TEST_CASE_FIXTURE(Fixture, "local_tables_are_not_any") { + ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; + CheckResult result = check(R"( --!nonstrict local T = {} @@ -136,31 +138,25 @@ TEST_CASE_FIXTURE(Fixture, "local_tables_are_not_any") T:staticmethod() )"); - LUAU_REQUIRE_ERROR_COUNT(2, result); + LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(std::any_of(result.errors.begin(), result.errors.end(), [](const TypeError& e) { - return get(e); - })); - CHECK(std::any_of(result.errors.begin(), result.errors.end(), [](const TypeError& e) { - return get(e); - })); + CHECK_EQ("This function does not take self. Did you mean to use a dot instead of a colon?", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "offer_a_hint_if_you_use_a_dot_instead_of_a_colon") { + ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; + CheckResult result = check(R"( --!nonstrict local T = {} - function T:method() end - T.method() + function T:method(x: number) end + T.method(5) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - auto e = get(result.errors[0]); - REQUIRE(e != nullptr); - - REQUIRE_EQ(1, e->requiredExtraNils); + CHECK_EQ("This function must be called with self. Did you mean to use a colon instead of a dot?", toString(result.errors[0])); } TEST_CASE_FIXTURE(Fixture, "table_props_are_any") diff --git a/tests/TypeInfer.aliases.test.cpp b/tests/TypeInfer.aliases.test.cpp index 711c0aa1d..b2e760528 100644 --- a/tests/TypeInfer.aliases.test.cpp +++ b/tests/TypeInfer.aliases.test.cpp @@ -676,4 +676,25 @@ TEST_CASE_FIXTURE(Fixture, "forward_declared_alias_is_not_clobbered_by_prior_uni LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok") +{ + CheckResult result = check(R"( + type Tree = { data: T, children: {Tree} } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_not_ok") +{ + ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; + + CheckResult result = check(R"( + -- this would be an infinite type if we allowed it + type Tree = { data: T, children: {Tree<{T}>} } + )"); + + LUAU_REQUIRE_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.anyerror.test.cpp b/tests/TypeInfer.anyerror.test.cpp new file mode 100644 index 000000000..5224b5d88 --- /dev/null +++ b/tests/TypeInfer.anyerror.test.cpp @@ -0,0 +1,335 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferAnyError"); + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any") +{ + CheckResult result = check(R"( + function bar(): any + return true + end + + local a + for b in bar do + a = b + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(typeChecker.anyType, requireType("a")); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") +{ + CheckResult result = check(R"( + function bar(): any + return true + end + + local a + for b in bar() do + a = b + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") +{ + CheckResult result = check(R"( + local bar: any + + local a + for b in bar do + a = b + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") +{ + CheckResult result = check(R"( + local bar: any + + local a + for b in bar() do + a = b + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") +{ + CheckResult result = check(R"( + local a + for b in bar do + a = b + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("*unknown*", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") +{ + CheckResult result = check(R"( + function bar(c) return c end + + local a + for b in bar() do + a = b + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("*unknown*", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error") +{ + CheckResult result = check(R"( + local l = #this_is_not_defined + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "indexing_error_type_does_not_produce_an_error") +{ + CheckResult result = check(R"( + local originalReward = unknown.Parent.Reward:GetChildren()[1] + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "dot_on_error_type_does_not_produce_an_error") +{ + CheckResult result = check(R"( + local foo = (true).x + foo.x = foo.y + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "any_type_propagates") +{ + CheckResult result = check(R"( + local foo: any + local bar = foo:method("argument") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "can_subscript_any") +{ + CheckResult result = check(R"( + local foo: any + local bar = foo[5] + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("any", toString(requireType("bar"))); +} + +// Not strictly correct: metatables permit overriding this +TEST_CASE_FIXTURE(Fixture, "can_get_length_of_any") +{ + CheckResult result = check(R"( + local foo: any = {} + local bar = #foo + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(PrimitiveTypeVar::Number, getPrimitiveType(requireType("bar"))); +} + +TEST_CASE_FIXTURE(Fixture, "assign_prop_to_table_by_calling_any_yields_any") +{ + CheckResult result = check(R"( + local f: any + local T = {} + + T.prop = f() + + return T + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TableTypeVar* ttv = getMutable(requireType("T")); + REQUIRE(ttv); + REQUIRE(ttv->props.count("prop")); + + REQUIRE_EQ("any", toString(ttv->props["prop"].type)); +} + +TEST_CASE_FIXTURE(Fixture, "quantify_any_does_not_bind_to_itself") +{ + CheckResult result = check(R"( + local A : any + function A.B() end + A:C() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId aType = requireType("A"); + CHECK_EQ(aType, typeChecker.anyType); +} + +TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") +{ + CheckResult result = check(R"( + local a = unknown.Parent.Reward.GetChildren() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + UnknownSymbol* err = get(result.errors[0]); + REQUIRE(err != nullptr); + + CHECK_EQ("unknown", err->name); + + CHECK_EQ("*unknown*", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") +{ + CheckResult result = check(R"( + local a = Utility.Create "Foo" {} + )"); + + CHECK_EQ("*unknown*", toString(requireType("a"))); +} + +TEST_CASE_FIXTURE(Fixture, "replace_every_free_type_when_unifying_a_complex_function_with_any") +{ + CheckResult result = check(R"( + local a: any + local b + for _, i in pairs(a) do + b = i + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("any", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "call_to_any_yields_any") +{ + CheckResult result = check(R"( + local a: any + local b = a() + )"); + + REQUIRE_EQ("any", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfAny") +{ + CheckResult result = check(R"( +local x: any = {} +function x:y(z: number) + local s: string = z +end +)"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfError") +{ + CheckResult result = check(R"( +local x = (true).foo +function x:y(z: number) + local s: string = z +end +)"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "metatable_of_any_can_be_a_table") +{ + CheckResult result = check(R"( +--!strict +local T: any +T = {} +T.__index = T +function T.new(...) + local self = {} + setmetatable(self, T) + self:construct(...) + return self +end +function T:construct(index) +end +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "type_error_addition") +{ + CheckResult result = check(R"( +--!strict +local foo = makesandwich() +local bar = foo.nutrition + 100 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + // We should definitely get this error + CHECK_EQ("Unknown global 'makesandwich'", toString(result.errors[0])); + // We get this error if makesandwich() returns a free type + // CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'foo'", toString(result.errors[1])); +} + +TEST_CASE_FIXTURE(Fixture, "prop_access_on_any_with_other_options") +{ + CheckResult result = check(R"( + local function f(thing: any | string) + local foo = thing.SomeRandomKey + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index 8da655b34..ec20a2c7f 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -871,6 +871,7 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types") ScopedFastFlag sff[]{ {"LuauAssertStripsFalsyTypes", true}, {"LuauDiscriminableUnions2", true}, + {"LuauWidenIfSupertypeIsFree2", true}, }; CheckResult result = check(R"( @@ -879,6 +880,26 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types") end )"); + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("((boolean | number)?) -> boolean | number", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types2") +{ + ScopedFastFlag sff[]{ + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + {"LuauAssertStripsFalsyTypes", true}, + {"LuauDiscriminableUnions2", true}, + {"LuauWidenIfSupertypeIsFree2", true}, + }; + + CheckResult result = check(R"( + local function f(x: (number | boolean)?): number | true + return assert(x) + end + )"); + LUAU_REQUIRE_NO_ERRORS(result); CHECK_EQ("((boolean | number)?) -> number | true", toString(requireType("f"))); } @@ -958,4 +979,43 @@ a:b({}) CHECK_EQ(result.errors[1], (TypeError{Location{{3, 0}, {3, 5}}, CountMismatch{2, 1}})); } +TEST_CASE_FIXTURE(Fixture, "typeof_unresolved_function") +{ + CheckResult result = check(R"( +local function f(a: typeof(f)) end +)"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Unknown global 'f'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "no_persistent_typelevel_change") +{ + TypeId mathTy = requireType(typeChecker.globalScope, "math"); + REQUIRE(mathTy); + TableTypeVar* ttv = getMutable(mathTy); + REQUIRE(ttv); + const FunctionTypeVar* ftv = get(ttv->props["frexp"].type); + REQUIRE(ftv); + auto original = ftv->level; + + CheckResult result = check("local a = math.frexp"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK(ftv->level.level == original.level); + CHECK(ftv->level.subLevel == original.subLevel); +} + +TEST_CASE_FIXTURE(Fixture, "global_singleton_types_are_sealed") +{ + CheckResult result = check(R"( +local function f(x: string) + local p = x:split('a') + p = table.pack(table.unpack(p, 1, #p - 1)) + return p +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.definitions.test.cpp b/tests/TypeInfer.definitions.test.cpp index c6d55793d..898d89029 100644 --- a/tests/TypeInfer.definitions.test.cpp +++ b/tests/TypeInfer.definitions.test.cpp @@ -293,4 +293,22 @@ TEST_CASE_FIXTURE(Fixture, "documentation_symbols_dont_attach_to_persistent_type CHECK_EQ(ty->type->documentationSymbol, std::nullopt); } +TEST_CASE_FIXTURE(Fixture, "single_class_type_identity_in_global_types") +{ + ScopedFastFlag luauCloneDeclaredGlobals{"LuauCloneDeclaredGlobals", true}; + + loadDefinition(R"( +declare class Cls +end + +declare GetCls: () -> (Cls) + )"); + + CheckResult result = check(R"( +local s : Cls = GetCls() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp new file mode 100644 index 000000000..4288098a3 --- /dev/null +++ b/tests/TypeInfer.functions.test.cpp @@ -0,0 +1,1338 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferFunctions"); + +TEST_CASE_FIXTURE(Fixture, "tc_function") +{ + CheckResult result = check("function five() return 5 end"); + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* fiveType = get(requireType("five")); + REQUIRE(fiveType != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "check_function_bodies") +{ + CheckResult result = check("function myFunction() local a = 0 a = true end"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 44}, Position{0, 48}}, TypeMismatch{ + typeChecker.numberType, + typeChecker.booleanType, + }})); +} + +TEST_CASE_FIXTURE(Fixture, "infer_return_type") +{ + CheckResult result = check("function take_five() return 5 end"); + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* takeFiveType = get(requireType("take_five")); + REQUIRE(takeFiveType != nullptr); + + std::vector retVec = flatten(takeFiveType->retType).first; + REQUIRE(!retVec.empty()); + + REQUIRE_EQ(*follow(retVec[0]), *typeChecker.numberType); +} + +TEST_CASE_FIXTURE(Fixture, "infer_from_function_return_type") +{ + CheckResult result = check("function take_five() return 5 end local five = take_five()"); + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *follow(requireType("five"))); +} + +TEST_CASE_FIXTURE(Fixture, "infer_that_function_does_not_return_a_table") +{ + CheckResult result = check(R"( + function take_five() + return 5 + end + + take_five().prop = 888 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0], (TypeError{Location{Position{5, 8}, Position{5, 24}}, NotATable{typeChecker.numberType}})); +} + +TEST_CASE_FIXTURE(Fixture, "vararg_functions_should_allow_calls_of_any_types_and_size") +{ + CheckResult result = check(R"( + function f(...) end + + f(1) + f("foo", 2) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "vararg_function_is_quantified") +{ + CheckResult result = check(R"( + local T = {} + function T.f(...) + local result = {} + + for i = 1, select("#", ...) do + local dictionary = select(i, ...) + for key, value in pairs(dictionary) do + result[key] = value + end + end + + return result + end + + return T + )"); + + auto r = first(getMainModule()->getModuleScope()->returnType); + REQUIRE(r); + + TableTypeVar* ttv = getMutable(*r); + REQUIRE(ttv); + + TypeId k = ttv->props["f"].type; + REQUIRE(k); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "list_only_alternative_overloads_that_match_argument_count") +{ + CheckResult result = check(R"( + local multiply: ((number)->number) & ((number)->string) & ((number, number)->number) + multiply("") + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(typeChecker.stringType, tm->givenType); + + ExtraInformation* ei = get(result.errors[1]); + REQUIRE(ei); + CHECK_EQ("Other overloads are also not viable: (number) -> string", ei->message); +} + +TEST_CASE_FIXTURE(Fixture, "list_all_overloads_if_no_overload_takes_given_argument_count") +{ + CheckResult result = check(R"( + local multiply: ((number)->number) & ((number)->string) & ((number, number)->number) + multiply() + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK_EQ("No overload for function accepts 0 arguments.", ge->message); + + ExtraInformation* ei = get(result.errors[1]); + REQUIRE(ei); + CHECK_EQ("Available overloads: (number) -> number; (number) -> string; and (number, number) -> number", ei->message); +} + +TEST_CASE_FIXTURE(Fixture, "dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists") +{ + CheckResult result = check(R"( + local multiply: ((number)->number) & ((number)->string) & ((number, number)->number) + multiply(1, "") + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(typeChecker.stringType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "infer_return_type_from_selected_overload") +{ + CheckResult result = check(R"( + type T = {method: ((T, number) -> number) & ((number) -> string)} + local T: T + + local a = T.method(T, 4) + local b = T.method(5) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("a"))); + CHECK_EQ("string", toString(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "too_many_arguments") +{ + CheckResult result = check(R"( + --!nonstrict + + function g(a: number) end + + g() + + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto err = result.errors[0]; + auto acm = get(err); + REQUIRE(acm); + + CHECK_EQ(1, acm->expected); + CHECK_EQ(0, acm->actual); +} + +TEST_CASE_FIXTURE(Fixture, "recursive_function") +{ + CheckResult result = check(R"( + function count(n: number) + if n == 0 then + return 0 + else + return count(n - 1) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "lambda_form_of_local_function_cannot_be_recursive") +{ + CheckResult result = check(R"( + local f = function() return f() end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "recursive_local_function") +{ + CheckResult result = check(R"( + local function count(n: number) + if n == 0 then + return 0 + else + return count(n - 1) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +// FIXME: This and the above case get handled very differently. It's pretty dumb. +// We really should unify the two code paths, probably by deleting AstStatFunction. +TEST_CASE_FIXTURE(Fixture, "another_recursive_local_function") +{ + CheckResult result = check(R"( + local count + function count(n: number) + if n == 0 then + return 0 + else + return count(n - 1) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets") +{ + CheckResult result = check(R"( + function f() + return f + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("t1 where t1 = () -> t1", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") +{ + CheckResult result = check(R"( + function f(g) + return f(f) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("t1 where t1 = (t1) -> ()", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "another_higher_order_function") +{ + CheckResult result = check(R"( + local Get_des + function Get_des(func) + Get_des(func) + end + + local function f(d) + d:IsA("BasePart") + d.Parent:FindFirstChild("Humanoid") + d:IsA("Decal") + end + Get_des(f) + + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "another_other_higher_order_function") +{ + CheckResult result = check(R"( + local d + d:foo() + d:foo() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "local_function") +{ + CheckResult result = check(R"( + function f() + return 8 + end + + function g() + local function f() + return 'hello' + end + return f + end + + local h = g() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId h = follow(requireType("h")); + + const FunctionTypeVar* ftv = get(h); + REQUIRE(ftv != nullptr); + + std::optional rt = first(ftv->retType); + REQUIRE(bool(rt)); + + TypeId retType = follow(*rt); + CHECK_EQ(PrimitiveTypeVar::String, getPrimitiveType(retType)); +} + +TEST_CASE_FIXTURE(Fixture, "func_expr_doesnt_leak_free") +{ + CheckResult result = check(R"( + local p = function(x) return x end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + const Luau::FunctionTypeVar* fn = get(requireType("p")); + REQUIRE(fn); + auto ret = first(fn->retType); + REQUIRE(ret); + REQUIRE(get(follow(*ret))); +} + +TEST_CASE_FIXTURE(Fixture, "first_argument_can_be_optional") +{ + CheckResult result = check(R"( + local T = {} + function T.new(a: number?, b: number?, c: number?) return 5 end + local m = T.new() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "it_is_ok_not_to_supply_enough_retvals") +{ + CheckResult result = check(R"( + function get_two() return 5, 6 end + + local a = get_two() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "duplicate_functions2") +{ + CheckResult result = check(R"( + function foo() end + + function bar() + local function foo() end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); +} + +TEST_CASE_FIXTURE(Fixture, "duplicate_functions_allowed_in_nonstrict") +{ + CheckResult result = check(R"( + --!nonstrict + function foo() end + + function foo() end + + function bar() + local function foo() end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "duplicate_functions_with_different_signatures_not_allowed_in_nonstrict") +{ + CheckResult result = check(R"( + --!nonstrict + function foo(): number + return 1 + end + foo() + + function foo(n: number): number + return 2 + end + foo() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("() -> number", toString(tm->wantedType)); + CHECK_EQ("(number) -> number", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "complicated_return_types_require_an_explicit_annotation") +{ + CheckResult result = check(R"( + local i = 0 + function most_of_the_natural_numbers(): number? + if i < 10 then + i = i + 1 + return i + else + return nil + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* functionType = get(requireType("most_of_the_natural_numbers")); + + std::optional retType = first(functionType->retType); + REQUIRE(retType); + CHECK(get(*retType)); +} + +TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") +{ + CheckResult result = check(R"( + function apply(f, x) + return f(x) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* ftv = get(requireType("apply")); + REQUIRE(ftv != nullptr); + + std::vector argVec = flatten(ftv->argTypes).first; + + REQUIRE_EQ(2, argVec.size()); + + const FunctionTypeVar* fType = get(follow(argVec[0])); + REQUIRE(fType != nullptr); + + std::vector fArgs = flatten(fType->argTypes).first; + + TypeId xType = argVec[1]; + + CHECK_EQ(1, fArgs.size()); + CHECK_EQ(xType, fArgs[0]); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function_2") +{ + CheckResult result = check(R"( + function bottomupmerge(comp, a, b, left, mid, right) + local i, j = left, mid + for k = left, right do + if i < mid and (j > right or not comp(a[j], a[i])) then + b[k] = a[i] + i = i + 1 + else + b[k] = a[j] + j = j + 1 + end + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* ftv = get(requireType("bottomupmerge")); + REQUIRE(ftv != nullptr); + + std::vector argVec = flatten(ftv->argTypes).first; + + REQUIRE_EQ(6, argVec.size()); + + const FunctionTypeVar* fType = get(follow(argVec[0])); + REQUIRE(fType != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function_3") +{ + CheckResult result = check(R"( + function swap(p) + local t = p[0] + p[0] = p[1] + p[1] = t + return nil + end + + function swapTwice(p) + swap(p) + swap(p) + return p + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* ftv = get(requireType("swapTwice")); + REQUIRE(ftv != nullptr); + + std::vector argVec = flatten(ftv->argTypes).first; + + REQUIRE_EQ(1, argVec.size()); + + const TableTypeVar* argType = get(follow(argVec[0])); + REQUIRE(argType != nullptr); + + CHECK(bool(argType->indexer)); +} + +TEST_CASE_FIXTURE(Fixture, "higher_order_function_4") +{ + CheckResult result = check(R"( + function bottomupmerge(comp, a, b, left, mid, right) + local i, j = left, mid + for k = left, right do + if i < mid and (j > right or not comp(a[j], a[i])) then + b[k] = a[i] + i = i + 1 + else + b[k] = a[j] + j = j + 1 + end + end + end + + function mergesort(arr, comp) + local work = {} + for i = 1, #arr do + work[i] = arr[i] + end + local width = 1 + while width < #arr do + for i = 1, #arr, 2*width do + bottomupmerge(comp, arr, work, i, math.min(i+width, #arr), math.min(i+2*width-1, #arr)) + end + local temp = work + work = arr + arr = temp + width = width * 2 + end + return arr + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); + + /* + * mergesort takes two arguments: an array of some type T and a function that takes two Ts. + * We must assert that these two types are in fact the same type. + * In other words, comp(arr[x], arr[y]) is well-typed. + */ + + const FunctionTypeVar* ftv = get(requireType("mergesort")); + REQUIRE(ftv != nullptr); + + std::vector argVec = flatten(ftv->argTypes).first; + + REQUIRE_EQ(2, argVec.size()); + + const TableTypeVar* arg0 = get(follow(argVec[0])); + REQUIRE(arg0 != nullptr); + REQUIRE(bool(arg0->indexer)); + + const FunctionTypeVar* arg1 = get(follow(argVec[1])); + REQUIRE(arg1 != nullptr); + REQUIRE_EQ(2, size(arg1->argTypes)); + + std::vector arg1Args = flatten(arg1->argTypes).first; + + CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[0]); + CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[1]); +} + +TEST_CASE_FIXTURE(Fixture, "mutual_recursion") +{ + CheckResult result = check(R"( + --!strict + + function newPlayerCharacter() + startGui() -- Unknown symbol 'startGui' + end + + local characterAddedConnection: any + function startGui() + characterAddedConnection = game:GetService("Players").LocalPlayer.CharacterAdded:connect(newPlayerCharacter) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "toposort_doesnt_break_mutual_recursion") +{ + CheckResult result = check(R"( + --!strict + local x = nil + function f() g() end + -- make sure print(x) doesn't get toposorted here, breaking the mutual block + function g() x = f end + print(x) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + +TEST_CASE_FIXTURE(Fixture, "check_function_before_lambda_that_uses_it") +{ + CheckResult result = check(R"( + --!nonstrict + + function f() + return 114 + end + + return function() + return f():andThen() + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "it_is_ok_to_oversaturate_a_higher_order_function_argument") +{ + CheckResult result = check(R"( + function onerror() end + function foo() end + xpcall(foo, onerror) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "another_indirect_function_case_where_it_is_ok_to_provide_too_many_arguments") +{ + CheckResult result = check(R"( + local mycb: (number, number) -> () + + function f() end + + mycb = f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "report_exiting_without_return_nonstrict") +{ + CheckResult result = check(R"( + --!nonstrict + + local function f1(v): number? + if v then + return 1 + end + end + + local function f2(v) + if v then + return 1 + end + end + + local function f3(v): () + if v then + return + end + end + + local function f4(v) + if v then + return + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + FunctionExitsWithoutReturning* err = get(result.errors[0]); + CHECK(err); +} + +TEST_CASE_FIXTURE(Fixture, "report_exiting_without_return_strict") +{ + CheckResult result = check(R"( + --!strict + + local function f1(v): number? + if v then + return 1 + end + end + + local function f2(v) + if v then + return 1 + end + end + + local function f3(v): () + if v then + return + end + end + + local function f4(v) + if v then + return + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + FunctionExitsWithoutReturning* annotatedErr = get(result.errors[0]); + CHECK(annotatedErr); + + FunctionExitsWithoutReturning* inferredErr = get(result.errors[1]); + CHECK(inferredErr); +} + +TEST_CASE_FIXTURE(Fixture, "calling_function_with_incorrect_argument_type_yields_errors_spanning_argument") +{ + CheckResult result = check(R"( + function foo(a: number, b: string) end + + foo("Test", 123) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CHECK_EQ(result.errors[0], (TypeError{Location{Position{3, 12}, Position{3, 18}}, TypeMismatch{ + typeChecker.numberType, + typeChecker.stringType, + }})); + + CHECK_EQ(result.errors[1], (TypeError{Location{Position{3, 20}, Position{3, 23}}, TypeMismatch{ + typeChecker.stringType, + typeChecker.numberType, + }})); +} + +TEST_CASE_FIXTURE(Fixture, "calling_function_with_anytypepack_doesnt_leak_free_types") +{ + CheckResult result = check(R"( + --!nonstrict + + function Test(a) + return 1, "" + end + + + local tab = {} + table.insert(tab, Test(1)); + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions opts; + opts.exhaustive = true; + opts.maxTableLength = 0; + + CHECK_EQ("{any}", toString(requireType("tab"), opts)); +} + +TEST_CASE_FIXTURE(Fixture, "too_many_return_values") +{ + CheckResult result = check(R"( + --!strict + + function f() + return 55 + end + + local a, b = f() + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Result); + CHECK_EQ(acm->expected, 1); + CHECK_EQ(acm->actual, 2); +} + +TEST_CASE_FIXTURE(Fixture, "ignored_return_values") +{ + CheckResult result = check(R"( + --!strict + + function f() + return 55, "" + end + + local a = f() + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); +} + +TEST_CASE_FIXTURE(Fixture, "function_does_not_return_enough_values") +{ + CheckResult result = check(R"( + --!strict + + function f(): (number, string) + return 55 + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Return); + CHECK_EQ(acm->expected, 2); + CHECK_EQ(acm->actual, 1); +} + +TEST_CASE_FIXTURE(Fixture, "function_cast_error_uses_correct_language") +{ + CheckResult result = check(R"( + function foo(a, b): number + return 0 + end + + local a: (string)->number = foo + local b: (number, number)->(number, number) = foo + + local c: (string, number)->number = foo -- no error + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + auto tm1 = get(result.errors[0]); + REQUIRE(tm1); + + CHECK_EQ("(string) -> number", toString(tm1->wantedType)); + CHECK_EQ("(string, *unknown*) -> number", toString(tm1->givenType)); + + auto tm2 = get(result.errors[1]); + REQUIRE(tm2); + + CHECK_EQ("(number, number) -> (number, number)", toString(tm2->wantedType)); + CHECK_EQ("(string, *unknown*) -> number", toString(tm2->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type") +{ + CheckResult result = check(R"( + --!strict + local tbl = {} + function tbl:abc(a: number, b: number) + return a + end + tbl:abc(1, 2) -- Line 6 + -- | Column 14 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + TypeId type = requireTypeAtPosition(Position(6, 14)); + CHECK_EQ("(tbl, number, number) -> number", toString(type)); + auto ftv = get(type); + REQUIRE(ftv); + CHECK(ftv->hasSelf); +} + +TEST_CASE_FIXTURE(Fixture, "record_matching_overload") +{ + CheckResult result = check(R"( + type Overload = ((string) -> string) & ((number) -> number) + local abc: Overload + abc(1) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // AstExprCall is the node that has the overload stored on it. + // findTypeAtPosition will look at the AstExprLocal, but this is not what + // we want to look at. + std::vector ancestry = findAstAncestryOfPosition(*getMainSourceModule(), Position(3, 10)); + REQUIRE_GE(ancestry.size(), 2); + AstExpr* parentExpr = ancestry[ancestry.size() - 2]->asExpr(); + REQUIRE(bool(parentExpr)); + REQUIRE(parentExpr->is()); + + ModulePtr module = getMainModule(); + auto it = module->astOverloadResolvedTypes.find(parentExpr); + REQUIRE(it); + CHECK_EQ(toString(*it), "(number) -> number"); +} + +TEST_CASE_FIXTURE(Fixture, "return_type_by_overload") +{ + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( + type Overload = ((string) -> string) & ((number, number) -> number) + local abc: Overload + local x = abc(true) + local y = abc(true,true) + local z = abc(true,true,true) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("string", toString(requireType("x"))); + CHECK_EQ("number", toString(requireType("y"))); + // Should this be string|number? + CHECK_EQ("string", toString(requireType("z"))); +} + +TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") +{ + // Simple direct arg to arg propagation + CheckResult result = check(R"( +type Table = { x: number, y: number } +local function f(a: (Table) -> number) return a({x = 1, y = 2}) end +f(function(a) return a.x + a.y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // An optional function is accepted, but since we already provide a function, nil can be ignored + result = check(R"( +type Table = { x: number, y: number } +local function f(a: ((Table) -> number)?) if a then return a({x = 1, y = 2}) else return 0 end end +f(function(a) return a.x + a.y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Make sure self calls match correct index + result = check(R"( +type Table = { x: number, y: number } +local x = {} +x.b = {x = 1, y = 2} +function x:f(a: (Table) -> number) return a(self.b) end +x:f(function(a) return a.x + a.y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Mix inferred and explicit argument types + result = check(R"( +function f(a: (a: number, b: number, c: boolean) -> number) return a(1, 2, true) end +f(function(a: number, b, c) return c and a + b or b - a end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Anonymous function has a variadic pack + result = check(R"( +type Table = { x: number, y: number } +local function f(a: (Table) -> number) return a({x = 1, y = 2}) end +f(function(...) return select(1, ...).z end) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); + + // Can't accept more arguments than provided + result = check(R"( +function f(a: (a: number, b: number) -> number) return a(1, 2) end +f(function(a, b, c, ...) return a + b end) + )"); + + LUAU_REQUIRE_ERRORS(result); + + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' +caused by: + Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", + toString(result.errors[0])); + + // Infer from variadic packs into elements + result = check(R"( +function f(a: (...number) -> number) return a(1, 2) end +f(function(a, b) return a + b end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Infer from variadic packs into variadic packs + result = check(R"( +type Table = { x: number, y: number } +function f(a: (...Table) -> number) return a({x = 1, y = 2}, {x = 3, y = 4}) end +f(function(a, ...) local b = ... return b.z end) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); + + // Return type inference + result = check(R"( +type Table = { x: number, y: number } +function f(a: (number) -> Table) return a(4) end +f(function(x) return x * 2 end) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); + + // Return type doesn't inference 'nil' + result = check(R"( +function f(a: (number) -> nil) return a(4) end +f(function(x) print(x) end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") +{ + // Simple direct arg to arg propagation + CheckResult result = check(R"( +type Table = { x: number, y: number } +local function f(a: (Table) -> number) return a({x = 1, y = 2}) end +f(function(a) return a.x + a.y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // An optional function is accepted, but since we already provide a function, nil can be ignored + result = check(R"( +type Table = { x: number, y: number } +local function f(a: ((Table) -> number)?) if a then return a({x = 1, y = 2}) else return 0 end end +f(function(a) return a.x + a.y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Make sure self calls match correct index + result = check(R"( +type Table = { x: number, y: number } +local x = {} +x.b = {x = 1, y = 2} +function x:f(a: (Table) -> number) return a(self.b) end +x:f(function(a) return a.x + a.y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Mix inferred and explicit argument types + result = check(R"( +function f(a: (a: number, b: number, c: boolean) -> number) return a(1, 2, true) end +f(function(a: number, b, c) return c and a + b or b - a end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Anonymous function has a variadic pack + result = check(R"( +type Table = { x: number, y: number } +local function f(a: (Table) -> number) return a({x = 1, y = 2}) end +f(function(...) return select(1, ...).z end) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); + + // Can't accept more arguments than provided + result = check(R"( +function f(a: (a: number, b: number) -> number) return a(1, 2) end +f(function(a, b, c, ...) return a + b end) + )"); + + LUAU_REQUIRE_ERRORS(result); + + CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' +caused by: + Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", + toString(result.errors[0])); + + // Infer from variadic packs into elements + result = check(R"( +function f(a: (...number) -> number) return a(1, 2) end +f(function(a, b) return a + b end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // Infer from variadic packs into variadic packs + result = check(R"( +type Table = { x: number, y: number } +function f(a: (...Table) -> number) return a({x = 1, y = 2}, {x = 3, y = 4}) end +f(function(a, ...) local b = ... return b.z end) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); + + // Return type inference + result = check(R"( +type Table = { x: number, y: number } +function f(a: (number) -> Table) return a(4) end +f(function(x) return x * 2 end) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); + + // Return type doesn't inference 'nil' + result = check(R"( +function f(a: (number) -> nil) return a(4) end +f(function(x) print(x) end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments_outside_call") +{ + CheckResult result = check(R"( +type Table = { x: number, y: number } +local f: (Table) -> number = function(t) return t.x + t.y end + +type TableWithFunc = { x: number, y: number, f: (number, number) -> number } +local a: TableWithFunc = { x = 3, y = 4, f = function(a, b) return a + b end } + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_return_value_type") +{ + CheckResult result = check(R"( +local function f(): {string|number} + return {1, "b", 3} +end + +local function g(): (number, {string|number}) + return 4, {1, "b", 3} +end + +local function h(): ...{string|number} + return {4}, {1, "b", 3}, {"s"} +end + +local function i(): ...{string|number} + return {1, "b", 3}, h() +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg_count") +{ + CheckResult result = check(R"( +type A = (number, number) -> string +type B = (number) -> string + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number) -> string' +caused by: + Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg") +{ + CheckResult result = check(R"( +type A = (number, number) -> string +type B = (number, string) -> string + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, string) -> string' +caused by: + Argument #2 type is not compatible. Type 'string' could not be converted into 'number')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_count") +{ + CheckResult result = check(R"( +type A = (number, number) -> (number) +type B = (number, number) -> (number, boolean) + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> number' could not be converted into '(number, number) -> (number, boolean)' +caused by: + Function only returns 1 value. 2 are required here)"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret") +{ + CheckResult result = check(R"( +type A = (number, number) -> string +type B = (number, number) -> number + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, number) -> number' +caused by: + Return type is not compatible. Type 'string' could not be converted into 'number')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_mult") +{ + CheckResult result = check(R"( +type A = (number, number) -> (number, string) +type B = (number, number) -> (number, boolean) + +local a: A +local b: B = a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), + R"(Type '(number, number) -> (number, string)' could not be converted into '(number, number) -> (number, boolean)' +caused by: + Return #2 type is not compatible. Type 'string' could not be converted into 'boolean')"); +} + +TEST_CASE_FIXTURE(Fixture, "function_decl_quantify_right_type") +{ + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify2", true}; + + fileResolver.source["game/isAMagicMock"] = R"( +--!nonstrict +return function(value) + return false +end + )"; + + CheckResult result = check(R"( +--!nonstrict +local MagicMock = {} +MagicMock.is = require(game.isAMagicMock) + +function MagicMock.is(value) + return false +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_sealed_overwrite") +{ + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify2", true}; + + CheckResult result = check(R"( +function string.len(): number + return 1 +end + )"); + + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "strict_mode_ok_with_missing_arguments") +{ + ScopedFastFlag sff{"LuauAnyInIsOptionalIsOptional", true}; + + CheckResult result = check(R"( + local function f(x: any) end + f() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "function_statement_sealed_table_assignment_through_indexer") +{ + ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify2", true}; + + CheckResult result = check(R"( +local t: {[string]: () -> number} = {} + +function t.a() return 1 end -- OK +function t:b() return 2 end -- not OK + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(R"(Type '(*unknown*) -> number' could not be converted into '() -> number' +caused by: + Argument count mismatch. Function expects 1 argument, but none are specified)", + toString(result.errors[0])); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.generics.test.cpp b/tests/TypeInfer.generics.test.cpp index 547fbab15..f360a77cc 100644 --- a/tests/TypeInfer.generics.test.cpp +++ b/tests/TypeInfer.generics.test.cpp @@ -1,6 +1,9 @@ // This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details #include "Luau/TypeInfer.h" #include "Luau/TypeVar.h" +#include "Luau/Scope.h" + +#include #include "Fixture.h" @@ -830,5 +833,303 @@ wrapper(test2, 1, "", 3) CHECK_EQ(toString(result.errors[0]), R"(Argument count mismatch. Function expects 1 argument, but 4 are specified)"); } +TEST_CASE_FIXTURE(Fixture, "generic_function") +{ + CheckResult result = check(R"( + function id(x) return x end + local a = id(55) + local b = id(nil) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *requireType("a")); + CHECK_EQ(*typeChecker.nilType, *requireType("b")); +} + +TEST_CASE_FIXTURE(Fixture, "generic_table_method") +{ + CheckResult result = check(R"( + local T = {} + + function T:bar(i) + return i + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TypeId tType = requireType("T"); + TableTypeVar* tTable = getMutable(tType); + REQUIRE(tTable != nullptr); + + TypeId barType = tTable->props["bar"].type; + REQUIRE(barType != nullptr); + + const FunctionTypeVar* ftv = get(follow(barType)); + REQUIRE_MESSAGE(ftv != nullptr, "Should be a function: " << *barType); + + std::vector args = flatten(ftv->argTypes).first; + TypeId argType = args.at(1); + + CHECK_MESSAGE(get(argType), "Should be generic: " << *barType); +} + +TEST_CASE_FIXTURE(Fixture, "correctly_instantiate_polymorphic_member_functions") +{ + CheckResult result = check(R"( + local T = {} + + function T:foo() + return T:bar(5) + end + + function T:bar(i) + return i + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); + + const TableTypeVar* t = get(requireType("T")); + REQUIRE(t != nullptr); + + std::optional fooProp = get(t->props, "foo"); + REQUIRE(bool(fooProp)); + + const FunctionTypeVar* foo = get(follow(fooProp->type)); + REQUIRE(bool(foo)); + + std::optional ret_ = first(foo->retType); + REQUIRE(bool(ret_)); + TypeId ret = follow(*ret_); + + REQUIRE_EQ(getPrimitiveType(ret), PrimitiveTypeVar::Number); +} + +/* + * We had a bug in instantiation where the argument types of 'f' and 'g' would be inferred as + * f {+ method: function(): (t2, T3...) +} + * g {+ method: function({+ method: function(): (t2, T3...) +}): (t5, T6...) +} + * + * The type of 'g' is totally wrong as t2 and t5 should be unified, as should T3 with T6. + * + * The correct unification of the argument to 'g' is + * + * {+ method: function(): (t5, T6...) +} + */ +TEST_CASE_FIXTURE(Fixture, "instantiate_cyclic_generic_function") +{ + auto result = check(R"( + function f(o) + o:method() + end + + function g(o) + f(o) + end + )"); + + TypeId g = requireType("g"); + const FunctionTypeVar* gFun = get(g); + REQUIRE(gFun != nullptr); + + auto optionArg = first(gFun->argTypes); + REQUIRE(bool(optionArg)); + + TypeId arg = follow(*optionArg); + const TableTypeVar* argTable = get(arg); + REQUIRE(argTable != nullptr); + + std::optional methodProp = get(argTable->props, "method"); + REQUIRE(bool(methodProp)); + + const FunctionTypeVar* methodFunction = get(methodProp->type); + REQUIRE(methodFunction != nullptr); + + std::optional methodArg = first(methodFunction->argTypes); + REQUIRE(bool(methodArg)); + + REQUIRE_EQ(follow(*methodArg), follow(arg)); +} + +TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments") +{ + CheckResult result = check(R"( + function foo(a, b) + return a(b) + end + + function bar() + local c: ((number)->number, number)->number = foo -- no error + c = foo -- no error + local d: ((number)->number, string)->number = foo -- error from arg 2 (string) not being convertable to number from the call a(b) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("((number) -> number, string) -> number", toString(tm->wantedType)); + CHECK_EQ("((number) -> number, number) -> number", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments2") +{ + CheckResult result = check(R"( + function foo(a, b) + return a(b) + end + + function bar() + local _: (string, string)->number = foo -- string cannot be converted to (string)->number + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("(string, string) -> number", toString(tm->wantedType)); + CHECK_EQ("((string) -> number, string) -> number", toString(*tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") +{ + ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; + + // Mutability in type function application right now can create strange recursive types + CheckResult result = check(R"( +type Table = { a: number } +type Self = T +local a: Self
+ )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(requireType("a")), "Table"); +} + +TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") +{ + CheckResult result = check(R"( + function _(l0:t0): (any, ()->()) + end + + type t0 = t0 | {} + )"); + + CHECK_LE(0, result.errors.size()); + + std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); + REQUIRE(t0); + CHECK_EQ("*unknown*", toString(t0->type)); + + auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { + return get(err); + }); + CHECK(it != result.errors.end()); +} + +TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") +{ + ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; + + CheckResult result = check(R"( +local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end +return sum(2, 3, function(a, b) return a + b end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + result = check(R"( +local function map(arr: {a}, f: (a) -> b) local r = {} for i,v in ipairs(arr) do table.insert(r, f(v)) end return r end +local a = {1, 2, 3} +local r = map(a, function(a) return a + a > 100 end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + REQUIRE_EQ("{boolean}", toString(requireType("r"))); + + check(R"( +local function foldl(arr: {a}, init: b, f: (b, a) -> b) local r = init for i,v in ipairs(arr) do r = f(r, v) end return r end +local a = {1, 2, 3} +local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + REQUIRE_EQ("{ c: number, s: number }", toString(requireType("r"))); +} + +TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") +{ + CheckResult result = check(R"( +local function g1(a: T, f: (T) -> T) return f(a) end +local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end + +local g12: typeof(g1) & typeof(g2) + +g12(1, function(x) return x + x end) +g12(1, 2, function(x, y) return x + y end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + result = check(R"( +local function g1(a: T, f: (T) -> T) return f(a) end +local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end + +local g12: typeof(g1) & typeof(g2) + +g12({x=1}, function(x) return {x=-x.x} end) +g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "infer_generic_lib_function_function_argument") +{ + CheckResult result = check(R"( +local a = {{x=4}, {x=7}, {x=1}} +table.sort(a, function(x, y) return x.x < y.x end) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "do_not_infer_generic_functions") +{ + CheckResult result = check(R"( +local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end + +local function sumrec(f: typeof(sum)) + return sum(2, 3, function(a, b) return a + b end) +end + +local b = sumrec(sum) -- ok +local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not inferred + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'; different number of generic type " + "parameters", + toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") +{ + CheckResult result = check(R"( +type A = { x: number } +local a: A = { x = 1 } +local b = a +type B = typeof(b) +type X = T +local c: X + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} TEST_SUITE_END(); diff --git a/tests/TypeInfer.intersectionTypes.test.cpp b/tests/TypeInfer.intersectionTypes.test.cpp index 26881b5cc..d146f4e81 100644 --- a/tests/TypeInfer.intersectionTypes.test.cpp +++ b/tests/TypeInfer.intersectionTypes.test.cpp @@ -377,4 +377,32 @@ local b: number = a CHECK_EQ(toString(result.errors[0]), R"(Type 'X & Y & Z' could not be converted into 'number'; none of the intersection parts are compatible)"); } +TEST_CASE_FIXTURE(Fixture, "overload_is_not_a_function") +{ + check(R"( +--!nonstrict +function _(...):((typeof(not _))&(typeof(not _)))&((typeof(not _))&(typeof(not _))) +_(...)(setfenv,_,not _,"")[_] = nil +end +do end +_(...)(...,setfenv,_):_G() +)"); +} + +TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_flattenintersection") +{ + CheckResult result = check(R"( + local l0,l0 + repeat + type t0 = ((any)|((any)&((any)|((any)&((any)|(any))))))&(t0) + function _(l0):(t0)&(t0) + while nil do + end + end + until _(_)(_)._ + )"); + + CHECK_LE(0, result.errors.size()); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.loops.test.cpp b/tests/TypeInfer.loops.test.cpp new file mode 100644 index 000000000..30df717b1 --- /dev/null +++ b/tests/TypeInfer.loops.test.cpp @@ -0,0 +1,473 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferLoops"); + +TEST_CASE_FIXTURE(Fixture, "for_loop") +{ + CheckResult result = check(R"( + local q + for i=0, 50, 2 do + q = i + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *requireType("q")); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop") +{ + CheckResult result = check(R"( + local n + local s + for i, v in pairs({ "foo" }) do + n = i + s = v + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *requireType("n")); + CHECK_EQ(*typeChecker.stringType, *requireType("s")); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_next") +{ + CheckResult result = check(R"( + local n + local s + for i, v in next, { "foo", "bar" } do + n = i + s = v + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *requireType("n")); + CHECK_EQ(*typeChecker.stringType, *requireType("s")); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_with_an_iterator_of_type_any") +{ + CheckResult result = check(R"( + local it: any + local a, b + for i, v in it do + a, b = i, v + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_should_fail_with_non_function_iterator") +{ + CheckResult result = check(R"( + local foo = "bar" + for i, v in foo do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_with_just_one_iterator_is_ok") +{ + CheckResult result = check(R"( + local function keys(dictionary) + local new = {} + local index = 1 + + for key in pairs(dictionary) do + new[index] = key + index = index + 1 + end + + return new + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_with_a_custom_iterator_should_type_check") +{ + CheckResult result = check(R"( + local function range(l, h): () -> number + return function() + return l + end + end + + for n: string in range(1, 10) do + print(n) + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error") +{ + CheckResult result = check(R"( + function f(x) + gobble.prop = x.otherprop + end + + local p + for _, part in i_am_not_defined do + p = part + f(part) + part.thirdprop = false + end + )"); + + CHECK_EQ(2, result.errors.size()); + + TypeId p = requireType("p"); + CHECK_EQ("*unknown*", toString(p)); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") +{ + CheckResult result = check(R"( + local bad_iter = 5 + + for a in bad_iter() do + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_factory_not_returning_the_right_amount_of_values") +{ + CheckResult result = check(R"( + local function hasDivisors(value: number, table) + return false + end + + function prime_iter(state, index) + while hasDivisors(index, state) do + index += 1 + end + + state[index] = true + return index + end + + function primes1() + return prime_iter, {} + end + + function primes2() + return prime_iter, {}, "" + end + + function primes3() + return prime_iter, {}, 2 + end + + for p in primes1() do print(p) end -- mismatch in argument count + + for p in primes2() do print(p) end -- mismatch in argument types, prime_iter takes {}, number, we are given {}, string + + for p in primes3() do print(p) end -- no error + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Arg); + CHECK_EQ(2, acm->expected); + CHECK_EQ(1, acm->actual); + + TypeMismatch* tm = get(result.errors[1]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(typeChecker.stringType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_iterator_requiring_args_but_none_given") +{ + CheckResult result = check(R"( + function prime_iter(state, index) + return 1 + end + + for p in prime_iter do print(p) end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CountMismatch* acm = get(result.errors[0]); + REQUIRE(acm); + CHECK_EQ(acm->context, CountMismatch::Arg); + CHECK_EQ(2, acm->expected); + CHECK_EQ(0, acm->actual); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_custom_iterator") +{ + CheckResult result = check(R"( + function primes() + return function (state: number) end, 2 + end + + for p, q in primes do + q = "" + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ(typeChecker.stringType, tm->givenType); +} + +TEST_CASE_FIXTURE(Fixture, "while_loop") +{ + CheckResult result = check(R"( + local i + while true do + i = 8 + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.numberType, *requireType("i")); +} + +TEST_CASE_FIXTURE(Fixture, "repeat_loop") +{ + CheckResult result = check(R"( + local i + repeat + i = 'hi' + until true + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(*typeChecker.stringType, *requireType("i")); +} + +TEST_CASE_FIXTURE(Fixture, "repeat_loop_condition_binds_to_its_block") +{ + CheckResult result = check(R"( + repeat + local x = true + until x + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "symbols_in_repeat_block_should_not_be_visible_beyond_until_condition") +{ + CheckResult result = check(R"( + repeat + local x = true + until x + + print(x) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "varlist_declared_by_for_in_loop_should_be_free") +{ + CheckResult result = check(R"( + local T = {} + + function T.f(p) + for i, v in pairs(p) do + T.f(v) + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "properly_infer_iteratee_is_a_free_table") +{ + // In this case, we cannot know the element type of the table {}. It could be anything. + // We therefore must initially ascribe a free typevar to iter. + CheckResult result = check(R"( + for iter in pairs({}) do + iter:g().p = true + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_while") +{ + CheckResult result = check(R"( + while true do + local a = 1 + end + + print(a) -- oops! + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + UnknownSymbol* us = get(result.errors[0]); + REQUIRE(us); + CHECK_EQ(us->name, "a"); +} + +TEST_CASE_FIXTURE(Fixture, "ipairs_produces_integral_indices") +{ + CheckResult result = check(R"( + local key + for i, e in ipairs({}) do key = i end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + REQUIRE_EQ("number", toString(requireType("key"))); +} + +TEST_CASE_FIXTURE(Fixture, "for_in_loop_where_iteratee_is_free") +{ + // This code doesn't pass typechecking. We just care that it doesn't crash. + (void)check(R"( + --!nonstrict + function _:_(...) + end + + repeat + if _ then + else + _ = ... + end + until _ + + for _ in _() do + end + )"); +} + +TEST_CASE_FIXTURE(Fixture, "unreachable_code_after_infinite_loop") +{ + { + CheckResult result = check(R"( + function unreachablecodepath(a): number + while true do + if a then return 10 end + end + -- unreachable + end + unreachablecodepath(4) + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); + } + + { + CheckResult result = check(R"( + function reachablecodepath(a): number + while true do + if a then break end + return 10 + end + + print("x") -- correct error + end + reachablecodepath(4) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK(get(result.errors[0])); + } + + { + CheckResult result = check(R"( + function unreachablecodepath(a): number + repeat + if a then return 10 end + until false + + -- unreachable + end + unreachablecodepath(4) + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); + } + + { + CheckResult result = check(R"( + function reachablecodepath(a, b): number + repeat + if a then break end + + if b then return 10 end + until false + + print("x") -- correct error + end + reachablecodepath(4) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK(get(result.errors[0])); + } + + { + CheckResult result = check(R"( + function unreachablecodepath(a: number?): number + repeat + return 10 + until a ~= nil + + -- unreachable + end + unreachablecodepath(4) + )"); + + LUAU_REQUIRE_ERROR_COUNT(0, result); + } +} + +TEST_CASE_FIXTURE(Fixture, "loop_typecheck_crash_on_empty_optional") +{ + CheckResult result = check(R"( + local t = {} + for _ in t do + for _ in assert(missing()) do + end + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp new file mode 100644 index 000000000..636436103 --- /dev/null +++ b/tests/TypeInfer.modules.test.cpp @@ -0,0 +1,310 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferModules"); + +TEST_CASE_FIXTURE(Fixture, "require") +{ + fileResolver.source["game/A"] = R"( + local function hooty(x: number): string + return "Hi there!" + end + + return {hooty=hooty} + )"; + + fileResolver.source["game/B"] = R"( + local Hooty = require(game.A) + + local h -- free! + local i = Hooty.hooty(h) + )"; + + CheckResult aResult = frontend.check("game/A"); + dumpErrors(aResult); + LUAU_REQUIRE_NO_ERRORS(aResult); + + CheckResult bResult = frontend.check("game/B"); + dumpErrors(bResult); + LUAU_REQUIRE_NO_ERRORS(bResult); + + ModulePtr b = frontend.moduleResolver.modules["game/B"]; + + REQUIRE(b != nullptr); + + dumpErrors(bResult); + + std::optional iType = requireType(b, "i"); + REQUIRE_EQ("string", toString(*iType)); + + std::optional hType = requireType(b, "h"); + REQUIRE_EQ("number", toString(*hType)); +} + +TEST_CASE_FIXTURE(Fixture, "require_types") +{ + fileResolver.source["workspace/A"] = R"( + export type Point = {x: number, y: number} + + return {} + )"; + + fileResolver.source["workspace/B"] = R"( + local Hooty = require(workspace.A) + + local h: Hooty.Point + )"; + + CheckResult bResult = frontend.check("workspace/B"); + dumpErrors(bResult); + + ModulePtr b = frontend.moduleResolver.modules["workspace/B"]; + REQUIRE(b != nullptr); + + TypeId hType = requireType(b, "h"); + REQUIRE_MESSAGE(bool(get(hType)), "Expected table but got " << toString(hType)); +} + +TEST_CASE_FIXTURE(Fixture, "require_a_variadic_function") +{ + fileResolver.source["game/A"] = R"( + local T = {} + function T.f(...) end + return T + )"; + + fileResolver.source["game/B"] = R"( + local A = require(game.A) + local f = A.f + )"; + + CheckResult result = frontend.check("game/B"); + + ModulePtr bModule = frontend.moduleResolver.getModule("game/B"); + REQUIRE(bModule != nullptr); + + TypeId f = follow(requireType(bModule, "f")); + + const FunctionTypeVar* ftv = get(f); + REQUIRE(ftv); + + auto iter = begin(ftv->argTypes); + auto endIter = end(ftv->argTypes); + + REQUIRE(iter == endIter); + REQUIRE(iter.tail()); + + CHECK(get(*iter.tail())); +} + +TEST_CASE_FIXTURE(Fixture, "type_error_of_unknown_qualified_type") +{ + CheckResult result = check(R"( + local p: SomeModule.DoesNotExist + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE_EQ(result.errors[0], (TypeError{Location{{1, 17}, {1, 40}}, UnknownSymbol{"SomeModule.DoesNotExist"}})); +} + +TEST_CASE_FIXTURE(Fixture, "require_module_that_does_not_export") +{ + const std::string sourceA = R"( + )"; + + const std::string sourceB = R"( + local Hooty = require(script.Parent.A) + )"; + + fileResolver.source["game/Workspace/A"] = sourceA; + fileResolver.source["game/Workspace/B"] = sourceB; + + frontend.check("game/Workspace/A"); + frontend.check("game/Workspace/B"); + + ModulePtr aModule = frontend.moduleResolver.modules["game/Workspace/A"]; + ModulePtr bModule = frontend.moduleResolver.modules["game/Workspace/B"]; + + CHECK(aModule->errors.empty()); + REQUIRE_EQ(1, bModule->errors.size()); + CHECK_MESSAGE(get(bModule->errors[0]), "Should be IllegalRequire: " << toString(bModule->errors[0])); + + auto hootyType = requireType(bModule, "Hooty"); + + CHECK_EQ("*unknown*", toString(hootyType)); +} + +TEST_CASE_FIXTURE(Fixture, "warn_if_you_try_to_require_a_non_modulescript") +{ + fileResolver.source["Modules/A"] = ""; + fileResolver.sourceTypes["Modules/A"] = SourceCode::Local; + + fileResolver.source["Modules/B"] = R"( + local M = require(script.Parent.A) + )"; + + CheckResult result = frontend.check("Modules/B"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "general_require_call_expression") +{ + fileResolver.source["game/A"] = R"( +--!strict +return { def = 4 } + )"; + + fileResolver.source["game/B"] = R"( +--!strict +local tbl = { abc = require(game.A) } +local a : string = "" +a = tbl.abc.def + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "general_require_type_mismatch") +{ + fileResolver.source["game/A"] = R"( +return { def = 4 } + )"; + + fileResolver.source["game/B"] = R"( +local tbl: string = require(game.A) + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type '{| def: number |}' could not be converted into 'string'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "bound_free_table_export_is_ok") +{ + CheckResult result = check(R"( +local n = {} +function n:Clone() end + +local m = {} + +function m.a(x) + x:Clone() +end + +function m.b() + m.a(n) +end + +return m +)"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "custom_require_global") +{ + CheckResult result = check(R"( +--!nonstrict +require = function(a) end + +local crash = require(game.A) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "require_failed_module") +{ + fileResolver.source["game/A"] = R"( +return unfortunately() + )"; + + CheckResult aResult = frontend.check("game/A"); + LUAU_REQUIRE_ERRORS(aResult); + + CheckResult result = check(R"( +local ModuleA = require(game.A) + )"); + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional oty = requireType("ModuleA"); + CHECK_EQ("*unknown*", toString(*oty)); +} + +TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types") +{ + fileResolver.source["game/A"] = R"( +export type Type = { unrelated: boolean } +return {} + )"; + + fileResolver.source["game/B"] = R"( +local types = require(game.A) +type Type = types.Type +local x: Type = {} +function x:Destroy(): () end + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_2") +{ + ScopedFastFlag immutableTypes{"LuauImmutableTypes", true}; + + fileResolver.source["game/A"] = R"( +export type Type = { x: { a: number } } +return {} + )"; + + fileResolver.source["game/B"] = R"( +local types = require(game.A) +type Type = types.Type +local x: Type = { x = { a = 2 } } +type Rename = typeof(x.x) + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_3") +{ + ScopedFastFlag immutableTypes{"LuauImmutableTypes", true}; + + fileResolver.source["game/A"] = R"( +local y = setmetatable({}, {}) +export type Type = { x: typeof(y) } +return { x = y } + )"; + + fileResolver.source["game/B"] = R"( +local types = require(game.A) +type Type = types.Type +local x: Type = types +type Rename = typeof(x.x) + )"; + + CheckResult result = frontend.check("game/B"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.oop.test.cpp b/tests/TypeInfer.oop.test.cpp new file mode 100644 index 000000000..40831bf61 --- /dev/null +++ b/tests/TypeInfer.oop.test.cpp @@ -0,0 +1,275 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferOOP"); + +TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon") +{ + CheckResult result = check(R"( + local someTable = {} + + someTable.Function1 = function(Arg1) + end + + someTable.Function1() -- Argument count mismatch + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2") +{ + CheckResult result = check(R"( + local someTable = {} + + someTable.Function2 = function(Arg1, Arg2) + end + + someTable.Function2() -- Argument count mismatch + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_another_overload_works") +{ + CheckResult result = check(R"( + type T = {method: ((T, number) -> number) & ((number) -> number)} + local T: T + + T.method(4) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "method_depends_on_table") +{ + CheckResult result = check(R"( + -- This catches a bug where x:m didn't count as a use of x + -- so toposort would happily reorder a definition of + -- function x:m before the definition of x. + function g() f() end + local x = {} + function x:m() end + function f() x:m() end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "methods_are_topologically_sorted") +{ + CheckResult result = check(R"( + local T = {} + + function T:foo() + return T:bar(999), T:bar("hi") + end + + function T:bar(i) + return i + end + + local a, b = T:foo() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); + + CHECK_EQ(PrimitiveTypeVar::Number, getPrimitiveType(requireType("a"))); + CHECK_EQ(PrimitiveTypeVar::String, getPrimitiveType(requireType("b"))); +} + +TEST_CASE_FIXTURE(Fixture, "quantify_methods_defined_using_dot_syntax_and_explicit_self_parameter") +{ + check(R"( + local T = {} + + function T.method(self) + self:method() + end + + function T.method2(self) + self:method() + end + + T:method2() + )"); +} + +TEST_CASE_FIXTURE(Fixture, "inferring_hundreds_of_self_calls_should_not_suffocate_memory") +{ + CheckResult result = check(R"( + ("foo") + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + :lower() + )"); + + ModulePtr module = getMainModule(); + CHECK_GE(50, module->internalTypes.typeVars.size()); +} + +TEST_CASE_FIXTURE(Fixture, "object_constructor_can_refer_to_method_of_self") +{ + // CLI-30902 + CheckResult result = check(R"( + --!strict + + type Foo = { + fooConn: () -> () | nil + } + + local Foo = {} + Foo.__index = Foo + + function Foo.new() + local self: Foo = { + fooConn = nil, + } + setmetatable(self, Foo) + + self.fooConn = function() + self:method() -- Key 'method' not found in table self + end + + return self + end + + function Foo:method() + print("foo") + end + + local foo = Foo.new() + + -- TODO This is the best our current refinement support can offer :( + local bar = foo.fooConn + if bar then bar() end + + -- foo.fooConn() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfSealed") +{ + CheckResult result = check(R"( +local x: {prop: number} = {prop=9999} +function x:y(z: number) + local s: string = z +end +)"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + +TEST_CASE_FIXTURE(Fixture, "nonstrict_self_mismatch_tail") +{ + CheckResult result = check(R"( +--!nonstrict +local f = {} +function f:foo(a: number, b: number) end + +function bar(...) + f.foo(f, 1, ...) +end + +bar(2) +)"); + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "inferred_methods_of_free_tables_have_the_same_level_as_the_enclosing_table") +{ + check(R"( + function Base64FileReader(data) + local reader = {} + local index: number + + function reader:PeekByte() + return data:byte(index) + end + + function reader:Byte() + return data:byte(index - 1) + end + + return reader + end + + Base64FileReader() + + function ReadMidiEvents(data) + + local reader = Base64FileReader(data) + + while reader:HasMore() do + (reader:Byte() % 128) + end + end + )"); +} + +TEST_CASE_FIXTURE(Fixture, "table_oop") +{ + CheckResult result = check(R"( + --!strict +local Class = {} +Class.__index = Class + +type Class = typeof(setmetatable({} :: { x: number }, Class)) + +function Class.new(x: number): Class + return setmetatable({x = x}, Class) +end + +function Class.getx(self: Class) + return self.x +end + +function test() + local c = Class.new(42) + local n = c:getx() + local nn = c.x + + print(string.format("%d %d", n, nn)) +end +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp new file mode 100644 index 000000000..baa259783 --- /dev/null +++ b/tests/TypeInfer.operators.test.cpp @@ -0,0 +1,759 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferOperators"); + +TEST_CASE_FIXTURE(Fixture, "or_joins_types") +{ + CheckResult result = check(R"( + local s = "a" or 10 + local x:string|number = s + )"); + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(toString(*requireType("s")), "number | string"); + CHECK_EQ(toString(*requireType("x")), "number | string"); +} + +TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_extras") +{ + CheckResult result = check(R"( + local s = "a" or 10 + local x:number|string = s + local y = x or "s" + )"); + CHECK_EQ(0, result.errors.size()); + CHECK_EQ(toString(*requireType("s")), "number | string"); + CHECK_EQ(toString(*requireType("y")), "number | string"); +} + +TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_superfluous_union") +{ + CheckResult result = check(R"( + local s = "a" or "b" + local x:string = s + )"); + CHECK_EQ(0, result.errors.size()); + CHECK_EQ(*requireType("s"), *typeChecker.stringType); +} + +TEST_CASE_FIXTURE(Fixture, "and_adds_boolean") +{ + CheckResult result = check(R"( + local s = "a" and 10 + local x:boolean|number = s + )"); + CHECK_EQ(0, result.errors.size()); + CHECK_EQ(toString(*requireType("s")), "boolean | number"); +} + +TEST_CASE_FIXTURE(Fixture, "and_adds_boolean_no_superfluous_union") +{ + CheckResult result = check(R"( + local s = "a" and true + local x:boolean = s + )"); + CHECK_EQ(0, result.errors.size()); + CHECK_EQ(*requireType("x"), *typeChecker.booleanType); +} + +TEST_CASE_FIXTURE(Fixture, "and_or_ternary") +{ + CheckResult result = check(R"( + local s = (1/2) > 0.5 and "a" or 10 + )"); + CHECK_EQ(0, result.errors.size()); + CHECK_EQ(toString(*requireType("s")), "number | string"); +} + +TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") +{ + CheckResult result = check(R"( + function add(a: number, b: string) + return a + (tonumber(b) :: number), a .. b + end + local n, s = add(2,"3") + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + const FunctionTypeVar* functionType = get(requireType("add")); + + std::optional retType = first(functionType->retType); + CHECK_EQ(std::optional(typeChecker.numberType), retType); + CHECK_EQ(requireType("n"), typeChecker.numberType); + CHECK_EQ(requireType("s"), typeChecker.stringType); +} + +TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable_with_follows") +{ + CheckResult result = check(R"( + local PI=3.1415926535897931 + local SOLAR_MASS=4*PI * PI + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ(requireType("SOLAR_MASS"), typeChecker.numberType); +} + +TEST_CASE_FIXTURE(Fixture, "primitive_arith_possible_metatable") +{ + CheckResult result = check(R"( + function add(a: number, b: any) + return a + b + end + local t = add(1,2) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("any", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops") +{ + CheckResult result = check(R"( + local a = 4 + 8 + local b = a + 9 + local s = 'hotdogs' + local t = s .. s + local c = b - a + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); + CHECK_EQ("string", toString(requireType("s"))); + CHECK_EQ("string", toString(requireType("t"))); + CHECK_EQ("number", toString(requireType("c"))); +} + +TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection") +{ + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( + --!strict + local Vec3 = {} + Vec3.__index = Vec3 + function Vec3.new() + return setmetatable({x=0, y=0, z=0}, Vec3) + end + + export type Vec3 = typeof(Vec3.new()) + + local thefun: any = function(self, o) return self end + + local multiply: ((Vec3, Vec3) -> Vec3) & ((Vec3, number) -> Vec3) = thefun + + Vec3.__mul = multiply + + local a = Vec3.new() + local b = Vec3.new() + local c = a * b + local d = a * 2 + local e = a * 'cabbage' + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Vec3", toString(requireType("a"))); + CHECK_EQ("Vec3", toString(requireType("b"))); + CHECK_EQ("Vec3", toString(requireType("c"))); + CHECK_EQ("Vec3", toString(requireType("d"))); + CHECK_EQ("Vec3", toString(requireType("e"))); +} + +TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") +{ + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( + --!strict + local Vec3 = {} + Vec3.__index = Vec3 + function Vec3.new() + return setmetatable({x=0, y=0, z=0}, Vec3) + end + + export type Vec3 = typeof(Vec3.new()) + + local thefun: any = function(self, o) return self end + + local multiply: ((Vec3, Vec3) -> Vec3) & ((Vec3, number) -> Vec3) = thefun + + Vec3.__mul = multiply + + local a = Vec3.new() + local b = Vec3.new() + local c = b * a + local d = 2 * a + local e = 'cabbage' * a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("Vec3", toString(requireType("a"))); + CHECK_EQ("Vec3", toString(requireType("b"))); + CHECK_EQ("Vec3", toString(requireType("c"))); + CHECK_EQ("Vec3", toString(requireType("d"))); + CHECK_EQ("Vec3", toString(requireType("e"))); +} + +TEST_CASE_FIXTURE(Fixture, "compare_numbers") +{ + CheckResult result = check(R"( + local a = 441 + local b = 0 + local c = a < b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "compare_strings") +{ + CheckResult result = check(R"( + local a = '441' + local b = '0' + local c = a < b + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_have_a_metatable") +{ + CheckResult result = check(R"( + local a = {} + local b = {} + local c = a < b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* gen = get(result.errors[0]); + + REQUIRE_EQ(gen->message, "Type a cannot be compared with < because it has no metatable"); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators") +{ + CheckResult result = check(R"( + local M = {} + function M.new() + return setmetatable({}, M) + end + type M = typeof(M.new()) + + local a = M.new() + local b = M.new() + local c = a < b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* gen = get(result.errors[0]); + REQUIRE(gen != nullptr); + REQUIRE_EQ(gen->message, "Table M does not offer metamethod __lt"); +} + +TEST_CASE_FIXTURE(Fixture, "cannot_compare_tables_that_do_not_have_the_same_metatable") +{ + CheckResult result = check(R"( + --!strict + local M = {} + function M.new() + return setmetatable({}, M) + end + function M.__lt(left, right) return true end + + local a = M.new() + local b = {} + local c = a < b -- line 10 + local d = b < a -- line 11 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + + REQUIRE_EQ((Location{{10, 18}, {10, 23}}), result.errors[0].location); + + REQUIRE_EQ((Location{{11, 18}, {11, 23}}), result.errors[1].location); +} + +TEST_CASE_FIXTURE(Fixture, "produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not") +{ + CheckResult result = check(R"( + --!strict + local M = {} + function M.new() + return setmetatable({}, M) + end + function M.__lt(left, right) return true end + type M = typeof(M.new()) + + local a = M.new() + local b = {} + local c = a < b -- line 10 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto err = get(result.errors[0]); + REQUIRE(err != nullptr); + + // Frail. :| + REQUIRE_EQ("Types M and b cannot be compared with < because they do not have the same metatable", err->message); +} + +TEST_CASE_FIXTURE(Fixture, "in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators") +{ + CheckResult result = check(R"( + --!nonstrict + + function maybe_a_number(): number? + return 50 + end + + local a = maybe_a_number() < maybe_a_number() + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "compound_assign_basic") +{ + CheckResult result = check(R"( + local s = 10 + s += 20 + )"); + CHECK_EQ(0, result.errors.size()); + CHECK_EQ(toString(*requireType("s")), "number"); +} + +TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_op") +{ + CheckResult result = check(R"( + local s = 10 + s += true + )"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 13}, {2, 17}}, TypeMismatch{typeChecker.numberType, typeChecker.booleanType}})); +} + +TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_result") +{ + CheckResult result = check(R"( + local s = 'hello' + s += 10 + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); + CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{typeChecker.stringType, typeChecker.numberType}})); +} + +TEST_CASE_FIXTURE(Fixture, "compound_assign_metatable") +{ + CheckResult result = check(R"( + --!strict + type V2B = { x: number, y: number } + local v2b: V2B = { x = 0, y = 0 } + local VMT = {} + type V2 = typeof(setmetatable(v2b, VMT)) + + function VMT.__add(a: V2, b: V2): V2 + return setmetatable({ x = a.x + b.x, y = a.y + b.y }, VMT) + end + + local v1: V2 = setmetatable({ x = 1, y = 2 }, VMT) + local v2: V2 = setmetatable({ x = 3, y = 4 }, VMT) + v1 += v2 + )"); + CHECK_EQ(0, result.errors.size()); +} + +TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_metatable") +{ + CheckResult result = check(R"( + --!strict + type V2B = { x: number, y: number } + local v2b: V2B = { x = 0, y = 0 } + local VMT = {} + type V2 = typeof(setmetatable(v2b, VMT)) + + function VMT.__mod(a: V2, b: V2): number + return a.x * b.x + a.y * b.y + end + + local v1: V2 = setmetatable({ x = 1, y = 2 }, VMT) + local v2: V2 = setmetatable({ x = 3, y = 4 }, VMT) + v1 %= v2 + )"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + CHECK_EQ(*tm->wantedType, *requireType("v2")); + CHECK_EQ(*tm->givenType, *typeChecker.numberType); +} + +TEST_CASE_FIXTURE(Fixture, "CallOrOfFunctions") +{ + CheckResult result = check(R"( +function f() return 1; end +function g() return 2; end +(f or g)() +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "CallAndOrOfFunctions") +{ + CheckResult result = check(R"( +function f() return 1; end +function g() return 2; end +local x = false +(x and f or g)() +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "typecheck_unary_minus") +{ + CheckResult result = check(R"( + --!strict + local foo = { + value = 10 + } + local mt = {} + setmetatable(foo, mt) + + mt.__unm = function(val: typeof(foo)): string + return val.value .. "test" + end + + local a = -foo + + local b = 1+-1 + + local bar = { + value = 10 + } + local c = -bar -- disallowed + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ("string", toString(requireType("a"))); + CHECK_EQ("number", toString(requireType("b"))); + + GenericError* gen = get(result.errors[0]); + REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'"); +} + +TEST_CASE_FIXTURE(Fixture, "unary_not_is_boolean") +{ + CheckResult result = check(R"( + local b = not "string" + local c = not (math.random() > 0.5 and "string" or 7) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + REQUIRE_EQ("boolean", toString(requireType("b"))); + REQUIRE_EQ("boolean", toString(requireType("c"))); +} + +TEST_CASE_FIXTURE(Fixture, "disallow_string_and_types_without_metatables_from_arithmetic_binary_ops") +{ + CheckResult result = check(R"( + --!strict + local a = "1.24" + 123 -- not allowed + + local foo = { + value = 10 + } + + local b = foo + 1 -- not allowed + + local bar = { + value = 1 + } + + local mt = {} + + setmetatable(bar, mt) + + mt.__add = function(a: typeof(bar), b: number): number + return a.value + b + end + + local c = bar + 1 -- allowed + + local d = bar + foo -- not allowed + )"); + + LUAU_REQUIRE_ERROR_COUNT(3, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE_EQ(*tm->wantedType, *typeChecker.numberType); + REQUIRE_EQ(*tm->givenType, *typeChecker.stringType); + + TypeMismatch* tm2 = get(result.errors[2]); + CHECK_EQ(*tm2->wantedType, *typeChecker.numberType); + CHECK_EQ(*tm2->givenType, *requireType("foo")); + + GenericError* gen2 = get(result.errors[1]); + REQUIRE_EQ(gen2->message, "Binary operator '+' not supported by types 'foo' and 'number'"); +} + +// CLI-29033 +TEST_CASE_FIXTURE(Fixture, "unknown_type_in_comparison") +{ + CheckResult result = check(R"( + function merge(lower, greater) + if lower.y == greater.y then + end + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "concat_op_on_free_lhs_and_string_rhs") +{ + CheckResult result = check(R"( + local function f(x) + return x .. "y" + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + REQUIRE(get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "concat_op_on_string_lhs_and_free_rhs") +{ + CheckResult result = check(R"( + local function f(x) + return "foo" .. x + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("(string) -> string", toString(requireType("f"))); +} + +TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") +{ + std::vector ops = {"+", "-", "*", "/", "%", "^", ".."}; + + std::string src = R"( + function foo(a, b) + )"; + + for (const auto& op : ops) + src += "local _ = a " + op + "b\n"; + + src += "end"; + + CheckResult result = check(src); + LUAU_REQUIRE_ERROR_COUNT(ops.size(), result); + + CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'a'", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "and_binexps_dont_unify") +{ + CheckResult result = check(R"( + --!strict + local t = {} + while true and t[1] do + print(t[1].test) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators") +{ + CheckResult result = check(R"( + local a: boolean = true + local b: boolean = false + local foo = a < b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK_EQ("Type 'boolean' cannot be compared with relational operator <", ge->message); +} + +TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators2") +{ + CheckResult result = check(R"( + local a: number | string = "" + local b: number | string = 1 + local foo = a < b + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + GenericError* ge = get(result.errors[0]); + REQUIRE(ge); + CHECK_EQ("Type 'number | string' cannot be compared with relational operator <", ge->message); +} + +TEST_CASE_FIXTURE(Fixture, "cli_38355_recursive_union") +{ + CheckResult result = check(R"( + --!strict + local _ + _ += _ and _ or _ and _ or _ and _ + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Type contains a self-recursive construct that cannot be resolved", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign") +{ + // In non-strict mode, global definition is still allowed + { + CheckResult result = check(R"( + --!nonstrict + a = a + 1 + print(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); + } + + // In strict mode we no longer generate two errors from lhs + { + CheckResult result = check(R"( + --!strict + a += 1 + print(a) + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); + } + + // In non-strict mode, compound assignment is not a definition, it's a modification + { + CheckResult result = check(R"( + --!nonstrict + a += 1 + print(a) + )"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); + CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); + } +} + +TEST_CASE_FIXTURE(Fixture, "strip_nil_from_lhs_or_operator") +{ + CheckResult result = check(R"( +--!strict +local a: number? = nil +local b: number = a or 1 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "strip_nil_from_lhs_or_operator2") +{ + CheckResult result = check(R"( +--!nonstrict +local a: number? = nil +local b: number = a or 1 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "dont_strip_nil_from_rhs_or_operator") +{ + CheckResult result = check(R"( +--!strict +local a: number? = nil +local b: number = 1 or a + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ("number?", toString(tm->givenType)); +} + +TEST_CASE_FIXTURE(Fixture, "operator_eq_verifies_types_do_intersect") +{ + CheckResult result = check(R"( + type Array = { [number]: T } + type Fiber = { id: number } + type null = {} + + local fiberStack: Array = {} + local index = 0 + + local function f(fiber: Fiber) + local a = fiber ~= fiberStack[index] + local b = fiberStack[index] ~= fiber + end + + return f + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "operator_eq_operands_are_not_subtypes_of_each_other_but_has_overlap") +{ + ScopedFastFlag sff1{"LuauEqConstraint", true}; + + CheckResult result = check(R"( + local function f(a: string | number, b: boolean | number) + return a == b + end + )"); + + // This doesn't produce any errors but for the wrong reasons. + // This unit test serves as a reminder to not try and unify the operands on `==`/`~=`. + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "refine_and_or") +{ + CheckResult result = check(R"( + local t: {x: number?}? = {x = nil} + local u = t and t.x or 5 + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("number", toString(requireType("u"))); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.primitives.test.cpp b/tests/TypeInfer.primitives.test.cpp new file mode 100644 index 000000000..44b7b0d0b --- /dev/null +++ b/tests/TypeInfer.primitives.test.cpp @@ -0,0 +1,100 @@ +// This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details + +#include "Luau/AstQuery.h" +#include "Luau/BuiltinDefinitions.h" +#include "Luau/Scope.h" +#include "Luau/TypeInfer.h" +#include "Luau/TypeVar.h" +#include "Luau/VisitTypeVar.h" + +#include "Fixture.h" + +#include "doctest.h" + +using namespace Luau; + +TEST_SUITE_BEGIN("TypeInferPrimitives"); + +TEST_CASE_FIXTURE(Fixture, "cannot_call_primitives") +{ + CheckResult result = check("local foo = 5 foo()"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + REQUIRE(get(result.errors[0]) != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "string_length") +{ + CheckResult result = check(R"( + local s = "Hello, World!" + local t = #s + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("number", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(Fixture, "string_index") +{ + CheckResult result = check(R"( + local s = "Hello, World!" + local t = s[4] + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + NotATable* nat = get(result.errors[0]); + REQUIRE(nat); + CHECK_EQ("string", toString(nat->ty)); + + CHECK_EQ("*unknown*", toString(requireType("t"))); +} + +TEST_CASE_FIXTURE(Fixture, "string_method") +{ + CheckResult result = check(R"( + local p = ("tacos"):len() + )"); + CHECK_EQ(0, result.errors.size()); + + CHECK_EQ(*requireType("p"), *typeChecker.numberType); +} + +TEST_CASE_FIXTURE(Fixture, "string_function_indirect") +{ + CheckResult result = check(R"( + local s:string + local l = s.lower + local p = l(s) + )"); + CHECK_EQ(0, result.errors.size()); + + CHECK_EQ(*requireType("p"), *typeChecker.stringType); +} + +TEST_CASE_FIXTURE(Fixture, "string_function_other") +{ + CheckResult result = check(R"( + local s:string + local p = s:match("foo") + )"); + CHECK_EQ(0, result.errors.size()); + + CHECK_EQ(toString(requireType("p")), "string?"); +} + +TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") +{ + ScopedFastFlag sff{"LuauErrorRecoveryType", true}; + + CheckResult result = check(R"( +local x: number = 9999 +function x:y(z: number) + local s: string = z +end +)"); + + LUAU_REQUIRE_ERROR_COUNT(2, result); +} + +TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index a5147d56a..9b347921f 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -1298,4 +1298,22 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "x_is_not_instance_or_else_not_part") CHECK_EQ("Part", toString(requireTypeAtPosition({5, 28}))); } +TEST_CASE_FIXTURE(Fixture, "typeguard_doesnt_leak_to_elseif") +{ + const std::string code = R"( + function f(a) + if type(a) == "boolean" then + local a1 = a + elseif a.fn() then + local a2 = a + else + local a3 = a + end + end + )"; + CheckResult result = check(code); + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 3ed536ea6..7f8d8fec2 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -5,6 +5,8 @@ #include "doctest.h" #include "Luau/BuiltinDefinitions.h" +LUAU_FASTFLAG(BetterDiagnosticCodesInStudio) + using namespace Luau; TEST_SUITE_BEGIN("TypeSingletons"); @@ -353,7 +355,14 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer") )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Syntax error: Cannot have more than one table indexer", toString(result.errors[0])); + if (FFlag::BetterDiagnosticCodesInStudio) + { + CHECK_EQ("Cannot have more than one table indexer", toString(result.errors[0])); + } + else + { + CHECK_EQ("Syntax error: Cannot have more than one table indexer", toString(result.errors[0])); + } } TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") @@ -445,7 +454,7 @@ TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_si {"LuauSingletonTypes", true}, {"LuauEqConstraint", true}, {"LuauDiscriminableUnions2", true}, - {"LuauWidenIfSupertypeIsFree", true}, + {"LuauWidenIfSupertypeIsFree2", true}, {"LuauWeakEqConstraint", false}, }; @@ -472,9 +481,9 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") {"LuauSingletonTypes", true}, {"LuauDiscriminableUnions2", true}, {"LuauEqConstraint", true}, - {"LuauWidenIfSupertypeIsFree", true}, + {"LuauWidenIfSupertypeIsFree2", true}, {"LuauWeakEqConstraint", false}, - {"LuauDoNotAccidentallyDependOnPointerOrdering", true} + {"LuauDoNotAccidentallyDependOnPointerOrdering", true}, }; CheckResult result = check(R"( @@ -497,7 +506,7 @@ TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere") ScopedFastFlag sff[]{ {"LuauParseSingletonTypes", true}, {"LuauSingletonTypes", true}, - {"LuauWidenIfSupertypeIsFree", true}, + {"LuauWidenIfSupertypeIsFree2", true}, }; CheckResult result = check(R"( @@ -515,7 +524,7 @@ TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere_except_for_tables {"LuauParseSingletonTypes", true}, {"LuauSingletonTypes", true}, {"LuauDiscriminableUnions2", true}, - {"LuauWidenIfSupertypeIsFree", true}, + {"LuauWidenIfSupertypeIsFree2", true}, }; CheckResult result = check(R"( @@ -544,7 +553,7 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_with_a_singleton_argument") ScopedFastFlag sff[]{ {"LuauParseSingletonTypes", true}, {"LuauSingletonTypes", true}, - {"LuauWidenIfSupertypeIsFree", true}, + {"LuauWidenIfSupertypeIsFree2", true}, }; CheckResult result = check(R"( @@ -565,4 +574,97 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_with_a_singleton_argument") CHECK_EQ("{string}", toString(requireType("t"))); } +TEST_CASE_FIXTURE(Fixture, "functions_are_not_to_be_widened") +{ + ScopedFastFlag sff[]{ + {"LuauParseSingletonTypes", true}, + {"LuauSingletonTypes", true}, + {"LuauWidenIfSupertypeIsFree2", true}, + }; + + CheckResult result = check(R"( + local function foo(my_enum: "A" | "B") end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"(("A" | "B") -> ())", toString(requireType("foo"))); +} + +TEST_CASE_FIXTURE(Fixture, "indexing_on_string_singletons") +{ + ScopedFastFlag sff[]{ + {"LuauDiscriminableUnions2", true}, + {"LuauSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: string = "hi" + if a == "hi" then + local x = a:byte() + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 22}))); +} + +TEST_CASE_FIXTURE(Fixture, "indexing_on_union_of_string_singletons") +{ + ScopedFastFlag sff[]{ + {"LuauDiscriminableUnions2", true}, + {"LuauSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: string = "hi" + if a == "hi" or a == "bye" then + local x = a:byte() + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("bye" | "hi")", toString(requireTypeAtPosition({3, 22}))); +} + +TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_string_singleton") +{ + ScopedFastFlag sff[]{ + {"LuauDiscriminableUnions2", true}, + {"LuauSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: string = "hi" + if a == "hi" then + local x = #a + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 23}))); +} + +TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_union_of_string_singleton") +{ + ScopedFastFlag sff[]{ + {"LuauDiscriminableUnions2", true}, + {"LuauSingletonTypes", true}, + }; + + CheckResult result = check(R"( + local a: string = "hi" + if a == "hi" or a == "bye" then + local x = #a + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ(R"("bye" | "hi")", toString(requireTypeAtPosition({3, 23}))); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index a5eba5dfe..91140aaa4 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2384,4 +2384,504 @@ _ = (_.cos) LUAU_REQUIRE_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "cannot_call_tables") +{ + CheckResult result = check("local foo = {} foo()"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK(get(result.errors[0]) != nullptr); +} + +TEST_CASE_FIXTURE(Fixture, "table_length") +{ + CheckResult result = check(R"( + local t = {} + local s = #t + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK(nullptr != get(requireType("t"))); + CHECK_EQ(*typeChecker.numberType, *requireType("s")); +} + +TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_indexer") +{ + CheckResult result = check("local a = {} a[0] = 7 a[0] = nil"); + LUAU_REQUIRE_ERROR_COUNT(0, result); +} + +TEST_CASE_FIXTURE(Fixture, "wrong_assign_does_hit_indexer") +{ + CheckResult result = check("local a = {} a[0] = 7 a[0] = 't'"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 30}, Position{0, 33}}, TypeMismatch{ + typeChecker.numberType, + typeChecker.stringType, + }})); +} + +TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_no_indexer") +{ + CheckResult result = check("local a = {a=1, b=2} a['a'] = nil"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 30}, Position{0, 33}}, TypeMismatch{ + typeChecker.numberType, + typeChecker.nilType, + }})); +} + +TEST_CASE_FIXTURE(Fixture, "free_rhs_table_can_also_be_bound") +{ + check(R"( + local o + local v = o:i() + + function g(u) + v = u + end + + o:f(g) + o:h() + o:h() + )"); +} + +TEST_CASE_FIXTURE(Fixture, "table_unifies_into_map") +{ + CheckResult result = check(R"( + local Instance: any + local UDim2: any + + function Create(instanceType) + return function(data) + local obj = Instance.new(instanceType) + for k, v in pairs(data) do + if type(k) == 'number' then + --v.Parent = obj + else + obj[k] = v + end + end + return obj + end + end + + local topbarShadow = Create'ImageLabel'{ + Name = "TopBarShadow"; + Size = UDim2.new(1, 0, 0, 3); + Position = UDim2.new(0, 0, 1, 0); + Image = "rbxasset://textures/ui/TopBar/dropshadow.png"; + BackgroundTransparency = 1; + Active = false; + Visible = false; + }; + + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "tables_get_names_from_their_locals") +{ + CheckResult result = check(R"( + local T = {} + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + CHECK_EQ("T", toString(requireType("T"))); +} + +TEST_CASE_FIXTURE(Fixture, "generalize_table_argument") +{ + CheckResult result = check(R"( + function foo(arr) + local work = {} + for i = 1, #arr do + work[i] = arr[i] + end + + return arr + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); + + const FunctionTypeVar* fooType = get(requireType("foo")); + REQUIRE(fooType); + + std::optional fooArg1 = first(fooType->argTypes); + REQUIRE(fooArg1); + + const TableTypeVar* fooArg1Table = get(*fooArg1); + REQUIRE(fooArg1Table); + + CHECK_EQ(fooArg1Table->state, TableState::Generic); +} + +/* + * This test case exposed an oversight in the treatment of free tables. + * Free tables, like free TypeVars, need to record the scope depth where they were created so that + * we do not erroneously let-generalize them when they are used in a nested lambda. + * + * For more information about let-generalization, see + * + * The important idea here is that the return type of Counter.new is a table with some metatable. + * That metatable *must* be the same TypeVar as the type of Counter. If it is a copy (produced by + * the generalization process), then it loses the knowledge that its metatable will have an :incr() + * method. + */ +TEST_CASE_FIXTURE(Fixture, "dont_quantify_table_that_belongs_to_outer_scope") +{ + CheckResult result = check(R"( + local Counter = {} + Counter.__index = Counter + + function Counter.new() + local self = setmetatable({count=0}, Counter) + return self + end + + function Counter:incr() + self.count = 1 + return self.count + end + + local self = Counter.new() + print(self:incr()) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + TableTypeVar* counterType = getMutable(requireType("Counter")); + REQUIRE(counterType); + + const FunctionTypeVar* newType = get(follow(counterType->props["new"].type)); + REQUIRE(newType); + + std::optional newRetType = *first(newType->retType); + REQUIRE(newRetType); + + const MetatableTypeVar* newRet = get(follow(*newRetType)); + REQUIRE(newRet); + + const TableTypeVar* newRetMeta = get(newRet->metatable); + REQUIRE(newRetMeta); + + CHECK(newRetMeta->props.count("incr")); + CHECK_EQ(follow(newRet->metatable), follow(requireType("Counter"))); +} + +// TODO: CLI-39624 +TEST_CASE_FIXTURE(Fixture, "instantiate_tables_at_scope_level") +{ + CheckResult result = check(R"( + --!strict + local Option = {} + Option.__index = Option + function Option.Is(obj) + return (type(obj) == "table" and getmetatable(obj) == Option) + end + return Option + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "inferring_crazy_table_should_also_be_quick") +{ + CheckResult result = check(R"( + --!strict + function f(U) + U(w:s(an):c()():c():U(s):c():c():U(s):c():U(s):cU()):c():U(s):c():U(s):c():c():U(s):c():U(s):cU() + end + )"); + + ModulePtr module = getMainModule(); + CHECK_GE(100, module->internalTypes.typeVars.size()); +} + +TEST_CASE_FIXTURE(Fixture, "MixedPropertiesAndIndexers") +{ + CheckResult result = check(R"( +local x = {} +x.a = "a" +x[0] = true +x.b = 37 +)"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +TEST_CASE_FIXTURE(Fixture, "setmetatable_cant_be_used_to_mutate_global_types") +{ + { + Fixture fix; + + // inherit env from parent fixture checker + fix.typeChecker.globalScope = typeChecker.globalScope; + + fix.check(R"( +--!nonstrict +type MT = typeof(setmetatable) +function wtf(arg: {MT}): typeof(table) + arg = wtf(arg) +end +)"); + } + + // validate sharedEnv post-typecheck; valuable for debugging some typeck crashes but slows fuzzing down + // note: it's important for typeck to be destroyed at this point! + { + for (auto& p : typeChecker.globalScope->bindings) + { + toString(p.second.typeId); // toString walks the entire type, making sure ASAN catches access to destroyed type arenas + } + } +} + +TEST_CASE_FIXTURE(Fixture, "evil_table_unification") +{ + // this code re-infers the type of _ while processing fields of _, which can cause use-after-free + check(R"( +--!nonstrict +_ = ... +_:table(_,string)[_:gsub(_,...,n0)],_,_:gsub(_,string)[""],_:split(_,...,table)._,n0 = nil +do end +)"); +} + +TEST_CASE_FIXTURE(Fixture, "dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar") +{ + CheckResult result = check("local x = setmetatable({})"); + LUAU_REQUIRE_ERROR_COUNT(1, result); +} + +TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning") +{ + CheckResult result = check(R"( +--!nonstrict +local l0:any,l61:t0 = _,math +while _ do +_() +end +function _():t0 +end +type t0 = any +)"); + + std::optional ty = requireType("math"); + REQUIRE(ty); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning_2") +{ + ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; + + CheckResult result = check(R"( +type X = T +type K = X +)"); + + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("math"); + REQUIRE(ty); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning_3") +{ + ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; + + CheckResult result = check(R"( +type X = T +local a = {} +a.x = 4 +local b: X +a.y = 5 +local c: X +c = b +)"); + + LUAU_REQUIRE_NO_ERRORS(result); + + std::optional ty = requireType("a"); + REQUIRE(ty); + + const TableTypeVar* ttv = get(*ty); + REQUIRE(ttv); + CHECK(ttv->instantiatedTypeParams.empty()); +} + +TEST_CASE_FIXTURE(Fixture, "table_indexing_error_location") +{ + CheckResult result = check(R"( +local foo = {42} +local bar: number? +local baz = foo[bar] + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + CHECK_EQ(result.errors[0].location, Location{Position{3, 16}, Position{3, 19}}); +} + +TEST_CASE_FIXTURE(Fixture, "table_simple_call") +{ + CheckResult result = check(R"( +local a = setmetatable({ x = 2 }, { + __call = function(self) + return (self.x :: number) * 2 -- should work without annotation in the future + end +}) +local b = a() +local c = a(2) -- too many arguments + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "access_index_metamethod_that_returns_variadic") +{ + CheckResult result = check(R"( + type Foo = {x: string} + local t = {} + setmetatable(t, { + __index = function(x: string): ...Foo + return {x = x} + end + }) + + local foo = t.bar + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + ToStringOptions o; + o.exhaustive = true; + CHECK_EQ("{| x: string |}", toString(requireType("foo"), o)); +} + +TEST_CASE_FIXTURE(Fixture, "dont_invalidate_the_properties_iterator_of_free_table_when_rolled_back") +{ + fileResolver.source["Module/Backend/Types"] = R"( + export type Fiber = { + return_: Fiber? + } + return {} + )"; + + fileResolver.source["Module/Backend"] = R"( + local Types = require(script.Types) + type Fiber = Types.Fiber + type ReactRenderer = { findFiberByHostInstance: () -> Fiber? } + + local function attach(renderer): () + local function getPrimaryFiber(fiber) + local alternate = fiber.alternate + return fiber + end + + local function getFiberIDForNative() + local fiber = renderer.findFiberByHostInstance() + fiber = fiber.return_ + return getPrimaryFiber(fiber) + end + end + + function culprit(renderer: ReactRenderer): () + attach(renderer) + end + + return culprit + )"; + + CheckResult result = frontend.check("Module/Backend"); +} + +TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early") +{ + CheckResult result = check(R"( + local t: {x: number?}? = {x = nil} + local u = t.x and t or 5 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); + CHECK_EQ("number | {| x: number? |}", toString(requireType("u"))); +} + +TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") +{ + CheckResult result = check(R"( + local t: {x: number?}? = {x = nil} + local u = t and t.x == 5 or t.x == 31337 + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); + CHECK_EQ("boolean", toString(requireType("u"))); +} + +/* + * We had an issue where part of the type of pairs() was an unsealed table. + * This test depends on FFlagDebugLuauFreezeArena to trigger it. + */ +TEST_CASE_FIXTURE(Fixture, "pairs_parameters_are_not_unsealed_tables") +{ + check(R"( + function _(l0:{n0:any}) + _ = pairs + end + )"); +} + +TEST_CASE_FIXTURE(Fixture, "table_function_check_use_after_free") +{ + CheckResult result = check(R"( +local t = {} + +function t.x(value) + for k,v in pairs(t) do end +end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + +/* + * When we add new properties to an unsealed table, we should do a level check and promote the property type to be at + * the level of the table. + */ +TEST_CASE_FIXTURE(Fixture, "inferred_properties_of_a_table_should_start_with_the_same_TypeLevel_of_that_table") +{ + CheckResult result = check(R"( + --!strict + local T = {} + + local function f(prop) + T[1] = { + prop = prop, + } + end + + local function g() + local l = T[1].prop + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index d7bbad20d..571d0f8d6 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -60,15 +60,6 @@ TEST_CASE_FIXTURE(Fixture, "tc_error_2") }})); } -TEST_CASE_FIXTURE(Fixture, "tc_function") -{ - CheckResult result = check("function five() return 5 end"); - LUAU_REQUIRE_NO_ERRORS(result); - - const FunctionTypeVar* fiveType = get(requireType("five")); - REQUIRE(fiveType != nullptr); -} - TEST_CASE_FIXTURE(Fixture, "infer_locals_with_nil_value") { CheckResult result = check("local f = nil; f = 'hello world'"); @@ -108,4159 +99,565 @@ TEST_CASE_FIXTURE(Fixture, "infer_in_nocheck_mode") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "check_function_bodies") +TEST_CASE_FIXTURE(Fixture, "expr_statement") { - CheckResult result = check("function myFunction() local a = 0 a = true end"); + CheckResult result = check("local foo = 5 foo()"); LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 44}, Position{0, 48}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.booleanType, - }})); } -TEST_CASE_FIXTURE(Fixture, "infer_return_type") +TEST_CASE_FIXTURE(Fixture, "if_statement") { - CheckResult result = check("function take_five() return 5 end"); - LUAU_REQUIRE_NO_ERRORS(result); + CheckResult result = check(R"( + local a + local b - const FunctionTypeVar* takeFiveType = get(requireType("take_five")); - REQUIRE(takeFiveType != nullptr); + if true then + a = 'hello' + else + b = 999 + end + )"); - std::vector retVec = flatten(takeFiveType->retType).first; - REQUIRE(!retVec.empty()); + LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ(*follow(retVec[0]), *typeChecker.numberType); + CHECK_EQ(*typeChecker.stringType, *requireType("a")); + CHECK_EQ(*typeChecker.numberType, *requireType("b")); } -TEST_CASE_FIXTURE(Fixture, "infer_from_function_return_type") +TEST_CASE_FIXTURE(Fixture, "statements_are_topologically_sorted") { - CheckResult result = check("function take_five() return 5 end local five = take_five()"); - LUAU_REQUIRE_NO_ERRORS(result); + CheckResult result = check(R"( + function foo() + return bar(999), bar("hi") + end + + function bar(i) + return i + end + )"); - CHECK_EQ(*typeChecker.numberType, *follow(requireType("five"))); + LUAU_REQUIRE_NO_ERRORS(result); + dumpErrors(result); } -TEST_CASE_FIXTURE(Fixture, "cannot_call_primitives") +TEST_CASE_FIXTURE(Fixture, "unify_nearly_identical_recursive_types") { - CheckResult result = check("local foo = 5 foo()"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + CheckResult result = check(R"( + local o + o:method() - REQUIRE(get(result.errors[0]) != nullptr); -} + local p + p:method() -TEST_CASE_FIXTURE(Fixture, "cannot_call_tables") -{ - CheckResult result = check("local foo = {} foo()"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + o = p + )"); - CHECK(get(result.errors[0]) != nullptr); + LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "infer_that_function_does_not_return_a_table") +TEST_CASE_FIXTURE(Fixture, "warn_on_lowercase_parent_property") { CheckResult result = check(R"( - function take_five() - return 5 - end - - take_five().prop = 888 + local M = require(script.parent.DoesNotMatter) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{5, 8}, Position{5, 24}}, NotATable{typeChecker.numberType}})); -} -TEST_CASE_FIXTURE(Fixture, "expr_statement") -{ - CheckResult result = check("local foo = 5 foo()"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + auto ed = get(result.errors[0]); + REQUIRE(ed); + + REQUIRE_EQ("parent", ed->symbol); } -TEST_CASE_FIXTURE(Fixture, "generic_function") +TEST_CASE_FIXTURE(Fixture, "weird_case") { CheckResult result = check(R"( - function id(x) return x end - local a = id(55) - local b = id(nil) + local function f() return 4 end + local d = math.deg(f()) )"); LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.numberType, *requireType("a")); - CHECK_EQ(*typeChecker.nilType, *requireType("b")); + dumpErrors(result); } -TEST_CASE_FIXTURE(Fixture, "vararg_functions_should_allow_calls_of_any_types_and_size") +TEST_CASE_FIXTURE(Fixture, "dont_ice_when_failing_the_occurs_check") { CheckResult result = check(R"( - function f(...) end + --!strict + local s + s(s, 'a') + )"); + LUAU_REQUIRE_ERROR_COUNT(0, result); +} - f(1) - f("foo", 2) +TEST_CASE_FIXTURE(Fixture, "occurs_check_does_not_recurse_forever_if_asked_to_traverse_a_cyclic_type") +{ + CheckResult result = check(R"( + --!strict + function u(t, w) + u(u, t) + end )"); LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "vararg_function_is_quantified") +#if 0 +// CLI-29798 +TEST_CASE_FIXTURE(Fixture, "crazy_complexity") { CheckResult result = check(R"( - local T = {} - function T.f(...) - local result = {} - - for i = 1, select("#", ...) do - local dictionary = select(i, ...) - for key, value in pairs(dictionary) do - result[key] = value - end - end - - return result - end - - return T + --!nonstrict + A:A():A():A():A():A():A():A():A():A():A():A() )"); - auto r = first(getMainModule()->getModuleScope()->returnType); - REQUIRE(r); - - TableTypeVar* ttv = getMutable(*r); - REQUIRE(ttv); - - TypeId k = ttv->props["f"].type; - REQUIRE(k); - - LUAU_REQUIRE_NO_ERRORS(result); + std::cout << "OK! Allocated " << typeChecker.typeVars.size() << " typevars" << std::endl; } +#endif -TEST_CASE_FIXTURE(Fixture, "for_loop") +TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types") { CheckResult result = check(R"( - local q - for i=0, 50, 2 do - q = i - end + local err = (true).x + local c = err.Parent.Reward.GetChildren + local d = err.Parent.Reward + local e = err.Parent + local f = err )"); - LUAU_REQUIRE_NO_ERRORS(result); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + UnknownProperty* err = get(result.errors[0]); + REQUIRE(err != nullptr); + CHECK_EQ("boolean", toString(err->table)); + CHECK_EQ("x", err->key); - CHECK_EQ(*typeChecker.numberType, *requireType("q")); + CHECK_EQ("*unknown*", toString(requireType("c"))); + CHECK_EQ("*unknown*", toString(requireType("d"))); + CHECK_EQ("*unknown*", toString(requireType("e"))); + CHECK_EQ("*unknown*", toString(requireType("f"))); } -TEST_CASE_FIXTURE(Fixture, "for_in_loop") +TEST_CASE_FIXTURE(Fixture, "should_be_able_to_infer_this_without_stack_overflowing") { CheckResult result = check(R"( - local n - local s - for i, v in pairs({ "foo" }) do - n = i - s = v + local function f(x, y) + return x or y + end + + local function dont_crash(x, y) + local z: typeof(f(x, y)) = f(x, y) end )"); LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.numberType, *requireType("n")); - CHECK_EQ(*typeChecker.stringType, *requireType("s")); } -TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_next") +TEST_CASE_FIXTURE(Fixture, "exponential_blowup_from_copying_types") { CheckResult result = check(R"( - local n - local s - for i, v in next, { "foo", "bar" } do - n = i - s = v - end + --!strict + -- An example of exponential blowup in number of types + -- The problem is that if we define function f(a) return x end + -- then this has type (t)->T where x:T + -- *but* it copies T each time f is applied + -- so { left = f("hi"), right = f(5) } + -- has type { left : T_L, right : T_R } + -- where T_L and T_R are copies of T. + -- x0 : T0 where T0 = {} + local x0 = {} + -- f0 : (t)->T0 + local function f0(a) return x0 end + -- x1 : T1 where T1 = { left : T0_L, right : T0_R } + local x1 = { left = f0("hi"), right = f0(5) } + -- f1 : (t)->T1 + local function f1(a) return x1 end + -- x2 : T2 where T2 = { left : T1_L, right : T1_R } + local x2 = { left = f1("hi"), right = f1(5) } + -- f2 : (t)->T2 + local function f2(a) return x2 end + -- etc etc + local x3 = { left = f2("hi"), right = f2(5) } + local function f3(a) return x3 end + local x4 = { left = f3("hi"), right = f3(5) } + return x4 )"); LUAU_REQUIRE_NO_ERRORS(result); + ModulePtr module = getMainModule(); - CHECK_EQ(*typeChecker.numberType, *requireType("n")); - CHECK_EQ(*typeChecker.stringType, *requireType("s")); + // If we're not careful about copying, this ends up with O(2^N) types rather than O(N) + // (in this case 5 vs 31). + CHECK_GE(5, module->interfaceTypes.typeVars.size()); } -TEST_CASE_FIXTURE(Fixture, "for_in_with_an_iterator_of_type_any") +// In these tests, a successful parse is required, so we need the parser to return the AST and then we can test the recursion depth limit in type +// checker. We also want it to somewhat match up with production values, so we push up the parser recursion limit a little bit instead. +TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_limit") { - CheckResult result = check(R"( - local it: any - local a, b - for i, v in it do - a, b = i, v - end - )"); +#if defined(LUAU_ENABLE_ASAN) + int limit = 250; +#elif defined(_DEBUG) || defined(_NOOPT) + int limit = 350; +#else + int limit = 600; +#endif + ScopedFastInt luauRecursionLimit{"LuauRecursionLimit", limit + 100}; + ScopedFastInt luauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", limit - 100}; + ScopedFastInt luauCheckRecursionLimit{"LuauCheckRecursionLimit", 0}; - LUAU_REQUIRE_NO_ERRORS(result); + CHECK_NOTHROW(check("print('Hello!')")); + CHECK_THROWS_AS(check("function f() return " + rep("{a=", limit) + "'a'" + rep("}", limit) + " end"), std::runtime_error); } -TEST_CASE_FIXTURE(Fixture, "for_in_loop_should_fail_with_non_function_iterator") +TEST_CASE_FIXTURE(Fixture, "check_block_recursion_limit") { - CheckResult result = check(R"( - local foo = "bar" - for i, v in foo do - end - )"); +#if defined(LUAU_ENABLE_ASAN) + int limit = 250; +#elif defined(_DEBUG) || defined(_NOOPT) + int limit = 350; +#else + int limit = 600; +#endif + + ScopedFastInt luauRecursionLimit{"LuauRecursionLimit", limit + 100}; + ScopedFastInt luauCheckRecursionLimit{"LuauCheckRecursionLimit", limit - 100}; + + CheckResult result = check(rep("do ", limit) + "local a = 1" + rep(" end", limit)); LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(nullptr != get(result.errors[0])); } -TEST_CASE_FIXTURE(Fixture, "for_in_with_just_one_iterator_is_ok") +TEST_CASE_FIXTURE(Fixture, "check_expr_recursion_limit") { - CheckResult result = check(R"( - local function keys(dictionary) - local new = {} - local index = 1 +#if defined(LUAU_ENABLE_ASAN) + int limit = 250; +#elif defined(_DEBUG) || defined(_NOOPT) + int limit = 350; +#else + int limit = 600; +#endif + ScopedFastInt luauRecursionLimit{"LuauRecursionLimit", limit + 100}; + ScopedFastInt luauCheckRecursionLimit{"LuauCheckRecursionLimit", limit - 100}; - for key in pairs(dictionary) do - new[index] = key - index = index + 1 - end + CheckResult result = check(R"(("foo"))" + rep(":lower()", limit)); - return new - end + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK(nullptr != get(result.errors[0])); +} + +TEST_CASE_FIXTURE(Fixture, "globals") +{ + CheckResult result = check(R"( + --!nonstrict + foo = true + foo = "now i'm a string!" )"); LUAU_REQUIRE_NO_ERRORS(result); + CHECK_EQ("any", toString(requireType("foo"))); } -TEST_CASE_FIXTURE(Fixture, "for_in_with_a_custom_iterator_should_type_check") +TEST_CASE_FIXTURE(Fixture, "globals2") { CheckResult result = check(R"( - local function range(l, h): () -> number - return function() - return l - end - end - - for n: string in range(1, 10) do - print(n) - end + --!nonstrict + foo = function() return 1 end + foo = "now i'm a string!" )"); LUAU_REQUIRE_ERROR_COUNT(1, result); + + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ("() -> (...any)", toString(tm->wantedType)); + CHECK_EQ("string", toString(tm->givenType)); + CHECK_EQ("() -> (...any)", toString(requireType("foo"))); } -TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_error") +TEST_CASE_FIXTURE(Fixture, "globals_are_banned_in_strict_mode") { CheckResult result = check(R"( - function f(x) - gobble.prop = x.otherprop - end - - local p - for _, part in i_am_not_defined do - p = part - f(part) - part.thirdprop = false - end + --!strict + foo = true )"); - CHECK_EQ(2, result.errors.size()); + LUAU_REQUIRE_ERROR_COUNT(1, result); - TypeId p = requireType("p"); - CHECK_EQ("*unknown*", toString(p)); + UnknownSymbol* us = get(result.errors[0]); + REQUIRE(us); + CHECK_EQ("foo", us->name); } -TEST_CASE_FIXTURE(Fixture, "for_in_loop_on_non_function") +TEST_CASE_FIXTURE(Fixture, "globals_everywhere") { CheckResult result = check(R"( - local bad_iter = 5 + --!nonstrict + foo = 1 - for a in bad_iter() do + if true then + bar = 2 end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); + LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE(get(result.errors[0])); + CHECK_EQ("any", toString(requireType("foo"))); + CHECK_EQ("any", toString(requireType("bar"))); } -TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_factory_not_returning_the_right_amount_of_values") +TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_do") { CheckResult result = check(R"( - local function hasDivisors(value: number, table) - return false + do + local a = 1 end - function prime_iter(state, index) - while hasDivisors(index, state) do - index += 1 - end - - state[index] = true - return index - end - - function primes1() - return prime_iter, {} - end - - function primes2() - return prime_iter, {}, "" - end - - function primes3() - return prime_iter, {}, 2 - end - - for p in primes1() do print(p) end -- mismatch in argument count - - for p in primes2() do print(p) end -- mismatch in argument types, prime_iter takes {}, number, we are given {}, string - - for p in primes3() do print(p) end -- no error - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - - CountMismatch* acm = get(result.errors[0]); - REQUIRE(acm); - CHECK_EQ(acm->context, CountMismatch::Arg); - CHECK_EQ(2, acm->expected); - CHECK_EQ(1, acm->actual); - - TypeMismatch* tm = get(result.errors[1]); - REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_error_on_iterator_requiring_args_but_none_given") -{ - CheckResult result = check(R"( - function prime_iter(state, index) - return 1 - end - - for p in prime_iter do print(p) end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CountMismatch* acm = get(result.errors[0]); - REQUIRE(acm); - CHECK_EQ(acm->context, CountMismatch::Arg); - CHECK_EQ(2, acm->expected); - CHECK_EQ(0, acm->actual); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any") -{ - CheckResult result = check(R"( - function bar(): any - return true - end - - local a - for b in bar do - a = b - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(typeChecker.anyType, requireType("a")); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_returns_any2") -{ - CheckResult result = check(R"( - function bar(): any - return true - end - - local a - for b in bar() do - a = b - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("any", toString(requireType("a"))); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any") -{ - CheckResult result = check(R"( - local bar: any - - local a - for b in bar do - a = b - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("any", toString(requireType("a"))); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_any2") -{ - CheckResult result = check(R"( - local bar: any - - local a - for b in bar() do - a = b - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("any", toString(requireType("a"))); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error") -{ - CheckResult result = check(R"( - local a - for b in bar do - a = b - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ("*unknown*", toString(requireType("a"))); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_iterator_is_error2") -{ - CheckResult result = check(R"( - function bar(c) return c end - - local a - for b in bar() do - a = b - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ("*unknown*", toString(requireType("a"))); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_with_custom_iterator") -{ - CheckResult result = check(R"( - function primes() - return function (state: number) end, 2 - end - - for p, q in primes do - q = "" - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); -} - -TEST_CASE_FIXTURE(Fixture, "if_statement") -{ - CheckResult result = check(R"( - local a - local b - - if true then - a = 'hello' - else - b = 999 - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.stringType, *requireType("a")); - CHECK_EQ(*typeChecker.numberType, *requireType("b")); -} - -TEST_CASE_FIXTURE(Fixture, "while_loop") -{ - CheckResult result = check(R"( - local i - while true do - i = 8 - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.numberType, *requireType("i")); -} - -TEST_CASE_FIXTURE(Fixture, "repeat_loop") -{ - CheckResult result = check(R"( - local i - repeat - i = 'hi' - until true - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(*typeChecker.stringType, *requireType("i")); -} - -TEST_CASE_FIXTURE(Fixture, "repeat_loop_condition_binds_to_its_block") -{ - CheckResult result = check(R"( - repeat - local x = true - until x - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "symbols_in_repeat_block_should_not_be_visible_beyond_until_condition") -{ - CheckResult result = check(R"( - repeat - local x = true - until x - - print(x) - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "table_length") -{ - CheckResult result = check(R"( - local t = {} - local s = #t - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK(nullptr != get(requireType("t"))); - CHECK_EQ(*typeChecker.numberType, *requireType("s")); -} - -TEST_CASE_FIXTURE(Fixture, "string_length") -{ - CheckResult result = check(R"( - local s = "Hello, World!" - local t = #s - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number", toString(requireType("t"))); -} - -TEST_CASE_FIXTURE(Fixture, "string_index") -{ - CheckResult result = check(R"( - local s = "Hello, World!" - local t = s[4] - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - NotATable* nat = get(result.errors[0]); - REQUIRE(nat); - CHECK_EQ("string", toString(nat->ty)); - - CHECK_EQ("*unknown*", toString(requireType("t"))); -} - -TEST_CASE_FIXTURE(Fixture, "length_of_error_type_does_not_produce_an_error") -{ - CheckResult result = check(R"( - local l = #this_is_not_defined - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "indexing_error_type_does_not_produce_an_error") -{ - CheckResult result = check(R"( - local originalReward = unknown.Parent.Reward:GetChildren()[1] - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_indexer") -{ - CheckResult result = check("local a = {} a[0] = 7 a[0] = nil"); - LUAU_REQUIRE_ERROR_COUNT(0, result); -} - -TEST_CASE_FIXTURE(Fixture, "wrong_assign_does_hit_indexer") -{ - CheckResult result = check("local a = {} a[0] = 7 a[0] = 't'"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 30}, Position{0, 33}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.stringType, - }})); -} - -TEST_CASE_FIXTURE(Fixture, "nil_assign_doesnt_hit_no_indexer") -{ - CheckResult result = check("local a = {a=1, b=2} a['a'] = nil"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{Position{0, 30}, Position{0, 33}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.nilType, - }})); -} - -TEST_CASE_FIXTURE(Fixture, "dot_on_error_type_does_not_produce_an_error") -{ - CheckResult result = check(R"( - local foo = (true).x - foo.x = foo.y - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_not_defined_with_colon") -{ - CheckResult result = check(R"( - local someTable = {} - - someTable.Function1 = function(Arg1) - end - - someTable.Function1() -- Argument count mismatch - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE(get(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_it_wont_help_2") -{ - CheckResult result = check(R"( - local someTable = {} - - someTable.Function2 = function(Arg1, Arg2) - end - - someTable.Function2() -- Argument count mismatch - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE(get(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "dont_suggest_using_colon_rather_than_dot_if_another_overload_works") -{ - CheckResult result = check(R"( - type T = {method: ((T, number) -> number) & ((number) -> number)} - local T: T - - T.method(4) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "list_only_alternative_overloads_that_match_argument_count") -{ - CheckResult result = check(R"( - local multiply: ((number)->number) & ((number)->string) & ((number, number)->number) - multiply("") - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); - - ExtraInformation* ei = get(result.errors[1]); - REQUIRE(ei); - CHECK_EQ("Other overloads are also not viable: (number) -> string", ei->message); -} - -TEST_CASE_FIXTURE(Fixture, "list_all_overloads_if_no_overload_takes_given_argument_count") -{ - CheckResult result = check(R"( - local multiply: ((number)->number) & ((number)->string) & ((number, number)->number) - multiply() - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - - GenericError* ge = get(result.errors[0]); - REQUIRE(ge); - CHECK_EQ("No overload for function accepts 0 arguments.", ge->message); - - ExtraInformation* ei = get(result.errors[1]); - REQUIRE(ei); - CHECK_EQ("Available overloads: (number) -> number; (number) -> string; and (number, number) -> number", ei->message); -} - -TEST_CASE_FIXTURE(Fixture, "dont_give_other_overloads_message_if_only_one_argument_matching_overload_exists") -{ - CheckResult result = check(R"( - local multiply: ((number)->number) & ((number)->string) & ((number, number)->number) - multiply(1, "") - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ(typeChecker.stringType, tm->givenType); -} - -TEST_CASE_FIXTURE(Fixture, "infer_return_type_from_selected_overload") -{ - CheckResult result = check(R"( - type T = {method: ((T, number) -> number) & ((number) -> string)} - local T: T - - local a = T.method(T, 4) - local b = T.method(5) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("number", toString(requireType("a"))); - CHECK_EQ("string", toString(requireType("b"))); -} - -TEST_CASE_FIXTURE(Fixture, "too_many_arguments") -{ - CheckResult result = check(R"( - --!nonstrict - - function g(a: number) end - - g() - - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - auto err = result.errors[0]; - auto acm = get(err); - REQUIRE(acm); - - CHECK_EQ(1, acm->expected); - CHECK_EQ(0, acm->actual); -} - -TEST_CASE_FIXTURE(Fixture, "any_type_propagates") -{ - CheckResult result = check(R"( - local foo: any - local bar = foo:method("argument") - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("any", toString(requireType("bar"))); -} - -TEST_CASE_FIXTURE(Fixture, "can_subscript_any") -{ - CheckResult result = check(R"( - local foo: any - local bar = foo[5] - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("any", toString(requireType("bar"))); -} - -// Not strictly correct: metatables permit overriding this -TEST_CASE_FIXTURE(Fixture, "can_get_length_of_any") -{ - CheckResult result = check(R"( - local foo: any = {} - local bar = #foo - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(PrimitiveTypeVar::Number, getPrimitiveType(requireType("bar"))); -} - -TEST_CASE_FIXTURE(Fixture, "recursive_function") -{ - CheckResult result = check(R"( - function count(n: number) - if n == 0 then - return 0 - else - return count(n - 1) - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "lambda_form_of_local_function_cannot_be_recursive") -{ - CheckResult result = check(R"( - local f = function() return f() end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "recursive_local_function") -{ - CheckResult result = check(R"( - local function count(n: number) - if n == 0 then - return 0 - else - return count(n - 1) - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -// FIXME: This and the above case get handled very differently. It's pretty dumb. -// We really should unify the two code paths, probably by deleting AstStatFunction. -TEST_CASE_FIXTURE(Fixture, "another_recursive_local_function") -{ - CheckResult result = check(R"( - local count - function count(n: number) - if n == 0 then - return 0 - else - return count(n - 1) - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_rets") -{ - CheckResult result = check(R"( - function f() - return f - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("t1 where t1 = () -> t1", toString(requireType("f"))); -} - -TEST_CASE_FIXTURE(Fixture, "cyclic_function_type_in_args") -{ - CheckResult result = check(R"( - function f(g) - return f(f) - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("t1 where t1 = (t1) -> ()", toString(requireType("f"))); -} - -// TODO: File a Jira about this -/* -TEST_CASE_FIXTURE(Fixture, "unifying_vararg_pack_with_fixed_length_pack_produces_fixed_length_pack") -{ - CheckResult result = check(R"( - function a(x) return 1 end - a(...) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - REQUIRE(bool(getMainModule()->getModuleScope()->varargPack)); - - TypePackId varargPack = *getMainModule()->getModuleScope()->varargPack; - - auto iter = begin(varargPack); - auto endIter = end(varargPack); - - CHECK(iter != endIter); - ++iter; - CHECK(iter == endIter); - - CHECK(!iter.tail()); -} -*/ - -TEST_CASE_FIXTURE(Fixture, "method_depends_on_table") -{ - CheckResult result = check(R"( - -- This catches a bug where x:m didn't count as a use of x - -- so toposort would happily reorder a definition of - -- function x:m before the definition of x. - function g() f() end - local x = {} - function x:m() end - function f() x:m() end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "another_higher_order_function") -{ - CheckResult result = check(R"( - local Get_des - function Get_des(func) - Get_des(func) - end - - local function f(d) - d:IsA("BasePart") - d.Parent:FindFirstChild("Humanoid") - d:IsA("Decal") - end - Get_des(f) - - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "another_other_higher_order_function") -{ - CheckResult result = check(R"( - local d - d:foo() - d:foo() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "statements_are_topologically_sorted") -{ - CheckResult result = check(R"( - function foo() - return bar(999), bar("hi") - end - - function bar(i) - return i - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); -} - -TEST_CASE_FIXTURE(Fixture, "generic_table_method") -{ - CheckResult result = check(R"( - local T = {} - - function T:bar(i) - return i - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - TypeId tType = requireType("T"); - TableTypeVar* tTable = getMutable(tType); - REQUIRE(tTable != nullptr); - - TypeId barType = tTable->props["bar"].type; - REQUIRE(barType != nullptr); - - const FunctionTypeVar* ftv = get(follow(barType)); - REQUIRE_MESSAGE(ftv != nullptr, "Should be a function: " << *barType); - - std::vector args = flatten(ftv->argTypes).first; - TypeId argType = args.at(1); - - CHECK_MESSAGE(get(argType), "Should be generic: " << *barType); -} - -TEST_CASE_FIXTURE(Fixture, "correctly_instantiate_polymorphic_member_functions") -{ - CheckResult result = check(R"( - local T = {} - - function T:foo() - return T:bar(5) - end - - function T:bar(i) - return i - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); - - const TableTypeVar* t = get(requireType("T")); - REQUIRE(t != nullptr); - - std::optional fooProp = get(t->props, "foo"); - REQUIRE(bool(fooProp)); - - const FunctionTypeVar* foo = get(follow(fooProp->type)); - REQUIRE(bool(foo)); - - std::optional ret_ = first(foo->retType); - REQUIRE(bool(ret_)); - TypeId ret = follow(*ret_); - - REQUIRE_EQ(getPrimitiveType(ret), PrimitiveTypeVar::Number); -} - -TEST_CASE_FIXTURE(Fixture, "methods_are_topologically_sorted") -{ - CheckResult result = check(R"( - local T = {} - - function T:foo() - return T:bar(999), T:bar("hi") - end - - function T:bar(i) - return i - end - - local a, b = T:foo() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); - - CHECK_EQ(PrimitiveTypeVar::Number, getPrimitiveType(requireType("a"))); - CHECK_EQ(PrimitiveTypeVar::String, getPrimitiveType(requireType("b"))); -} - -TEST_CASE_FIXTURE(Fixture, "local_function") -{ - CheckResult result = check(R"( - function f() - return 8 - end - - function g() - local function f() - return 'hello' - end - return f - end - - local h = g() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - TypeId h = follow(requireType("h")); - - const FunctionTypeVar* ftv = get(h); - REQUIRE(ftv != nullptr); - - std::optional rt = first(ftv->retType); - REQUIRE(bool(rt)); - - TypeId retType = follow(*rt); - CHECK_EQ(PrimitiveTypeVar::String, getPrimitiveType(retType)); -} - -TEST_CASE_FIXTURE(Fixture, "unify_nearly_identical_recursive_types") -{ - CheckResult result = check(R"( - local o - o:method() - - local p - p:method() - - o = p - )"); -} - -/* - * We had a bug in instantiation where the argument types of 'f' and 'g' would be inferred as - * f {+ method: function(): (t2, T3...) +} - * g {+ method: function({+ method: function(): (t2, T3...) +}): (t5, T6...) +} - * - * The type of 'g' is totally wrong as t2 and t5 should be unified, as should T3 with T6. - * - * The correct unification of the argument to 'g' is - * - * {+ method: function(): (t5, T6...) +} - */ -TEST_CASE_FIXTURE(Fixture, "instantiate_cyclic_generic_function") -{ - auto result = check(R"( - function f(o) - o:method() - end - - function g(o) - f(o) - end - )"); - - TypeId g = requireType("g"); - const FunctionTypeVar* gFun = get(g); - REQUIRE(gFun != nullptr); - - auto optionArg = first(gFun->argTypes); - REQUIRE(bool(optionArg)); - - TypeId arg = follow(*optionArg); - const TableTypeVar* argTable = get(arg); - REQUIRE(argTable != nullptr); - - std::optional methodProp = get(argTable->props, "method"); - REQUIRE(bool(methodProp)); - - const FunctionTypeVar* methodFunction = get(methodProp->type); - REQUIRE(methodFunction != nullptr); - - std::optional methodArg = first(methodFunction->argTypes); - REQUIRE(bool(methodArg)); - - REQUIRE_EQ(follow(*methodArg), follow(arg)); -} - -TEST_CASE_FIXTURE(Fixture, "varlist_declared_by_for_in_loop_should_be_free") -{ - CheckResult result = check(R"( - local T = {} - - function T.f(p) - for i, v in pairs(p) do - T.f(v) - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "properly_infer_iteratee_is_a_free_table") -{ - // In this case, we cannot know the element type of the table {}. It could be anything. - // We therefore must initially ascribe a free typevar to iter. - CheckResult result = check(R"( - for iter in pairs({}) do - iter:g().p = true - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "quantify_methods_defined_using_dot_syntax_and_explicit_self_parameter") -{ - check(R"( - local T = {} - - function T.method(self) - self:method() - end - - function T.method2(self) - self:method() - end - - T:method2() - )"); -} - -TEST_CASE_FIXTURE(Fixture, "free_rhs_table_can_also_be_bound") -{ - check(R"( - local o - local v = o:i() - - function g(u) - v = u - end - - o:f(g) - o:h() - o:h() - )"); -} - -TEST_CASE_FIXTURE(Fixture, "require") -{ - fileResolver.source["game/A"] = R"( - local function hooty(x: number): string - return "Hi there!" - end - - return {hooty=hooty} - )"; - - fileResolver.source["game/B"] = R"( - local Hooty = require(game.A) - - local h -- free! - local i = Hooty.hooty(h) - )"; - - CheckResult aResult = frontend.check("game/A"); - dumpErrors(aResult); - LUAU_REQUIRE_NO_ERRORS(aResult); - - CheckResult bResult = frontend.check("game/B"); - dumpErrors(bResult); - LUAU_REQUIRE_NO_ERRORS(bResult); - - ModulePtr b = frontend.moduleResolver.modules["game/B"]; - - REQUIRE(b != nullptr); - - dumpErrors(bResult); - - std::optional iType = requireType(b, "i"); - REQUIRE_EQ("string", toString(*iType)); - - std::optional hType = requireType(b, "h"); - REQUIRE_EQ("number", toString(*hType)); -} - -TEST_CASE_FIXTURE(Fixture, "require_types") -{ - fileResolver.source["workspace/A"] = R"( - export type Point = {x: number, y: number} - - return {} - )"; - - fileResolver.source["workspace/B"] = R"( - local Hooty = require(workspace.A) - - local h: Hooty.Point - )"; - - CheckResult bResult = frontend.check("workspace/B"); - dumpErrors(bResult); - - ModulePtr b = frontend.moduleResolver.modules["workspace/B"]; - REQUIRE(b != nullptr); - - TypeId hType = requireType(b, "h"); - REQUIRE_MESSAGE(bool(get(hType)), "Expected table but got " << toString(hType)); -} - -TEST_CASE_FIXTURE(Fixture, "require_a_variadic_function") -{ - fileResolver.source["game/A"] = R"( - local T = {} - function T.f(...) end - return T - )"; - - fileResolver.source["game/B"] = R"( - local A = require(game.A) - local f = A.f - )"; - - CheckResult result = frontend.check("game/B"); - - ModulePtr bModule = frontend.moduleResolver.getModule("game/B"); - REQUIRE(bModule != nullptr); - - TypeId f = follow(requireType(bModule, "f")); - - const FunctionTypeVar* ftv = get(f); - REQUIRE(ftv); - - auto iter = begin(ftv->argTypes); - auto endIter = end(ftv->argTypes); - - REQUIRE(iter == endIter); - REQUIRE(iter.tail()); - - CHECK(get(*iter.tail())); -} - -TEST_CASE_FIXTURE(Fixture, "assign_prop_to_table_by_calling_any_yields_any") -{ - CheckResult result = check(R"( - local f: any - local T = {} - - T.prop = f() - - return T - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - TableTypeVar* ttv = getMutable(requireType("T")); - REQUIRE(ttv); - REQUIRE(ttv->props.count("prop")); - - REQUIRE_EQ("any", toString(ttv->props["prop"].type)); -} - -TEST_CASE_FIXTURE(Fixture, "type_error_of_unknown_qualified_type") -{ - CheckResult result = check(R"( - local p: SomeModule.DoesNotExist - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - REQUIRE_EQ(result.errors[0], (TypeError{Location{{1, 17}, {1, 40}}, UnknownSymbol{"SomeModule.DoesNotExist"}})); -} - -TEST_CASE_FIXTURE(Fixture, "require_module_that_does_not_export") -{ - const std::string sourceA = R"( - )"; - - const std::string sourceB = R"( - local Hooty = require(script.Parent.A) - )"; - - fileResolver.source["game/Workspace/A"] = sourceA; - fileResolver.source["game/Workspace/B"] = sourceB; - - frontend.check("game/Workspace/A"); - frontend.check("game/Workspace/B"); - - ModulePtr aModule = frontend.moduleResolver.modules["game/Workspace/A"]; - ModulePtr bModule = frontend.moduleResolver.modules["game/Workspace/B"]; - - CHECK(aModule->errors.empty()); - REQUIRE_EQ(1, bModule->errors.size()); - CHECK_MESSAGE(get(bModule->errors[0]), "Should be IllegalRequire: " << toString(bModule->errors[0])); - - auto hootyType = requireType(bModule, "Hooty"); - - CHECK_EQ("*unknown*", toString(hootyType)); -} - -TEST_CASE_FIXTURE(Fixture, "warn_on_lowercase_parent_property") -{ - CheckResult result = check(R"( - local M = require(script.parent.DoesNotMatter) - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - auto ed = get(result.errors[0]); - REQUIRE(ed); - - REQUIRE_EQ("parent", ed->symbol); -} - -TEST_CASE_FIXTURE(Fixture, "quantify_any_does_not_bind_to_itself") -{ - CheckResult result = check(R"( - local A : any - function A.B() end - A:C() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - TypeId aType = requireType("A"); - CHECK_EQ(aType, typeChecker.anyType); -} - -TEST_CASE_FIXTURE(Fixture, "table_unifies_into_map") -{ - CheckResult result = check(R"( - local Instance: any - local UDim2: any - - function Create(instanceType) - return function(data) - local obj = Instance.new(instanceType) - for k, v in pairs(data) do - if type(k) == 'number' then - --v.Parent = obj - else - obj[k] = v - end - end - return obj - end - end - - local topbarShadow = Create'ImageLabel'{ - Name = "TopBarShadow"; - Size = UDim2.new(1, 0, 0, 3); - Position = UDim2.new(0, 0, 1, 0); - Image = "rbxasset://textures/ui/TopBar/dropshadow.png"; - BackgroundTransparency = 1; - Active = false; - Visible = false; - }; - - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "func_expr_doesnt_leak_free") -{ - CheckResult result = check(R"( - local p = function(x) return x end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - const Luau::FunctionTypeVar* fn = get(requireType("p")); - REQUIRE(fn); - auto ret = first(fn->retType); - REQUIRE(ret); - REQUIRE(get(follow(*ret))); -} - -TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments") -{ - CheckResult result = check(R"( - function foo(a, b) - return a(b) - end - - function bar() - local c: ((number)->number, number)->number = foo -- no error - c = foo -- no error - local d: ((number)->number, string)->number = foo -- error from arg 2 (string) not being convertable to number from the call a(b) - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("((number) -> number, string) -> number", toString(tm->wantedType)); - CHECK_EQ("((number) -> number, number) -> number", toString(tm->givenType)); -} - -TEST_CASE_FIXTURE(Fixture, "instantiate_generic_function_in_assignments2") -{ - CheckResult result = check(R"( - function foo(a, b) - return a(b) - end - - function bar() - local _: (string, string)->number = foo -- string cannot be converted to (string)->number - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("(string, string) -> number", toString(tm->wantedType)); - CHECK_EQ("((string) -> number, string) -> number", toString(*tm->givenType)); -} - -TEST_CASE_FIXTURE(Fixture, "string_method") -{ - CheckResult result = check(R"( - local p = ("tacos"):len() - )"); - CHECK_EQ(0, result.errors.size()); - - CHECK_EQ(*requireType("p"), *typeChecker.numberType); -} - -TEST_CASE_FIXTURE(Fixture, "string_function_indirect") -{ - CheckResult result = check(R"( - local s:string - local l = s.lower - local p = l(s) - )"); - CHECK_EQ(0, result.errors.size()); - - CHECK_EQ(*requireType("p"), *typeChecker.stringType); -} - -TEST_CASE_FIXTURE(Fixture, "string_function_other") -{ - CheckResult result = check(R"( - local s:string - local p = s:match("foo") - )"); - CHECK_EQ(0, result.errors.size()); - - CHECK_EQ(toString(requireType("p")), "string?"); -} - -TEST_CASE_FIXTURE(Fixture, "weird_case") -{ - CheckResult result = check(R"( - local function f() return 4 end - local d = math.deg(f()) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); -} - -TEST_CASE_FIXTURE(Fixture, "or_joins_types") -{ - CheckResult result = check(R"( - local s = "a" or 10 - local x:string|number = s - )"); - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(*requireType("s")), "number | string"); - CHECK_EQ(toString(*requireType("x")), "number | string"); -} - -TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_extras") -{ - CheckResult result = check(R"( - local s = "a" or 10 - local x:number|string = s - local y = x or "s" - )"); - CHECK_EQ(0, result.errors.size()); - CHECK_EQ(toString(*requireType("s")), "number | string"); - CHECK_EQ(toString(*requireType("y")), "number | string"); -} - -TEST_CASE_FIXTURE(Fixture, "or_joins_types_with_no_superfluous_union") -{ - CheckResult result = check(R"( - local s = "a" or "b" - local x:string = s - )"); - CHECK_EQ(0, result.errors.size()); - CHECK_EQ(*requireType("s"), *typeChecker.stringType); -} - -TEST_CASE_FIXTURE(Fixture, "and_adds_boolean") -{ - CheckResult result = check(R"( - local s = "a" and 10 - local x:boolean|number = s - )"); - CHECK_EQ(0, result.errors.size()); - CHECK_EQ(toString(*requireType("s")), "boolean | number"); -} - -TEST_CASE_FIXTURE(Fixture, "and_adds_boolean_no_superfluous_union") -{ - CheckResult result = check(R"( - local s = "a" and true - local x:boolean = s - )"); - CHECK_EQ(0, result.errors.size()); - CHECK_EQ(*requireType("x"), *typeChecker.booleanType); -} - -TEST_CASE_FIXTURE(Fixture, "and_or_ternary") -{ - CheckResult result = check(R"( - local s = (1/2) > 0.5 and "a" or 10 - )"); - CHECK_EQ(0, result.errors.size()); - CHECK_EQ(toString(*requireType("s")), "number | string"); -} - -TEST_CASE_FIXTURE(Fixture, "first_argument_can_be_optional") -{ - CheckResult result = check(R"( - local T = {} - function T.new(a: number?, b: number?, c: number?) return 5 end - local m = T.new() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); -} - -TEST_CASE_FIXTURE(Fixture, "dont_ice_when_failing_the_occurs_check") -{ - CheckResult result = check(R"( - --!strict - local s - s(s, 'a') - )"); - LUAU_REQUIRE_ERROR_COUNT(0, result); -} - -TEST_CASE_FIXTURE(Fixture, "occurs_check_does_not_recurse_forever_if_asked_to_traverse_a_cyclic_type") -{ - CheckResult result = check(R"( - --!strict - function u(t, w) - u(u, t) - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -#if 0 -// CLI-29798 -TEST_CASE_FIXTURE(Fixture, "crazy_complexity") -{ - CheckResult result = check(R"( - --!nonstrict - A:A():A():A():A():A():A():A():A():A():A():A() - )"); - - std::cout << "OK! Allocated " << typeChecker.typeVars.size() << " typevars" << std::endl; -} -#endif - -// We had a bug where a cyclic union caused a stack overflow. -// ex type U = number | U -TEST_CASE_FIXTURE(Fixture, "dont_allow_cyclic_unions_to_be_inferred") -{ - CheckResult result = check(R"( - --!strict - - function f(a, b) - a:g(b or {}) - a:g(b) - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "it_is_ok_not_to_supply_enough_retvals") -{ - CheckResult result = check(R"( - function get_two() return 5, 6 end - - local a = get_two() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); -} - -TEST_CASE_FIXTURE(Fixture, "duplicate_functions2") -{ - CheckResult result = check(R"( - function foo() end - - function bar() - local function foo() end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(0, result); -} - -TEST_CASE_FIXTURE(Fixture, "duplicate_functions_allowed_in_nonstrict") -{ - CheckResult result = check(R"( - --!nonstrict - function foo() end - - function foo() end - - function bar() - local function foo() end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "duplicate_functions_with_different_signatures_not_allowed_in_nonstrict") -{ - CheckResult result = check(R"( - --!nonstrict - function foo(): number - return 1 - end - foo() - - function foo(n: number): number - return 2 - end - foo() - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("() -> number", toString(tm->wantedType)); - CHECK_EQ("(number) -> number", toString(tm->givenType)); -} - -TEST_CASE_FIXTURE(Fixture, "tables_get_names_from_their_locals") -{ - CheckResult result = check(R"( - local T = {} - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("T", toString(requireType("T"))); -} - -TEST_CASE_FIXTURE(Fixture, "generalize_table_argument") -{ - CheckResult result = check(R"( - function foo(arr) - local work = {} - for i = 1, #arr do - work[i] = arr[i] - end - - return arr - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); - - const FunctionTypeVar* fooType = get(requireType("foo")); - REQUIRE(fooType); - - std::optional fooArg1 = first(fooType->argTypes); - REQUIRE(fooArg1); - - const TableTypeVar* fooArg1Table = get(*fooArg1); - REQUIRE(fooArg1Table); - - CHECK_EQ(fooArg1Table->state, TableState::Generic); -} - -TEST_CASE_FIXTURE(Fixture, "complicated_return_types_require_an_explicit_annotation") -{ - CheckResult result = check(R"( - local i = 0 - function most_of_the_natural_numbers(): number? - if i < 10 then - i = i + 1 - return i - else - return nil - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - const FunctionTypeVar* functionType = get(requireType("most_of_the_natural_numbers")); - - std::optional retType = first(functionType->retType); - REQUIRE(retType); - CHECK(get(*retType)); -} - -TEST_CASE_FIXTURE(Fixture, "infer_higher_order_function") -{ - CheckResult result = check(R"( - function apply(f, x) - return f(x) - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - const FunctionTypeVar* ftv = get(requireType("apply")); - REQUIRE(ftv != nullptr); - - std::vector argVec = flatten(ftv->argTypes).first; - - REQUIRE_EQ(2, argVec.size()); - - const FunctionTypeVar* fType = get(follow(argVec[0])); - REQUIRE(fType != nullptr); - - std::vector fArgs = flatten(fType->argTypes).first; - - TypeId xType = argVec[1]; - - CHECK_EQ(1, fArgs.size()); - CHECK_EQ(xType, fArgs[0]); -} - -TEST_CASE_FIXTURE(Fixture, "higher_order_function_2") -{ - CheckResult result = check(R"( - function bottomupmerge(comp, a, b, left, mid, right) - local i, j = left, mid - for k = left, right do - if i < mid and (j > right or not comp(a[j], a[i])) then - b[k] = a[i] - i = i + 1 - else - b[k] = a[j] - j = j + 1 - end - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - const FunctionTypeVar* ftv = get(requireType("bottomupmerge")); - REQUIRE(ftv != nullptr); - - std::vector argVec = flatten(ftv->argTypes).first; - - REQUIRE_EQ(6, argVec.size()); - - const FunctionTypeVar* fType = get(follow(argVec[0])); - REQUIRE(fType != nullptr); -} - -TEST_CASE_FIXTURE(Fixture, "higher_order_function_3") -{ - CheckResult result = check(R"( - function swap(p) - local t = p[0] - p[0] = p[1] - p[1] = t - return nil - end - - function swapTwice(p) - swap(p) - swap(p) - return p - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - const FunctionTypeVar* ftv = get(requireType("swapTwice")); - REQUIRE(ftv != nullptr); - - std::vector argVec = flatten(ftv->argTypes).first; - - REQUIRE_EQ(1, argVec.size()); - - const TableTypeVar* argType = get(follow(argVec[0])); - REQUIRE(argType != nullptr); - - CHECK(bool(argType->indexer)); -} - -TEST_CASE_FIXTURE(Fixture, "higher_order_function_4") -{ - CheckResult result = check(R"( - function bottomupmerge(comp, a, b, left, mid, right) - local i, j = left, mid - for k = left, right do - if i < mid and (j > right or not comp(a[j], a[i])) then - b[k] = a[i] - i = i + 1 - else - b[k] = a[j] - j = j + 1 - end - end - end - - function mergesort(arr, comp) - local work = {} - for i = 1, #arr do - work[i] = arr[i] - end - local width = 1 - while width < #arr do - for i = 1, #arr, 2*width do - bottomupmerge(comp, arr, work, i, math.min(i+width, #arr), math.min(i+2*width-1, #arr)) - end - local temp = work - work = arr - arr = temp - width = width * 2 - end - return arr - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); - - /* - * mergesort takes two arguments: an array of some type T and a function that takes two Ts. - * We must assert that these two types are in fact the same type. - * In other words, comp(arr[x], arr[y]) is well-typed. - */ - - const FunctionTypeVar* ftv = get(requireType("mergesort")); - REQUIRE(ftv != nullptr); - - std::vector argVec = flatten(ftv->argTypes).first; - - REQUIRE_EQ(2, argVec.size()); - - const TableTypeVar* arg0 = get(follow(argVec[0])); - REQUIRE(arg0 != nullptr); - REQUIRE(bool(arg0->indexer)); - - const FunctionTypeVar* arg1 = get(follow(argVec[1])); - REQUIRE(arg1 != nullptr); - REQUIRE_EQ(2, size(arg1->argTypes)); - - std::vector arg1Args = flatten(arg1->argTypes).first; - - CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[0]); - CHECK_EQ(*arg0->indexer->indexResultType, *arg1Args[1]); -} - -TEST_CASE_FIXTURE(Fixture, "type_errors_infer_types") -{ - CheckResult result = check(R"( - local err = (true).x - local c = err.Parent.Reward.GetChildren - local d = err.Parent.Reward - local e = err.Parent - local f = err - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - UnknownProperty* err = get(result.errors[0]); - REQUIRE(err != nullptr); - CHECK_EQ("boolean", toString(err->table)); - CHECK_EQ("x", err->key); - - CHECK_EQ("*unknown*", toString(requireType("c"))); - CHECK_EQ("*unknown*", toString(requireType("d"))); - CHECK_EQ("*unknown*", toString(requireType("e"))); - CHECK_EQ("*unknown*", toString(requireType("f"))); -} - -TEST_CASE_FIXTURE(Fixture, "calling_error_type_yields_error") -{ - CheckResult result = check(R"( - local a = unknown.Parent.Reward.GetChildren() - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - UnknownSymbol* err = get(result.errors[0]); - REQUIRE(err != nullptr); - - CHECK_EQ("unknown", err->name); - - CHECK_EQ("*unknown*", toString(requireType("a"))); -} - -TEST_CASE_FIXTURE(Fixture, "chain_calling_error_type_yields_error") -{ - CheckResult result = check(R"( - local a = Utility.Create "Foo" {} - )"); - - CHECK_EQ("*unknown*", toString(requireType("a"))); -} - -TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable") -{ - CheckResult result = check(R"( - function add(a: number, b: string) - return a + (tonumber(b) :: number), a .. b - end - local n, s = add(2,"3") - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - const FunctionTypeVar* functionType = get(requireType("add")); - - std::optional retType = first(functionType->retType); - CHECK_EQ(std::optional(typeChecker.numberType), retType); - CHECK_EQ(requireType("n"), typeChecker.numberType); - CHECK_EQ(requireType("s"), typeChecker.stringType); -} - -TEST_CASE_FIXTURE(Fixture, "primitive_arith_no_metatable_with_follows") -{ - CheckResult result = check(R"( - local PI=3.1415926535897931 - local SOLAR_MASS=4*PI * PI - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(requireType("SOLAR_MASS"), typeChecker.numberType); -} - -TEST_CASE_FIXTURE(Fixture, "primitive_arith_possible_metatable") -{ - CheckResult result = check(R"( - function add(a: number, b: any) - return a + b - end - local t = add(1,2) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("any", toString(requireType("t"))); -} - -TEST_CASE_FIXTURE(Fixture, "some_primitive_binary_ops") -{ - CheckResult result = check(R"( - local a = 4 + 8 - local b = a + 9 - local s = 'hotdogs' - local t = s .. s - local c = b - a - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("number", toString(requireType("a"))); - CHECK_EQ("number", toString(requireType("b"))); - CHECK_EQ("string", toString(requireType("s"))); - CHECK_EQ("string", toString(requireType("t"))); - CHECK_EQ("number", toString(requireType("c"))); -} - -TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection") -{ - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - - CheckResult result = check(R"( - --!strict - local Vec3 = {} - Vec3.__index = Vec3 - function Vec3.new() - return setmetatable({x=0, y=0, z=0}, Vec3) - end - - export type Vec3 = typeof(Vec3.new()) - - local thefun: any = function(self, o) return self end - - local multiply: ((Vec3, Vec3) -> Vec3) & ((Vec3, number) -> Vec3) = thefun - - Vec3.__mul = multiply - - local a = Vec3.new() - local b = Vec3.new() - local c = a * b - local d = a * 2 - local e = a * 'cabbage' - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ("Vec3", toString(requireType("a"))); - CHECK_EQ("Vec3", toString(requireType("b"))); - CHECK_EQ("Vec3", toString(requireType("c"))); - CHECK_EQ("Vec3", toString(requireType("d"))); - CHECK_EQ("Vec3", toString(requireType("e"))); -} - -TEST_CASE_FIXTURE(Fixture, "typecheck_overloaded_multiply_that_is_an_intersection_on_rhs") -{ - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - - CheckResult result = check(R"( - --!strict - local Vec3 = {} - Vec3.__index = Vec3 - function Vec3.new() - return setmetatable({x=0, y=0, z=0}, Vec3) - end - - export type Vec3 = typeof(Vec3.new()) - - local thefun: any = function(self, o) return self end - - local multiply: ((Vec3, Vec3) -> Vec3) & ((Vec3, number) -> Vec3) = thefun - - Vec3.__mul = multiply - - local a = Vec3.new() - local b = Vec3.new() - local c = b * a - local d = 2 * a - local e = 'cabbage' * a - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ("Vec3", toString(requireType("a"))); - CHECK_EQ("Vec3", toString(requireType("b"))); - CHECK_EQ("Vec3", toString(requireType("c"))); - CHECK_EQ("Vec3", toString(requireType("d"))); - CHECK_EQ("Vec3", toString(requireType("e"))); -} - -TEST_CASE_FIXTURE(Fixture, "compare_numbers") -{ - CheckResult result = check(R"( - local a = 441 - local b = 0 - local c = a < b - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "compare_strings") -{ - CheckResult result = check(R"( - local a = '441' - local b = '0' - local c = a < b - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_have_a_metatable") -{ - CheckResult result = check(R"( - local a = {} - local b = {} - local c = a < b - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - GenericError* gen = get(result.errors[0]); - - REQUIRE_EQ(gen->message, "Type a cannot be compared with < because it has no metatable"); -} - -TEST_CASE_FIXTURE(Fixture, "cannot_indirectly_compare_types_that_do_not_offer_overloaded_ordering_operators") -{ - CheckResult result = check(R"( - local M = {} - function M.new() - return setmetatable({}, M) - end - type M = typeof(M.new()) - - local a = M.new() - local b = M.new() - local c = a < b - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - GenericError* gen = get(result.errors[0]); - REQUIRE(gen != nullptr); - REQUIRE_EQ(gen->message, "Table M does not offer metamethod __lt"); -} - -TEST_CASE_FIXTURE(Fixture, "cannot_compare_tables_that_do_not_have_the_same_metatable") -{ - CheckResult result = check(R"( - --!strict - local M = {} - function M.new() - return setmetatable({}, M) - end - function M.__lt(left, right) return true end - - local a = M.new() - local b = {} - local c = a < b -- line 10 - local d = b < a -- line 11 - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - - REQUIRE_EQ((Location{{10, 18}, {10, 23}}), result.errors[0].location); - - REQUIRE_EQ((Location{{11, 18}, {11, 23}}), result.errors[1].location); -} - -TEST_CASE_FIXTURE(Fixture, "produce_the_correct_error_message_when_comparing_a_table_with_a_metatable_with_one_that_does_not") -{ - CheckResult result = check(R"( - --!strict - local M = {} - function M.new() - return setmetatable({}, M) - end - function M.__lt(left, right) return true end - type M = typeof(M.new()) - - local a = M.new() - local b = {} - local c = a < b -- line 10 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - auto err = get(result.errors[0]); - REQUIRE(err != nullptr); - - // Frail. :| - REQUIRE_EQ("Types M and b cannot be compared with < because they do not have the same metatable", err->message); -} - -TEST_CASE_FIXTURE(Fixture, "in_nonstrict_mode_strip_nil_from_intersections_when_considering_relational_operators") -{ - CheckResult result = check(R"( - --!nonstrict - - function maybe_a_number(): number? - return 50 - end - - local a = maybe_a_number() < maybe_a_number() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -/* - * This test case exposed an oversight in the treatment of free tables. - * Free tables, like free TypeVars, need to record the scope depth where they were created so that - * we do not erroneously let-generalize them when they are used in a nested lambda. - * - * For more information about let-generalization, see - * - * The important idea here is that the return type of Counter.new is a table with some metatable. - * That metatable *must* be the same TypeVar as the type of Counter. If it is a copy (produced by - * the generalization process), then it loses the knowledge that its metatable will have an :incr() - * method. - */ -TEST_CASE_FIXTURE(Fixture, "dont_quantify_table_that_belongs_to_outer_scope") -{ - CheckResult result = check(R"( - local Counter = {} - Counter.__index = Counter - - function Counter.new() - local self = setmetatable({count=0}, Counter) - return self - end - - function Counter:incr() - self.count = 1 - return self.count - end - - local self = Counter.new() - print(self:incr()) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - TableTypeVar* counterType = getMutable(requireType("Counter")); - REQUIRE(counterType); - - const FunctionTypeVar* newType = get(follow(counterType->props["new"].type)); - REQUIRE(newType); - - std::optional newRetType = *first(newType->retType); - REQUIRE(newRetType); - - const MetatableTypeVar* newRet = get(follow(*newRetType)); - REQUIRE(newRet); - - const TableTypeVar* newRetMeta = get(newRet->metatable); - REQUIRE(newRetMeta); - - CHECK(newRetMeta->props.count("incr")); - CHECK_EQ(follow(newRet->metatable), follow(requireType("Counter"))); -} - -// TODO: CLI-39624 -TEST_CASE_FIXTURE(Fixture, "instantiate_tables_at_scope_level") -{ - CheckResult result = check(R"( - --!strict - local Option = {} - Option.__index = Option - function Option.Is(obj) - return (type(obj) == "table" and getmetatable(obj) == Option) - end - return Option - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "typeguard_doesnt_leak_to_elseif") -{ - const std::string code = R"( - function f(a) - if type(a) == "boolean" then - local a1 = a - elseif a.fn() then - local a2 = a - else - local a3 = a - end - end - )"; - CheckResult result = check(code); - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); -} - -TEST_CASE_FIXTURE(Fixture, "should_be_able_to_infer_this_without_stack_overflowing") -{ - CheckResult result = check(R"( - local function f(x, y) - return x or y - end - - local function dont_crash(x, y) - local z: typeof(f(x, y)) = f(x, y) - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "inferring_hundreds_of_self_calls_should_not_suffocate_memory") -{ - CheckResult result = check(R"( - ("foo") - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - :lower() - )"); - - ModulePtr module = getMainModule(); - CHECK_GE(50, module->internalTypes.typeVars.size()); -} - -TEST_CASE_FIXTURE(Fixture, "inferring_crazy_table_should_also_be_quick") -{ - CheckResult result = check(R"( - --!strict - function f(U) - U(w:s(an):c()():c():U(s):c():c():U(s):c():U(s):cU()):c():U(s):c():U(s):c():c():U(s):c():U(s):cU() - end - )"); - - ModulePtr module = getMainModule(); - CHECK_GE(100, module->internalTypes.typeVars.size()); -} - -TEST_CASE_FIXTURE(Fixture, "exponential_blowup_from_copying_types") -{ - CheckResult result = check(R"( - --!strict - -- An example of exponential blowup in number of types - -- The problem is that if we define function f(a) return x end - -- then this has type (t)->T where x:T - -- *but* it copies T each time f is applied - -- so { left = f("hi"), right = f(5) } - -- has type { left : T_L, right : T_R } - -- where T_L and T_R are copies of T. - -- x0 : T0 where T0 = {} - local x0 = {} - -- f0 : (t)->T0 - local function f0(a) return x0 end - -- x1 : T1 where T1 = { left : T0_L, right : T0_R } - local x1 = { left = f0("hi"), right = f0(5) } - -- f1 : (t)->T1 - local function f1(a) return x1 end - -- x2 : T2 where T2 = { left : T1_L, right : T1_R } - local x2 = { left = f1("hi"), right = f1(5) } - -- f2 : (t)->T2 - local function f2(a) return x2 end - -- etc etc - local x3 = { left = f2("hi"), right = f2(5) } - local function f3(a) return x3 end - local x4 = { left = f3("hi"), right = f3(5) } - return x4 - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - ModulePtr module = getMainModule(); - - // If we're not careful about copying, this ends up with O(2^N) types rather than O(N) - // (in this case 5 vs 31). - CHECK_GE(5, module->interfaceTypes.typeVars.size()); -} - -TEST_CASE_FIXTURE(Fixture, "mutual_recursion") -{ - CheckResult result = check(R"( - --!strict - - function newPlayerCharacter() - startGui() -- Unknown symbol 'startGui' - end - - local characterAddedConnection: any - function startGui() - characterAddedConnection = game:GetService("Players").LocalPlayer.CharacterAdded:connect(newPlayerCharacter) - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); -} - -TEST_CASE_FIXTURE(Fixture, "toposort_doesnt_break_mutual_recursion") -{ - CheckResult result = check(R"( - --!strict - local x = nil - function f() g() end - -- make sure print(x) doesn't get toposorted here, breaking the mutual block - function g() x = f end - print(x) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - dumpErrors(result); -} - -TEST_CASE_FIXTURE(Fixture, "object_constructor_can_refer_to_method_of_self") -{ - // CLI-30902 - CheckResult result = check(R"( - --!strict - - type Foo = { - fooConn: () -> () | nil - } - - local Foo = {} - Foo.__index = Foo - - function Foo.new() - local self: Foo = { - fooConn = nil, - } - setmetatable(self, Foo) - - self.fooConn = function() - self:method() -- Key 'method' not found in table self - end - - return self - end - - function Foo:method() - print("foo") - end - - local foo = Foo.new() - - -- TODO This is the best our current refinement support can offer :( - local bar = foo.fooConn - if bar then bar() end - - -- foo.fooConn() - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "replace_every_free_type_when_unifying_a_complex_function_with_any") -{ - CheckResult result = check(R"( - local a: any - local b - for _, i in pairs(a) do - b = i - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("any", toString(requireType("b"))); -} - -// In these tests, a successful parse is required, so we need the parser to return the AST and then we can test the recursion depth limit in type -// checker. We also want it to somewhat match up with production values, so we push up the parser recursion limit a little bit instead. -TEST_CASE_FIXTURE(Fixture, "check_type_infer_recursion_limit") -{ -#if defined(LUAU_ENABLE_ASAN) - int limit = 250; -#elif defined(_DEBUG) || defined(_NOOPT) - int limit = 350; -#else - int limit = 600; -#endif - ScopedFastInt luauRecursionLimit{"LuauRecursionLimit", limit + 100}; - ScopedFastInt luauTypeInferRecursionLimit{"LuauTypeInferRecursionLimit", limit - 100}; - ScopedFastInt luauCheckRecursionLimit{"LuauCheckRecursionLimit", 0}; - - CHECK_NOTHROW(check("print('Hello!')")); - CHECK_THROWS_AS(check("function f() return " + rep("{a=", limit) + "'a'" + rep("}", limit) + " end"), std::runtime_error); -} - -TEST_CASE_FIXTURE(Fixture, "check_block_recursion_limit") -{ -#if defined(LUAU_ENABLE_ASAN) - int limit = 250; -#elif defined(_DEBUG) || defined(_NOOPT) - int limit = 350; -#else - int limit = 600; -#endif - - ScopedFastInt luauRecursionLimit{"LuauRecursionLimit", limit + 100}; - ScopedFastInt luauCheckRecursionLimit{"LuauCheckRecursionLimit", limit - 100}; - - CheckResult result = check(rep("do ", limit) + "local a = 1" + rep(" end", limit)); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(nullptr != get(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "check_expr_recursion_limit") -{ -#if defined(LUAU_ENABLE_ASAN) - int limit = 250; -#elif defined(_DEBUG) || defined(_NOOPT) - int limit = 350; -#else - int limit = 600; -#endif - ScopedFastInt luauRecursionLimit{"LuauRecursionLimit", limit + 100}; - ScopedFastInt luauCheckRecursionLimit{"LuauCheckRecursionLimit", limit - 100}; - - CheckResult result = check(R"(("foo"))" + rep(":lower()", limit)); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK(nullptr != get(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "compound_assign_basic") -{ - CheckResult result = check(R"( - local s = 10 - s += 20 - )"); - CHECK_EQ(0, result.errors.size()); - CHECK_EQ(toString(*requireType("s")), "number"); -} - -TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_op") -{ - CheckResult result = check(R"( - local s = 10 - s += true - )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0], (TypeError{Location{{2, 13}, {2, 17}}, TypeMismatch{typeChecker.numberType, typeChecker.booleanType}})); -} - -TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_result") -{ - CheckResult result = check(R"( - local s = 'hello' - s += 10 - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(result.errors[0], (TypeError{Location{{2, 8}, {2, 9}}, TypeMismatch{typeChecker.numberType, typeChecker.stringType}})); - CHECK_EQ(result.errors[1], (TypeError{Location{{2, 8}, {2, 15}}, TypeMismatch{typeChecker.stringType, typeChecker.numberType}})); -} - -TEST_CASE_FIXTURE(Fixture, "compound_assign_metatable") -{ - CheckResult result = check(R"( - --!strict - type V2B = { x: number, y: number } - local v2b: V2B = { x = 0, y = 0 } - local VMT = {} - type V2 = typeof(setmetatable(v2b, VMT)) - - function VMT.__add(a: V2, b: V2): V2 - return setmetatable({ x = a.x + b.x, y = a.y + b.y }, VMT) - end - - local v1: V2 = setmetatable({ x = 1, y = 2 }, VMT) - local v2: V2 = setmetatable({ x = 3, y = 4 }, VMT) - v1 += v2 - )"); - CHECK_EQ(0, result.errors.size()); -} - -TEST_CASE_FIXTURE(Fixture, "compound_assign_mismatch_metatable") -{ - CheckResult result = check(R"( - --!strict - type V2B = { x: number, y: number } - local v2b: V2B = { x = 0, y = 0 } - local VMT = {} - type V2 = typeof(setmetatable(v2b, VMT)) - - function VMT.__mod(a: V2, b: V2): number - return a.x * b.x + a.y * b.y - end - - local v1: V2 = setmetatable({ x = 1, y = 2 }, VMT) - local v2: V2 = setmetatable({ x = 3, y = 4 }, VMT) - v1 %= v2 - )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - CHECK_EQ(*tm->wantedType, *requireType("v2")); - CHECK_EQ(*tm->givenType, *typeChecker.numberType); -} - -TEST_CASE_FIXTURE(Fixture, "dont_ice_if_a_TypePack_is_an_error") -{ - CheckResult result = check(R"( - --!strict - function f(s) - print(s) - return f - end - - f("foo")("bar") - )"); -} - -TEST_CASE_FIXTURE(Fixture, "check_function_before_lambda_that_uses_it") -{ - CheckResult result = check(R"( - --!nonstrict - - function f() - return 114 - end - - return function() - return f():andThen() - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "it_is_ok_to_oversaturate_a_higher_order_function_argument") -{ - CheckResult result = check(R"( - function onerror() end - function foo() end - xpcall(foo, onerror) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "another_indirect_function_case_where_it_is_ok_to_provide_too_many_arguments") -{ - CheckResult result = check(R"( - local mycb: (number, number) -> () - - function f() end - - mycb = f - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "call_to_any_yields_any") -{ - CheckResult result = check(R"( - local a: any - local b = a() - )"); - - REQUIRE_EQ("any", toString(requireType("b"))); -} - -TEST_CASE_FIXTURE(Fixture, "globals") -{ - CheckResult result = check(R"( - --!nonstrict - foo = true - foo = "now i'm a string!" - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ("any", toString(requireType("foo"))); -} - -TEST_CASE_FIXTURE(Fixture, "globals2") -{ - CheckResult result = check(R"( - --!nonstrict - foo = function() return 1 end - foo = "now i'm a string!" - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ("() -> (...any)", toString(tm->wantedType)); - CHECK_EQ("string", toString(tm->givenType)); - CHECK_EQ("() -> (...any)", toString(requireType("foo"))); -} - -TEST_CASE_FIXTURE(Fixture, "globals_are_banned_in_strict_mode") -{ - CheckResult result = check(R"( - --!strict - foo = true - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - UnknownSymbol* us = get(result.errors[0]); - REQUIRE(us); - CHECK_EQ("foo", us->name); -} - -TEST_CASE_FIXTURE(Fixture, "globals_everywhere") -{ - CheckResult result = check(R"( - --!nonstrict - foo = 1 - - if true then - bar = 2 - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("any", toString(requireType("foo"))); - CHECK_EQ("any", toString(requireType("bar"))); -} - -TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfAny") -{ - CheckResult result = check(R"( -local x: any = {} -function x:y(z: number) - local s: string = z -end -)"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfSealed") -{ - CheckResult result = check(R"( -local x: {prop: number} = {prop=9999} -function x:y(z: number) - local s: string = z -end -)"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); -} - -TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfNumber") -{ - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - - CheckResult result = check(R"( -local x: number = 9999 -function x:y(z: number) - local s: string = z -end -)"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); -} - -TEST_CASE_FIXTURE(Fixture, "CheckMethodsOfError") -{ - CheckResult result = check(R"( -local x = (true).foo -function x:y(z: number) - local s: string = z -end -)"); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "CallOrOfFunctions") -{ - CheckResult result = check(R"( -function f() return 1; end -function g() return 2; end -(f or g)() -)"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "CallAndOrOfFunctions") -{ - CheckResult result = check(R"( -function f() return 1; end -function g() return 2; end -local x = false -(x and f or g)() -)"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "MixedPropertiesAndIndexers") -{ - CheckResult result = check(R"( -local x = {} -x.a = "a" -x[0] = true -x.b = 37 -)"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_do") -{ - CheckResult result = check(R"( - do - local a = 1 - end - - print(a) -- oops! - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - UnknownSymbol* us = get(result.errors[0]); - REQUIRE(us); - CHECK_EQ(us->name, "a"); -} - -TEST_CASE_FIXTURE(Fixture, "correctly_scope_locals_while") -{ - CheckResult result = check(R"( - while true do - local a = 1 - end - - print(a) -- oops! - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - UnknownSymbol* us = get(result.errors[0]); - REQUIRE(us); - CHECK_EQ(us->name, "a"); -} - -TEST_CASE_FIXTURE(Fixture, "ipairs_produces_integral_indices") -{ - CheckResult result = check(R"( - local key - for i, e in ipairs({}) do key = i end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - REQUIRE_EQ("number", toString(requireType("key"))); -} - -TEST_CASE_FIXTURE(Fixture, "checking_should_not_ice") -{ - CHECK_NOTHROW(check(R"( - --!nonstrict - f,g = ... - f(g(...))[...] = nil - f,xpcall = ... - local value = g(...)(g(...)) - )")); - - CHECK_EQ("any", toString(requireType("value"))); -} - -TEST_CASE_FIXTURE(Fixture, "report_exiting_without_return_nonstrict") -{ - CheckResult result = check(R"( - --!nonstrict - - local function f1(v): number? - if v then - return 1 - end - end - - local function f2(v) - if v then - return 1 - end - end - - local function f3(v): () - if v then - return - end - end - - local function f4(v) - if v then - return - end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - FunctionExitsWithoutReturning* err = get(result.errors[0]); - CHECK(err); -} - -TEST_CASE_FIXTURE(Fixture, "report_exiting_without_return_strict") -{ - CheckResult result = check(R"( - --!strict - - local function f1(v): number? - if v then - return 1 - end - end - - local function f2(v) - if v then - return 1 - end - end - - local function f3(v): () - if v then - return - end - end - - local function f4(v) - if v then - return - end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - FunctionExitsWithoutReturning* annotatedErr = get(result.errors[0]); - CHECK(annotatedErr); - - FunctionExitsWithoutReturning* inferredErr = get(result.errors[1]); - CHECK(inferredErr); -} - -// TEST_CASE_FIXTURE(Fixture, "infer_method_signature_of_argument") -// { -// CheckResult result = check(R"( -// function f(a) -// if a.cond then -// return a.method() -// end -// end -// )"); - -// LUAU_REQUIRE_NO_ERRORS(result); - -// CHECK_EQ("A", toString(requireType("f"))); -// } - -TEST_CASE_FIXTURE(Fixture, "warn_if_you_try_to_require_a_non_modulescript") -{ - fileResolver.source["Modules/A"] = ""; - fileResolver.sourceTypes["Modules/A"] = SourceCode::Local; - - fileResolver.source["Modules/B"] = R"( - local M = require(script.Parent.A) - )"; - - CheckResult result = frontend.check("Modules/B"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK(get(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "calling_function_with_incorrect_argument_type_yields_errors_spanning_argument") -{ - CheckResult result = check(R"( - function foo(a: number, b: string) end - - foo("Test", 123) - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - - CHECK_EQ(result.errors[0], (TypeError{Location{Position{3, 12}, Position{3, 18}}, TypeMismatch{ - typeChecker.numberType, - typeChecker.stringType, - }})); - - CHECK_EQ(result.errors[1], (TypeError{Location{Position{3, 20}, Position{3, 23}}, TypeMismatch{ - typeChecker.stringType, - typeChecker.numberType, - }})); -} - -TEST_CASE_FIXTURE(Fixture, "calling_function_with_anytypepack_doesnt_leak_free_types") -{ - CheckResult result = check(R"( - --!nonstrict - - function Test(a) - return 1, "" - end - - - local tab = {} - table.insert(tab, Test(1)); - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - ToStringOptions opts; - opts.exhaustive = true; - opts.maxTableLength = 0; - - CHECK_EQ("{any}", toString(requireType("tab"), opts)); -} - -TEST_CASE_FIXTURE(Fixture, "too_many_return_values") -{ - CheckResult result = check(R"( - --!strict - - function f() - return 55 - end - - local a, b = f() - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CountMismatch* acm = get(result.errors[0]); - REQUIRE(acm); - CHECK_EQ(acm->context, CountMismatch::Result); - CHECK_EQ(acm->expected, 1); - CHECK_EQ(acm->actual, 2); -} - -TEST_CASE_FIXTURE(Fixture, "ignored_return_values") -{ - CheckResult result = check(R"( - --!strict - - function f() - return 55, "" - end - - local a = f() - )"); - - LUAU_REQUIRE_ERROR_COUNT(0, result); -} - -TEST_CASE_FIXTURE(Fixture, "function_does_not_return_enough_values") -{ - CheckResult result = check(R"( - --!strict - - function f(): (number, string) - return 55 - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CountMismatch* acm = get(result.errors[0]); - REQUIRE(acm); - CHECK_EQ(acm->context, CountMismatch::Return); - CHECK_EQ(acm->expected, 2); - CHECK_EQ(acm->actual, 1); -} - -TEST_CASE_FIXTURE(Fixture, "typecheck_unary_minus") -{ - CheckResult result = check(R"( - --!strict - local foo = { - value = 10 - } - local mt = {} - setmetatable(foo, mt) - - mt.__unm = function(val: typeof(foo)): string - return val.value .. "test" - end - - local a = -foo - - local b = 1+-1 - - local bar = { - value = 10 - } - local c = -bar -- disallowed - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - CHECK_EQ("string", toString(requireType("a"))); - CHECK_EQ("number", toString(requireType("b"))); - - GenericError* gen = get(result.errors[0]); - REQUIRE_EQ(gen->message, "Unary operator '-' not supported by type 'bar'"); -} - -TEST_CASE_FIXTURE(Fixture, "unary_not_is_boolean") -{ - CheckResult result = check(R"( - local b = not "string" - local c = not (math.random() > 0.5 and "string" or 7) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("boolean", toString(requireType("b"))); - REQUIRE_EQ("boolean", toString(requireType("c"))); -} - -TEST_CASE_FIXTURE(Fixture, "disallow_string_and_types_without_metatables_from_arithmetic_binary_ops") -{ - CheckResult result = check(R"( - --!strict - local a = "1.24" + 123 -- not allowed - - local foo = { - value = 10 - } - - local b = foo + 1 -- not allowed - - local bar = { - value = 1 - } - - local mt = {} - - setmetatable(bar, mt) - - mt.__add = function(a: typeof(bar), b: number): number - return a.value + b - end - - local c = bar + 1 -- allowed - - local d = bar + foo -- not allowed - )"); - - LUAU_REQUIRE_ERROR_COUNT(3, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE_EQ(*tm->wantedType, *typeChecker.numberType); - REQUIRE_EQ(*tm->givenType, *typeChecker.stringType); - - TypeMismatch* tm2 = get(result.errors[2]); - CHECK_EQ(*tm2->wantedType, *typeChecker.numberType); - CHECK_EQ(*tm2->givenType, *requireType("foo")); - - GenericError* gen2 = get(result.errors[1]); - REQUIRE_EQ(gen2->message, "Binary operator '+' not supported by types 'foo' and 'number'"); -} - -// CLI-29033 -TEST_CASE_FIXTURE(Fixture, "unknown_type_in_comparison") -{ - CheckResult result = check(R"( - function merge(lower, greater) - if lower.y == greater.y then - end - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "concat_op_on_free_lhs_and_string_rhs") -{ - CheckResult result = check(R"( - local function f(x) - return x .. "y" - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - REQUIRE(get(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "concat_op_on_string_lhs_and_free_rhs") -{ - CheckResult result = check(R"( - local function f(x) - return "foo" .. x - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("(string) -> string", toString(requireType("f"))); -} - -TEST_CASE_FIXTURE(Fixture, "strict_binary_op_where_lhs_unknown") -{ - std::vector ops = {"+", "-", "*", "/", "%", "^", ".."}; - - std::string src = R"( - function foo(a, b) - )"; - - for (const auto& op : ops) - src += "local _ = a " + op + "b\n"; - - src += "end"; - - CheckResult result = check(src); - LUAU_REQUIRE_ERROR_COUNT(ops.size(), result); - - CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'a'", toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "function_cast_error_uses_correct_language") -{ - CheckResult result = check(R"( - function foo(a, b): number - return 0 - end - - local a: (string)->number = foo - local b: (number, number)->(number, number) = foo - - local c: (string, number)->number = foo -- no error - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - - auto tm1 = get(result.errors[0]); - REQUIRE(tm1); - - CHECK_EQ("(string) -> number", toString(tm1->wantedType)); - CHECK_EQ("(string, *unknown*) -> number", toString(tm1->givenType)); - - auto tm2 = get(result.errors[1]); - REQUIRE(tm2); - - CHECK_EQ("(number, number) -> (number, number)", toString(tm2->wantedType)); - CHECK_EQ("(string, *unknown*) -> number", toString(tm2->givenType)); -} - -TEST_CASE_FIXTURE(Fixture, "setmetatable_cant_be_used_to_mutate_global_types") -{ - { - Fixture fix; - - // inherit env from parent fixture checker - fix.typeChecker.globalScope = typeChecker.globalScope; - - fix.check(R"( ---!nonstrict -type MT = typeof(setmetatable) -function wtf(arg: {MT}): typeof(table) - arg = wtf(arg) -end -)"); - } - - // validate sharedEnv post-typecheck; valuable for debugging some typeck crashes but slows fuzzing down - // note: it's important for typeck to be destroyed at this point! - { - for (auto& p : typeChecker.globalScope->bindings) - { - toString(p.second.typeId); // toString walks the entire type, making sure ASAN catches access to destroyed type arenas - } - } -} - -TEST_CASE_FIXTURE(Fixture, "evil_table_unification") -{ - // this code re-infers the type of _ while processing fields of _, which can cause use-after-free - check(R"( ---!nonstrict -_ = ... -_:table(_,string)[_:gsub(_,...,n0)],_,_:gsub(_,string)[""],_:split(_,...,table)._,n0 = nil -do end -)"); -} - -TEST_CASE_FIXTURE(Fixture, "overload_is_not_a_function") -{ - check(R"( ---!nonstrict -function _(...):((typeof(not _))&(typeof(not _)))&((typeof(not _))&(typeof(not _))) -_(...)(setfenv,_,not _,"")[_] = nil -end -do end -_(...)(...,setfenv,_):_G() -)"); -} - -TEST_CASE_FIXTURE(Fixture, "cyclic_type_packs") -{ - // this has a risk of creating cyclic type packs, causing infinite loops / OOMs - check(R"( ---!nonstrict -_ += _(_,...) -repeat -_ += _(...) -until ... + _ -)"); - - check(R"( ---!nonstrict -_ += _(_(...,...),_(...)) -repeat -until _ -)"); -} - -TEST_CASE_FIXTURE(Fixture, "cyclic_follow") -{ - check(R"( ---!nonstrict -l0,table,_,_,_ = ... -_,_,_,_.time(...)._.n0,l0,_ = function(l0) -end,_.__index,(_),_.time(_.n0 or _,...) -for l0=...,_,"" do -end -_ += not _ -do end -)"); - - check(R"( ---!nonstrict -n13,_,table,_,l0,_,_ = ... -_,n0[(_)],_,_._(...)._.n39,l0,_._ = function(l84,...) -end,_.__index,"",_,l0._(nil) -for l0=...,table.n5,_ do -end -_:_(...).n1 /= _ -do -_(_ + _) -do end -end -)"); -} - -TEST_CASE_FIXTURE(Fixture, "and_binexps_dont_unify") -{ - CheckResult result = check(R"( - --!strict - local t = {} - while true and t[1] do - print(t[1].test) - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -struct FindFreeTypeVars -{ - bool foundOne = false; - - template - void cycle(ID) - { - } - - template - bool operator()(ID, T) - { - return !foundOne; - } - - template - bool operator()(ID, Unifiable::Free) - { - foundOne = true; - return false; - } -}; - -TEST_CASE_FIXTURE(Fixture, "dont_crash_when_setmetatable_does_not_produce_a_metatabletypevar") -{ - CheckResult result = check("local x = setmetatable({})"); - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "for_in_loop_where_iteratee_is_free") -{ - // This code doesn't pass typechecking. We just care that it doesn't crash. - (void)check(R"( - --!nonstrict - function _:_(...) - end - - repeat - if _ then - else - _ = ... - end - until _ - - for _ in _() do - end - )"); -} - -TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery") -{ - CheckResult result = check(R"( - local x = - local a = 7 - )"); - LUAU_REQUIRE_ERRORS(result); - - TypeId aType = requireType("a"); - CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::Number); -} - -// Check that type checker knows about error expressions -TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery_no_assert") -{ - CheckResult result = check("function +() local _ = true end"); - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery_no_replacement_name_in_error") -{ - { - CheckResult result = check(R"( - --!strict - local t = { x = 10, y = 20 } - return t. - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - } - - { - CheckResult result = check(R"( - --!strict - export type = number - export type = string - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - } - - { - CheckResult result = check(R"( - --!strict - function string.() end - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - } - - { - CheckResult result = check(R"( - --!strict - local function () end - local function () end - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - } - - { - CheckResult result = check(R"( - --!strict - local dm = {} - function dm.() end - function dm.() end - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - } -} - -TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators") -{ - CheckResult result = check(R"( - local a: boolean = true - local b: boolean = false - local foo = a < b - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - GenericError* ge = get(result.errors[0]); - REQUIRE(ge); - CHECK_EQ("Type 'boolean' cannot be compared with relational operator <", ge->message); -} - -TEST_CASE_FIXTURE(Fixture, "error_on_invalid_operand_types_to_relational_operators2") -{ - CheckResult result = check(R"( - local a: number | string = "" - local b: number | string = 1 - local foo = a < b - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - GenericError* ge = get(result.errors[0]); - REQUIRE(ge); - CHECK_EQ("Type 'number | string' cannot be compared with relational operator <", ge->message); -} - -TEST_CASE_FIXTURE(Fixture, "index_expr_should_be_checked") -{ - CheckResult result = check(R"( - local foo: any - - print(foo[(true).x]) - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - UnknownProperty* up = get(result.errors[0]); // Should probably be NotATable - REQUIRE(up); - CHECK_EQ("boolean", toString(up->table)); - CHECK_EQ("x", up->key); -} - -TEST_CASE_FIXTURE(Fixture, "unreachable_code_after_infinite_loop") -{ - { - CheckResult result = check(R"( - function unreachablecodepath(a): number - while true do - if a then return 10 end - end - -- unreachable - end - unreachablecodepath(4) - )"); - - LUAU_REQUIRE_ERROR_COUNT(0, result); - } - - { - CheckResult result = check(R"( - function reachablecodepath(a): number - while true do - if a then break end - return 10 - end - - print("x") -- correct error - end - reachablecodepath(4) - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK(get(result.errors[0])); - } - - { - CheckResult result = check(R"( - function unreachablecodepath(a): number - repeat - if a then return 10 end - until false - - -- unreachable - end - unreachablecodepath(4) - )"); - - LUAU_REQUIRE_ERROR_COUNT(0, result); - } - - { - CheckResult result = check(R"( - function reachablecodepath(a, b): number - repeat - if a then break end - - if b then return 10 end - until false - - print("x") -- correct error - end - reachablecodepath(4) - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK(get(result.errors[0])); - } - - { - CheckResult result = check(R"( - function unreachablecodepath(a: number?): number - repeat - return 10 - until a ~= nil - - -- unreachable - end - unreachablecodepath(4) - )"); - - LUAU_REQUIRE_ERROR_COUNT(0, result); - } -} - -TEST_CASE_FIXTURE(Fixture, "cli_38355_recursive_union") -{ - CheckResult result = check(R"( - --!strict - local _ - _ += _ and _ or _ and _ or _ and _ - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type contains a self-recursive construct that cannot be resolved", toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "stringify_nested_unions_with_optionals") -{ - CheckResult result = check(R"( - --!strict - local a: number | (string | boolean) | nil - local b: number = a - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ("(boolean | number | string)?", toString(tm->givenType)); -} - -TEST_CASE_FIXTURE(Fixture, "UnknownGlobalCompoundAssign") -{ - // In non-strict mode, global definition is still allowed - { - CheckResult result = check(R"( - --!nonstrict - a = a + 1 - print(a) - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); - } - - // In strict mode we no longer generate two errors from lhs - { - CheckResult result = check(R"( - --!strict - a += 1 - print(a) - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); - } - - // In non-strict mode, compound assignment is not a definition, it's a modification - { - CheckResult result = check(R"( - --!nonstrict - a += 1 - print(a) - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); - CHECK_EQ(toString(result.errors[0]), "Unknown global 'a'"); - } -} - -TEST_CASE_FIXTURE(Fixture, "loop_typecheck_crash_on_empty_optional") -{ - CheckResult result = check(R"( - local t = {} - for _ in t do - for _ in assert(missing()) do - end - end - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); -} - -TEST_CASE_FIXTURE(Fixture, "cli_39932_use_unifier_in_ensure_methods") -{ - CheckResult result = check(R"( - local x: {number|number} = {1, 2, 3} - local y = x[1] - x[2] - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "metatable_of_any_can_be_a_table") -{ - CheckResult result = check(R"( ---!strict -local T: any -T = {} -T.__index = T -function T.new(...) - local self = {} - setmetatable(self, T) - self:construct(...) - return self -end -function T:construct(index) -end -)"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "dont_report_type_errors_within_an_AstStatError") -{ - CheckResult result = check(R"( - foo - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "dont_report_type_errors_within_an_AstExprError") -{ - CheckResult result = check(R"( - local a = foo: - )"); - - LUAU_REQUIRE_ERROR_COUNT(2, result); -} - -TEST_CASE_FIXTURE(Fixture, "dont_ice_on_astexprerror") -{ - CheckResult result = check(R"( - local foo = -; - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); -} - -TEST_CASE_FIXTURE(Fixture, "strip_nil_from_lhs_or_operator") -{ - CheckResult result = check(R"( ---!strict -local a: number? = nil -local b: number = a or 1 - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "strip_nil_from_lhs_or_operator2") -{ - CheckResult result = check(R"( ---!nonstrict -local a: number? = nil -local b: number = a or 1 - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "dont_strip_nil_from_rhs_or_operator") -{ - CheckResult result = check(R"( ---!strict -local a: number? = nil -local b: number = 1 or a - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - TypeMismatch* tm = get(result.errors[0]); - REQUIRE(tm); - CHECK_EQ(typeChecker.numberType, tm->wantedType); - CHECK_EQ("number?", toString(tm->givenType)); -} - -TEST_CASE_FIXTURE(Fixture, "no_lossy_function_type") -{ - CheckResult result = check(R"( - --!strict - local tbl = {} - function tbl:abc(a: number, b: number) - return a - end - tbl:abc(1, 2) -- Line 6 - -- | Column 14 - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - TypeId type = requireTypeAtPosition(Position(6, 14)); - CHECK_EQ("(tbl, number, number) -> number", toString(type)); - auto ftv = get(type); - REQUIRE(ftv); - CHECK(ftv->hasSelf); -} - -TEST_CASE_FIXTURE(Fixture, "luau_resolves_symbols_the_same_way_lua_does") -{ - CheckResult result = check(R"( - --!strict - function Funky() - local a: number = foo - end - - local foo: string = 'hello' + print(a) -- oops! )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - auto e = result.errors.front(); - REQUIRE_MESSAGE(get(e) != nullptr, "Expected UnknownSymbol, but got " << e); + UnknownSymbol* us = get(result.errors[0]); + REQUIRE(us); + CHECK_EQ(us->name, "a"); } -TEST_CASE_FIXTURE(Fixture, "operator_eq_verifies_types_do_intersect") +TEST_CASE_FIXTURE(Fixture, "checking_should_not_ice") { - CheckResult result = check(R"( - type Array = { [number]: T } - type Fiber = { id: number } - type null = {} - - local fiberStack: Array = {} - local index = 0 - - local function f(fiber: Fiber) - local a = fiber ~= fiberStack[index] - local b = fiberStack[index] ~= fiber - end - - return f - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} + CHECK_NOTHROW(check(R"( + --!nonstrict + f,g = ... + f(g(...))[...] = nil + f,xpcall = ... + local value = g(...)(g(...)) + )")); -TEST_CASE_FIXTURE(Fixture, "general_require_call_expression") -{ - fileResolver.source["game/A"] = R"( ---!strict -return { def = 4 } - )"; - - fileResolver.source["game/B"] = R"( ---!strict -local tbl = { abc = require(game.A) } -local a : string = "" -a = tbl.abc.def - )"; - - CheckResult result = frontend.check("game/B"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number' could not be converted into 'string'", toString(result.errors[0])); + CHECK_EQ("any", toString(requireType("value"))); } -TEST_CASE_FIXTURE(Fixture, "general_require_type_mismatch") -{ - fileResolver.source["game/A"] = R"( -return { def = 4 } - )"; +// TEST_CASE_FIXTURE(Fixture, "infer_method_signature_of_argument") +// { +// CheckResult result = check(R"( +// function f(a) +// if a.cond then +// return a.method() +// end +// end +// )"); - fileResolver.source["game/B"] = R"( -local tbl: string = require(game.A) - )"; +// LUAU_REQUIRE_NO_ERRORS(result); - CheckResult result = frontend.check("game/B"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type '{| def: number |}' could not be converted into 'string'", toString(result.errors[0])); -} +// CHECK_EQ("A", toString(requireType("f"))); +// } -TEST_CASE_FIXTURE(Fixture, "nonstrict_self_mismatch_tail") +TEST_CASE_FIXTURE(Fixture, "cyclic_follow") { - CheckResult result = check(R"( + check(R"( --!nonstrict -local f = {} -function f:foo(a: number, b: number) end - -function bar(...) - f.foo(f, 1, ...) +l0,table,_,_,_ = ... +_,_,_,_.time(...)._.n0,l0,_ = function(l0) +end,_.__index,(_),_.time(_.n0 or _,...) +for l0=...,_,"" do end - -bar(2) -)"); - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "typeof_unresolved_function") -{ - CheckResult result = check(R"( -local function f(a: typeof(f)) end +_ += not _ +do end )"); - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Unknown global 'f'", toString(result.errors[0])); -} -TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning") -{ - CheckResult result = check(R"( + check(R"( --!nonstrict -local l0:any,l61:t0 = _,math -while _ do -_() +n13,_,table,_,l0,_,_ = ... +_,n0[(_)],_,_._(...)._.n39,l0,_._ = function(l84,...) +end,_.__index,"",_,l0._(nil) +for l0=...,table.n5,_ do end -function _():t0 +_:_(...).n1 /= _ +do +_(_ + _) +do end end -type t0 = any )"); - - std::optional ty = requireType("math"); - REQUIRE(ty); - - const TableTypeVar* ttv = get(*ty); - REQUIRE(ttv); - CHECK(ttv->instantiatedTypeParams.empty()); } -TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning_2") +struct FindFreeTypeVars { - ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; - - CheckResult result = check(R"( -type X = T -type K = X -)"); + bool foundOne = false; - LUAU_REQUIRE_NO_ERRORS(result); + template + void cycle(ID) + { + } - std::optional ty = requireType("math"); - REQUIRE(ty); + template + bool operator()(ID, T) + { + return !foundOne; + } - const TableTypeVar* ttv = get(*ty); - REQUIRE(ttv); - CHECK(ttv->instantiatedTypeParams.empty()); -} + template + bool operator()(ID, Unifiable::Free) + { + foundOne = true; + return false; + } +}; -TEST_CASE_FIXTURE(Fixture, "instantiate_table_cloning_3") +TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery") { - ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; - CheckResult result = check(R"( -type X = T -local a = {} -a.x = 4 -local b: X -a.y = 5 -local c: X -c = b -)"); - - LUAU_REQUIRE_NO_ERRORS(result); - - std::optional ty = requireType("a"); - REQUIRE(ty); + local x = + local a = 7 + )"); + LUAU_REQUIRE_ERRORS(result); - const TableTypeVar* ttv = get(*ty); - REQUIRE(ttv); - CHECK(ttv->instantiatedTypeParams.empty()); + TypeId aType = requireType("a"); + CHECK_EQ(getPrimitiveType(aType), PrimitiveTypeVar::Number); } -TEST_CASE_FIXTURE(Fixture, "bound_free_table_export_is_ok") +// Check that type checker knows about error expressions +TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery_no_assert") { - CheckResult result = check(R"( -local n = {} -function n:Clone() end + CheckResult result = check("function +() local _ = true end"); + LUAU_REQUIRE_ERRORS(result); +} -local m = {} +TEST_CASE_FIXTURE(Fixture, "tc_after_error_recovery_no_replacement_name_in_error") +{ + { + CheckResult result = check(R"( + --!strict + local t = { x = 10, y = 20 } + return t. + )"); -function m.a(x) - x:Clone() -end + LUAU_REQUIRE_ERROR_COUNT(1, result); + } -function m.b() - m.a(n) -end + { + CheckResult result = check(R"( + --!strict + export type = number + export type = string + )"); -return m -)"); - LUAU_REQUIRE_NO_ERRORS(result); -} + LUAU_REQUIRE_ERROR_COUNT(2, result); + } -TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") -{ - ScopedFastFlag sff{"LuauOnlyMutateInstantiatedTables", true}; + { + CheckResult result = check(R"( + --!strict + function string.() end + )"); - // Mutability in type function application right now can create strange recursive types - CheckResult result = check(R"( -type Table = { a: number } -type Self = T -local a: Self
- )"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + } - LUAU_REQUIRE_NO_ERRORS(result); - CHECK_EQ(toString(requireType("a")), "Table"); -} + { + CheckResult result = check(R"( + --!strict + local function () end + local function () end + )"); -TEST_CASE_FIXTURE(Fixture, "no_persistent_typelevel_change") -{ - TypeId mathTy = requireType(typeChecker.globalScope, "math"); - REQUIRE(mathTy); - TableTypeVar* ttv = getMutable(mathTy); - REQUIRE(ttv); - const FunctionTypeVar* ftv = get(ttv->props["frexp"].type); - REQUIRE(ftv); - auto original = ftv->level; + LUAU_REQUIRE_ERROR_COUNT(2, result); + } - CheckResult result = check("local a = math.frexp"); + { + CheckResult result = check(R"( + --!strict + local dm = {} + function dm.() end + function dm.() end + )"); - LUAU_REQUIRE_NO_ERRORS(result); - CHECK(ftv->level.level == original.level); - CHECK(ftv->level.subLevel == original.subLevel); + LUAU_REQUIRE_ERROR_COUNT(2, result); + } } -TEST_CASE_FIXTURE(Fixture, "table_indexing_error_location") +TEST_CASE_FIXTURE(Fixture, "index_expr_should_be_checked") { CheckResult result = check(R"( -local foo = {42} -local bar: number? -local baz = foo[bar] + local foo: any + + print(foo[(true).x]) )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(result.errors[0].location, Location{Position{3, 16}, Position{3, 19}}); + UnknownProperty* up = get(result.errors[0]); // Should probably be NotATable + REQUIRE(up); + CHECK_EQ("boolean", toString(up->table)); + CHECK_EQ("x", up->key); } -TEST_CASE_FIXTURE(Fixture, "table_simple_call") +TEST_CASE_FIXTURE(Fixture, "stringify_nested_unions_with_optionals") { CheckResult result = check(R"( -local a = setmetatable({ x = 2 }, { - __call = function(self) - return (self.x :: number) * 2 -- should work without annotation in the future - end -}) -local b = a() -local c = a(2) -- too many arguments + --!strict + local a: number | (string | boolean) | nil + local b: number = a )"); LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Argument count mismatch. Function expects 1 argument, but 2 are specified", toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "custom_require_global") -{ - CheckResult result = check(R"( ---!nonstrict -require = function(a) end - -local crash = require(game.A) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); + TypeMismatch* tm = get(result.errors[0]); + REQUIRE(tm); + CHECK_EQ(typeChecker.numberType, tm->wantedType); + CHECK_EQ("(boolean | number | string)?", toString(tm->givenType)); } -TEST_CASE_FIXTURE(Fixture, "operator_eq_operands_are_not_subtypes_of_each_other_but_has_overlap") +TEST_CASE_FIXTURE(Fixture, "cli_39932_use_unifier_in_ensure_methods") { - ScopedFastFlag sff1{"LuauEqConstraint", true}; - CheckResult result = check(R"( - local function f(a: string | number, b: boolean | number) - return a == b - end + local x: {number|number} = {1, 2, 3} + local y = x[1] - x[2] )"); - // This doesn't produce any errors but for the wrong reasons. - // This unit test serves as a reminder to not try and unify the operands on `==`/`~=`. LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "access_index_metamethod_that_returns_variadic") +TEST_CASE_FIXTURE(Fixture, "dont_report_type_errors_within_an_AstStatError") { CheckResult result = check(R"( - type Foo = {x: string} - local t = {} - setmetatable(t, { - __index = function(x: string): ...Foo - return {x = x} - end - }) - - local foo = t.bar + foo )"); - LUAU_REQUIRE_NO_ERRORS(result); - - ToStringOptions o; - o.exhaustive = true; - CHECK_EQ("{| x: string |}", toString(requireType("foo"), o)); + LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "detect_cyclic_typepacks") +TEST_CASE_FIXTURE(Fixture, "dont_report_type_errors_within_an_AstExprError") { CheckResult result = check(R"( - type ( ... ) ( ) ; - ( ... ) ( - - ... ) ( - ... ) - type = ( ... ) ; - ( ... ) ( ) ( ... ) ; - ( ... ) "" + local a = foo: )"); - CHECK_LE(0, result.errors.size()); + LUAU_REQUIRE_ERROR_COUNT(2, result); } -TEST_CASE_FIXTURE(Fixture, "detect_cyclic_typepacks2") +TEST_CASE_FIXTURE(Fixture, "dont_ice_on_astexprerror") { CheckResult result = check(R"( - function _(l0:((typeof((pcall)))|((((t0)->())|(typeof(-67108864)))|(any)))|(any),...):(((typeof(0))|(any))|(any),typeof(-67108864),any) - xpcall(_,_,_) - _(_,_,_) - end + local foo = -; )"); - CHECK_LE(0, result.errors.size()); + LUAU_REQUIRE_ERROR_COUNT(1, result); } -TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_quantifying") +TEST_CASE_FIXTURE(Fixture, "luau_resolves_symbols_the_same_way_lua_does") { CheckResult result = check(R"( - function _(l0:t0): (any, ()->()) + --!strict + function Funky() + local a: number = foo end - type t0 = t0 | {} + local foo: string = 'hello' )"); - CHECK_LE(0, result.errors.size()); - - std::optional t0 = getMainModule()->getModuleScope()->lookupType("t0"); - REQUIRE(t0); - CHECK_EQ("*unknown*", toString(t0->type)); + LUAU_REQUIRE_ERROR_COUNT(1, result); - auto it = std::find_if(result.errors.begin(), result.errors.end(), [](TypeError& err) { - return get(err); - }); - CHECK(it != result.errors.end()); + auto e = result.errors.front(); + REQUIRE_MESSAGE(get(e) != nullptr, "Expected UnknownSymbol, but got " << e); } TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_isoptional") @@ -4316,360 +713,21 @@ TEST_CASE_FIXTURE(Fixture, "no_infinite_loop_when_trying_to_unify_uh_this") CHECK_LE(0, result.errors.size()); } -TEST_CASE_FIXTURE(Fixture, "no_stack_overflow_from_flattenintersection") -{ - CheckResult result = check(R"( - local l0,l0 - repeat - type t0 = ((any)|((any)&((any)|((any)&((any)|(any))))))&(t0) - function _(l0):(t0)&(t0) - while nil do - end - end - until _(_)(_)._ - )"); - - CHECK_LE(0, result.errors.size()); -} - TEST_CASE_FIXTURE(Fixture, "no_heap_use_after_free_error") -{ - CheckResult result = check(R"( - --!nonstrict - _ += _:n0(xpcall,_) - local l0 - do end - while _ do - function _:_() - _ += _(_._(_:n0(xpcall,_))) - end - end - )"); - - CHECK_LE(0, result.errors.size()); -} - -TEST_CASE_FIXTURE(Fixture, "dont_invalidate_the_properties_iterator_of_free_table_when_rolled_back") -{ - fileResolver.source["Module/Backend/Types"] = R"( - export type Fiber = { - return_: Fiber? - } - return {} - )"; - - fileResolver.source["Module/Backend"] = R"( - local Types = require(script.Types) - type Fiber = Types.Fiber - type ReactRenderer = { findFiberByHostInstance: () -> Fiber? } - - local function attach(renderer): () - local function getPrimaryFiber(fiber) - local alternate = fiber.alternate - return fiber - end - - local function getFiberIDForNative() - local fiber = renderer.findFiberByHostInstance() - fiber = fiber.return_ - return getPrimaryFiber(fiber) - end - end - - function culprit(renderer: ReactRenderer): () - attach(renderer) - end - - return culprit - )"; - - CheckResult result = frontend.check("Module/Backend"); -} - -TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_ok") -{ - CheckResult result = check(R"( - type Tree = { data: T, children: {Tree} } - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "recursive_types_restriction_not_ok") -{ - ScopedFastFlag sff{"LuauRecursiveTypeParameterRestriction", true}; - - CheckResult result = check(R"( - -- this would be an infinite type if we allowed it - type Tree = { data: T, children: {Tree<{T}>} } - )"); - - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "record_matching_overload") -{ - CheckResult result = check(R"( - type Overload = ((string) -> string) & ((number) -> number) - local abc: Overload - abc(1) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // AstExprCall is the node that has the overload stored on it. - // findTypeAtPosition will look at the AstExprLocal, but this is not what - // we want to look at. - std::vector ancestry = findAstAncestryOfPosition(*getMainSourceModule(), Position(3, 10)); - REQUIRE_GE(ancestry.size(), 2); - AstExpr* parentExpr = ancestry[ancestry.size() - 2]->asExpr(); - REQUIRE(bool(parentExpr)); - REQUIRE(parentExpr->is()); - - ModulePtr module = getMainModule(); - auto it = module->astOverloadResolvedTypes.find(parentExpr); - REQUIRE(it); - CHECK_EQ(toString(*it), "(number) -> number"); -} - -TEST_CASE_FIXTURE(Fixture, "return_type_by_overload") -{ - ScopedFastFlag sff{"LuauErrorRecoveryType", true}; - - CheckResult result = check(R"( - type Overload = ((string) -> string) & ((number, number) -> number) - local abc: Overload - local x = abc(true) - local y = abc(true,true) - local z = abc(true,true,true) - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("string", toString(requireType("x"))); - CHECK_EQ("number", toString(requireType("y"))); - // Should this be string|number? - CHECK_EQ("string", toString(requireType("z"))); -} - -TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments") -{ - // Simple direct arg to arg propagation - CheckResult result = check(R"( -type Table = { x: number, y: number } -local function f(a: (Table) -> number) return a({x = 1, y = 2}) end -f(function(a) return a.x + a.y end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // An optional function is accepted, but since we already provide a function, nil can be ignored - result = check(R"( -type Table = { x: number, y: number } -local function f(a: ((Table) -> number)?) if a then return a({x = 1, y = 2}) else return 0 end end -f(function(a) return a.x + a.y end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // Make sure self calls match correct index - result = check(R"( -type Table = { x: number, y: number } -local x = {} -x.b = {x = 1, y = 2} -function x:f(a: (Table) -> number) return a(self.b) end -x:f(function(a) return a.x + a.y end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // Mix inferred and explicit argument types - result = check(R"( -function f(a: (a: number, b: number, c: boolean) -> number) return a(1, 2, true) end -f(function(a: number, b, c) return c and a + b or b - a end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // Anonymous function has a variadic pack - result = check(R"( -type Table = { x: number, y: number } -local function f(a: (Table) -> number) return a({x = 1, y = 2}) end -f(function(...) return select(1, ...).z end) - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); - - // Can't accept more arguments than provided - result = check(R"( -function f(a: (a: number, b: number) -> number) return a(1, 2) end -f(function(a, b, c, ...) return a + b end) - )"); - - LUAU_REQUIRE_ERRORS(result); - - CHECK_EQ(R"(Type '(number, number, a) -> number' could not be converted into '(number, number) -> number' -caused by: - Argument count mismatch. Function expects 3 arguments, but only 2 are specified)", - toString(result.errors[0])); - - // Infer from variadic packs into elements - result = check(R"( -function f(a: (...number) -> number) return a(1, 2) end -f(function(a, b) return a + b end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - // Infer from variadic packs into variadic packs - result = check(R"( -type Table = { x: number, y: number } -function f(a: (...Table) -> number) return a({x = 1, y = 2}, {x = 3, y = 4}) end -f(function(a, ...) local b = ... return b.z end) - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Key 'z' not found in table 'Table'", toString(result.errors[0])); - - // Return type inference - result = check(R"( -type Table = { x: number, y: number } -function f(a: (number) -> Table) return a(4) end -f(function(x) return x * 2 end) - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Type 'number' could not be converted into 'Table'", toString(result.errors[0])); - - // Return type doesn't inference 'nil' - result = check(R"( -function f(a: (number) -> nil) return a(4) end -f(function(x) print(x) end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument") -{ - ScopedFastFlag sff{"LuauUnsealedTableLiteral", true}; - - CheckResult result = check(R"( -local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end -return sum(2, 3, function(a, b) return a + b end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - result = check(R"( -local function map(arr: {a}, f: (a) -> b) local r = {} for i,v in ipairs(arr) do table.insert(r, f(v)) end return r end -local a = {1, 2, 3} -local r = map(a, function(a) return a + a > 100 end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("{boolean}", toString(requireType("r"))); - - check(R"( -local function foldl(arr: {a}, init: b, f: (b, a) -> b) local r = init for i,v in ipairs(arr) do r = f(r, v) end return r end -local a = {1, 2, 3} -local r = foldl(a, {s=0,c=0}, function(a, b) return {s = a.s + b, c = a.c + 1} end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - REQUIRE_EQ("{ c: number, s: number }", toString(requireType("r"))); -} - -TEST_CASE_FIXTURE(Fixture, "infer_generic_function_function_argument_overloaded") -{ - CheckResult result = check(R"( -local function g1(a: T, f: (T) -> T) return f(a) end -local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end - -local g12: typeof(g1) & typeof(g2) - -g12(1, function(x) return x + x end) -g12(1, 2, function(x, y) return x + y end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - result = check(R"( -local function g1(a: T, f: (T) -> T) return f(a) end -local function g2(a: T, b: T, f: (T, T) -> T) return f(a, b) end - -local g12: typeof(g1) & typeof(g2) - -g12({x=1}, function(x) return {x=-x.x} end) -g12({x=1}, {x=2}, function(x, y) return {x=x.x + y.x} end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "infer_generic_lib_function_function_argument") -{ - CheckResult result = check(R"( -local a = {{x=4}, {x=7}, {x=1}} -table.sort(a, function(x, y) return x.x < y.x end) - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "infer_anonymous_function_arguments_outside_call") -{ - CheckResult result = check(R"( -type Table = { x: number, y: number } -local f: (Table) -> number = function(t) return t.x + t.y end - -type TableWithFunc = { x: number, y: number, f: (number, number) -> number } -local a: TableWithFunc = { x = 3, y = 4, f = function(a, b) return a + b end } - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "do_not_infer_generic_functions") -{ - CheckResult result = check(R"( -local function sum(x: a, y: a, f: (a, a) -> a) return f(x, y) end - -local function sumrec(f: typeof(sum)) - return sum(2, 3, function(a, b) return a + b end) -end - -local b = sumrec(sum) -- ok -local c = sumrec(function(x, y, f) return f(x, y) end) -- type binders are not inferred - )"); - - LUAU_REQUIRE_ERRORS(result); - CHECK_EQ("Type '(a, b, (a, b) -> (c...)) -> (c...)' could not be converted into '(a, a, (a, a) -> a) -> a'; different number of generic type " - "parameters", - toString(result.errors[0])); -} - -TEST_CASE_FIXTURE(Fixture, "infer_return_value_type") -{ - CheckResult result = check(R"( -local function f(): {string|number} - return {1, "b", 3} -end - -local function g(): (number, {string|number}) - return 4, {1, "b", 3} -end - -local function h(): ...{string|number} - return {4}, {1, "b", 3}, {"s"} -end - -local function i(): ...{string|number} - return {1, "b", 3}, h() -end +{ + CheckResult result = check(R"( + --!nonstrict + _ += _:n0(xpcall,_) + local l0 + do end + while _ do + function _:_() + _ += _(_._(_:n0(xpcall,_))) + end + end )"); - LUAU_REQUIRE_NO_ERRORS(result); + CHECK_LE(0, result.errors.size()); } TEST_CASE_FIXTURE(Fixture, "infer_type_assertion_value_type") @@ -4719,56 +777,6 @@ f(((function(a, b) return a + b end))) LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "refine_and_or") -{ - CheckResult result = check(R"( - local t: {x: number?}? = {x = nil} - local u = t and t.x or 5 - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ("number", toString(requireType("u"))); -} - -TEST_CASE_FIXTURE(Fixture, "checked_prop_too_early") -{ - CheckResult result = check(R"( - local t: {x: number?}? = {x = nil} - local u = t.x and t or 5 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); - CHECK_EQ("number | {| x: number? |}", toString(requireType("u"))); -} - -TEST_CASE_FIXTURE(Fixture, "accidentally_checked_prop_in_opposite_branch") -{ - CheckResult result = check(R"( - local t: {x: number?}? = {x = nil} - local u = t and t.x == 5 or t.x == 31337 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ("Value of type '{| x: number? |}?' could be nil", toString(result.errors[0])); - CHECK_EQ("boolean", toString(requireType("u"))); -} - -TEST_CASE_FIXTURE(Fixture, "substitution_with_bound_table") -{ - CheckResult result = check(R"( -type A = { x: number } -local a: A = { x = 1 } -local b = a -type B = typeof(b) -type X = T -local c: X - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "tc_if_else_expressions1") { CheckResult result = check(R"(local a = if true then "true" else "false")"); @@ -4830,40 +838,6 @@ end LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "type_error_addition") -{ - CheckResult result = check(R"( ---!strict -local foo = makesandwich() -local bar = foo.nutrition + 100 - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - - // We should definitely get this error - CHECK_EQ("Unknown global 'makesandwich'", toString(result.errors[0])); - // We get this error if makesandwich() returns a free type - // CHECK_EQ("Unknown type used in + operation; consider adding a type annotation to 'foo'", toString(result.errors[1])); -} - -TEST_CASE_FIXTURE(Fixture, "require_failed_module") -{ - fileResolver.source["game/A"] = R"( -return unfortunately() - )"; - - CheckResult aResult = frontend.check("game/A"); - LUAU_REQUIRE_ERRORS(aResult); - - CheckResult result = check(R"( -local ModuleA = require(game.A) - )"); - LUAU_REQUIRE_NO_ERRORS(result); - - std::optional oty = requireType("ModuleA"); - CHECK_EQ("*unknown*", toString(*oty)); -} - /* * If it wasn't instantly obvious, we have the fuzzer to thank for this gem of a test. * @@ -4921,184 +895,6 @@ TEST_CASE_FIXTURE(Fixture, "fuzzer_found_this") )"); } -/* - * We had an issue where part of the type of pairs() was an unsealed table. - * This test depends on FFlagDebugLuauFreezeArena to trigger it. - */ -TEST_CASE_FIXTURE(Fixture, "pairs_parameters_are_not_unsealed_tables") -{ - check(R"( - function _(l0:{n0:any}) - _ = pairs - end - )"); -} - -TEST_CASE_FIXTURE(Fixture, "inferred_methods_of_free_tables_have_the_same_level_as_the_enclosing_table") -{ - check(R"( - function Base64FileReader(data) - local reader = {} - local index: number - - function reader:PeekByte() - return data:byte(index) - end - - function reader:Byte() - return data:byte(index - 1) - end - - return reader - end - - Base64FileReader() - - function ReadMidiEvents(data) - - local reader = Base64FileReader(data) - - while reader:HasMore() do - (reader:Byte() % 128) - end - end - )"); -} - -TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg_count") -{ - CheckResult result = check(R"( -type A = (number, number) -> string -type B = (number) -> string - -local a: A -local b: B = a - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number) -> string' -caused by: - Argument count mismatch. Function expects 2 arguments, but only 1 is specified)"); -} - -TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_arg") -{ - CheckResult result = check(R"( -type A = (number, number) -> string -type B = (number, string) -> string - -local a: A -local b: B = a - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, string) -> string' -caused by: - Argument #2 type is not compatible. Type 'string' could not be converted into 'number')"); -} - -TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_count") -{ - CheckResult result = check(R"( -type A = (number, number) -> (number) -type B = (number, number) -> (number, boolean) - -local a: A -local b: B = a - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> number' could not be converted into '(number, number) -> (number, boolean)' -caused by: - Function only returns 1 value. 2 are required here)"); -} - -TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret") -{ - CheckResult result = check(R"( -type A = (number, number) -> string -type B = (number, number) -> number - -local a: A -local b: B = a - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), R"(Type '(number, number) -> string' could not be converted into '(number, number) -> number' -caused by: - Return type is not compatible. Type 'string' could not be converted into 'number')"); -} - -TEST_CASE_FIXTURE(Fixture, "error_detailed_function_mismatch_ret_mult") -{ - CheckResult result = check(R"( -type A = (number, number) -> (number, string) -type B = (number, number) -> (number, boolean) - -local a: A -local b: B = a - )"); - - LUAU_REQUIRE_ERROR_COUNT(1, result); - CHECK_EQ(toString(result.errors[0]), - R"(Type '(number, number) -> (number, string)' could not be converted into '(number, number) -> (number, boolean)' -caused by: - Return #2 type is not compatible. Type 'string' could not be converted into 'boolean')"); -} - -TEST_CASE_FIXTURE(Fixture, "prop_access_on_any_with_other_options") -{ - CheckResult result = check(R"( - local function f(thing: any | string) - local foo = thing.SomeRandomKey - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "table_function_check_use_after_free") -{ - CheckResult result = check(R"( -local t = {} - -function t.x(value) - for k,v in pairs(t) do end -end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "table_oop") -{ - CheckResult result = check(R"( - --!strict -local Class = {} -Class.__index = Class - -type Class = typeof(setmetatable({} :: { x: number }, Class)) - -function Class.new(x: number): Class - return setmetatable({x = x}, Class) -end - -function Class.getx(self: Class) - return self.x -end - -function test() - local c = Class.new(42) - local n = c:getx() - local nn = c.x - - print(string.format("%d %d", n, nn)) -end -)"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - TEST_CASE_FIXTURE(Fixture, "recursive_metatable_crash") { CheckResult result = check(R"( @@ -5182,213 +978,4 @@ TEST_CASE_FIXTURE(Fixture, "cli_50041_committing_txnlog_in_apollo_client_error") LUAU_REQUIRE_NO_ERRORS(result); } -TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types") -{ - fileResolver.source["game/A"] = R"( -export type Type = { unrelated: boolean } -return {} - )"; - - fileResolver.source["game/B"] = R"( -local types = require(game.A) -type Type = types.Type -local x: Type = {} -function x:Destroy(): () end - )"; - - CheckResult result = frontend.check("game/B"); - LUAU_REQUIRE_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_2") -{ - ScopedFastFlag immutableTypes{"LuauImmutableTypes", true}; - - fileResolver.source["game/A"] = R"( -export type Type = { x: { a: number } } -return {} - )"; - - fileResolver.source["game/B"] = R"( -local types = require(game.A) -type Type = types.Type -local x: Type = { x = { a = 2 } } -type Rename = typeof(x.x) - )"; - - CheckResult result = frontend.check("game/B"); - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_3") -{ - ScopedFastFlag immutableTypes{"LuauImmutableTypes", true}; - - fileResolver.source["game/A"] = R"( -local y = setmetatable({}, {}) -export type Type = { x: typeof(y) } -return { x = y } - )"; - - fileResolver.source["game/B"] = R"( -local types = require(game.A) -type Type = types.Type -local x: Type = types -type Rename = typeof(x.x) - )"; - - CheckResult result = frontend.check("game/B"); - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "indexing_on_string_singletons") -{ - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - {"LuauSingletonTypes", true}, - }; - - CheckResult result = check(R"( - local a: string = "hi" - if a == "hi" then - local x = a:byte() - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 22}))); -} - -TEST_CASE_FIXTURE(Fixture, "indexing_on_union_of_string_singletons") -{ - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - {"LuauSingletonTypes", true}, - }; - - CheckResult result = check(R"( - local a: string = "hi" - if a == "hi" or a == "bye" then - local x = a:byte() - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(R"("bye" | "hi")", toString(requireTypeAtPosition({3, 22}))); -} - -TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_string_singleton") -{ - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - {"LuauSingletonTypes", true}, - }; - - CheckResult result = check(R"( - local a: string = "hi" - if a == "hi" then - local x = #a - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(R"("hi")", toString(requireTypeAtPosition({3, 23}))); -} - -TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_union_of_string_singleton") -{ - ScopedFastFlag sff[]{ - {"LuauDiscriminableUnions2", true}, - {"LuauSingletonTypes", true}, - }; - - CheckResult result = check(R"( - local a: string = "hi" - if a == "hi" or a == "bye" then - local x = #a - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); - - CHECK_EQ(R"("bye" | "hi")", toString(requireTypeAtPosition({3, 23}))); -} - -/* - * When we add new properties to an unsealed table, we should do a level check and promote the property type to be at - * the level of the table. - */ -TEST_CASE_FIXTURE(Fixture, "inferred_properties_of_a_table_should_start_with_the_same_TypeLevel_of_that_table") -{ - CheckResult result = check(R"( - --!strict - local T = {} - - local function f(prop) - T[1] = { - prop = prop, - } - end - - local function g() - local l = T[1].prop - end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "global_singleton_types_are_sealed") -{ - CheckResult result = check(R"( -local function f(x: string) - local p = x:split('a') - p = table.pack(table.unpack(p, 1, #p - 1)) - return p -end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "function_decl_quantify_right_type") -{ - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify", true}; - - fileResolver.source["game/isAMagicMock"] = R"( ---!nonstrict -return function(value) - return false -end - )"; - - CheckResult result = check(R"( ---!nonstrict -local MagicMock = {} -MagicMock.is = require(game.isAMagicMock) - -function MagicMock.is(value) - return false -end - )"); - - LUAU_REQUIRE_NO_ERRORS(result); -} - -TEST_CASE_FIXTURE(Fixture, "function_decl_non_self_sealed_overwrite") -{ - ScopedFastFlag statFunctionSimplify{"LuauStatFunctionSimplify", true}; - - CheckResult result = check(R"( -function string.len(): number - return 1 -end - )"); - - LUAU_REQUIRE_ERRORS(result); -} - TEST_SUITE_END(); diff --git a/tests/TypeInfer.typePacks.cpp b/tests/TypeInfer.typePacks.cpp index fcc21c18e..130f33d70 100644 --- a/tests/TypeInfer.typePacks.cpp +++ b/tests/TypeInfer.typePacks.cpp @@ -895,4 +895,87 @@ caused by: Type 'boolean' could not be converted into 'string')"); } +// TODO: File a Jira about this +/* +TEST_CASE_FIXTURE(Fixture, "unifying_vararg_pack_with_fixed_length_pack_produces_fixed_length_pack") +{ + CheckResult result = check(R"( + function a(x) return 1 end + a(...) + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + REQUIRE(bool(getMainModule()->getModuleScope()->varargPack)); + + TypePackId varargPack = *getMainModule()->getModuleScope()->varargPack; + + auto iter = begin(varargPack); + auto endIter = end(varargPack); + + CHECK(iter != endIter); + ++iter; + CHECK(iter == endIter); + + CHECK(!iter.tail()); +} +*/ + +TEST_CASE_FIXTURE(Fixture, "dont_ice_if_a_TypePack_is_an_error") +{ + CheckResult result = check(R"( + --!strict + function f(s) + print(s) + return f + end + + f("foo")("bar") + )"); +} + +TEST_CASE_FIXTURE(Fixture, "cyclic_type_packs") +{ + // this has a risk of creating cyclic type packs, causing infinite loops / OOMs + check(R"( +--!nonstrict +_ += _(_,...) +repeat +_ += _(...) +until ... + _ +)"); + + check(R"( +--!nonstrict +_ += _(_(...,...),_(...)) +repeat +until _ +)"); +} + +TEST_CASE_FIXTURE(Fixture, "detect_cyclic_typepacks") +{ + CheckResult result = check(R"( + type ( ... ) ( ) ; + ( ... ) ( - - ... ) ( - ... ) + type = ( ... ) ; + ( ... ) ( ) ( ... ) ; + ( ... ) "" + )"); + + CHECK_LE(0, result.errors.size()); +} + +TEST_CASE_FIXTURE(Fixture, "detect_cyclic_typepacks2") +{ + CheckResult result = check(R"( + function _(l0:((typeof((pcall)))|((((t0)->())|(typeof(-67108864)))|(any)))|(any),...):(((typeof(0))|(any))|(any),typeof(-67108864),any) + xpcall(_,_,_) + _(_,_,_) + end + )"); + + CHECK_LE(0, result.errors.size()); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index ad4cecd87..ad1e31e5c 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -496,4 +496,20 @@ caused by: None of the union options are compatible. For example: Table type 'a' not compatible with type 'X' because the former is missing field 'x')"); } +// We had a bug where a cyclic union caused a stack overflow. +// ex type U = number | U +TEST_CASE_FIXTURE(Fixture, "dont_allow_cyclic_unions_to_be_inferred") +{ + CheckResult result = check(R"( + --!strict + + function f(a, b) + a:g(b or {}) + a:g(b) + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); +} + TEST_SUITE_END(); diff --git a/tests/conformance/basic.lua b/tests/conformance/basic.lua index 78d900770..f803c3193 100644 --- a/tests/conformance/basic.lua +++ b/tests/conformance/basic.lua @@ -833,6 +833,17 @@ assert((function() return sum end)() == 105) +-- shrinking array part +assert((function() + local t = table.create(100, 42) + for i=1,90 do t[i] = nil end + t[101] = 42 + local sum = 0 + for _,v in ipairs(t) do sum += v end + for _,v in pairs(t) do sum += v end + return sum +end)() == 462) + -- upvalues: recursive capture assert((function() local function fact(n) return n < 1 and 1 or n * fact(n-1) end return fact(5) end)() == 120) @@ -881,6 +892,14 @@ end)() == "6,8,10") -- typeof == type in absence of custom userdata assert(concat(typeof(5), typeof(nil), typeof({}), typeof(newproxy())) == "number,nil,table,userdata") +-- type/typeof/newproxy interaction with metatables: __type doesn't work intentionally to avoid spoofing +assert((function() + local ud = newproxy(true) + getmetatable(ud).__type = "number" + + return concat(type(ud),typeof(ud)) +end)() == "userdata,userdata") + testgetfenv() -- DONT MOVE THIS LINE return 'OK' diff --git a/tests/conformance/debugger.lua b/tests/conformance/debugger.lua index 6ba99fb9b..ec0b412e0 100644 --- a/tests/conformance/debugger.lua +++ b/tests/conformance/debugger.lua @@ -3,14 +3,14 @@ print "testing debugger" -- note, this file can't run in isolation from C tests local a = 5 -function foo(b) +function foo(b, ...) print("in foo", b) a = 6 end breakpoint(8) -foo(50) +foo(50, 42) breakpoint(16) -- next line print("here") diff --git a/tests/conformance/errors.lua b/tests/conformance/errors.lua index 751188bed..297cf0110 100644 --- a/tests/conformance/errors.lua +++ b/tests/conformance/errors.lua @@ -305,4 +305,6 @@ assert(ecall(function() return "a" + "b" end) == "attempt to perform arithmetic assert(ecall(function() return 1 > nil end) == "attempt to compare nil < number") -- note reversed order (by design) assert(ecall(function() return "a" <= 5 end) == "attempt to compare string <= number") +assert(ecall(function() local t = {} setmetatable(t, { __newindex = function(t,i,v) end }) t[nil] = 2 end) == "table index is nil") + return('OK') diff --git a/tests/conformance/interrupt.lua b/tests/conformance/interrupt.lua new file mode 100644 index 000000000..2b1270991 --- /dev/null +++ b/tests/conformance/interrupt.lua @@ -0,0 +1,11 @@ +-- This file is part of the Luau programming language and is licensed under MIT License; see LICENSE.txt for details +print("testing interrupts") + +function foo() + for i=1,10 do end + return +end + +foo() + +return "OK" diff --git a/tools/gdb-printers.py b/tools/gdb_printers.py similarity index 100% rename from tools/gdb-printers.py rename to tools/gdb_printers.py diff --git a/tools/lldb-formatters.lldb b/tools/lldb-formatters.lldb deleted file mode 100644 index 3868ac20c..000000000 --- a/tools/lldb-formatters.lldb +++ /dev/null @@ -1,2 +0,0 @@ -type synthetic add -x "^Luau::Variant<.+>$" -l LuauVisualize.LuauVariantSyntheticChildrenProvider -type summary add -x "^Luau::Variant<.+>$" -l LuauVisualize.luau_variant_summary diff --git a/tools/lldb_formatters.lldb b/tools/lldb_formatters.lldb new file mode 100644 index 000000000..f6fa6cf5a --- /dev/null +++ b/tools/lldb_formatters.lldb @@ -0,0 +1,2 @@ +type synthetic add -x "^Luau::Variant<.+>$" -l lldb_formatters.LuauVariantSyntheticChildrenProvider +type summary add -x "^Luau::Variant<.+>$" -l lldb_formatters.luau_variant_summary diff --git a/tools/LuauVisualize.py b/tools/lldb_formatters.py similarity index 100% rename from tools/LuauVisualize.py rename to tools/lldb_formatters.py From 57faf7aaf2d6dd20eb69bc4088e61549cfa6552f Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 17 Mar 2022 17:32:02 -0700 Subject: [PATCH 31/32] Lower the stack limit to make tests pass in debug --- tests/TypeInfer.test.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 571d0f8d6..660ddcfcf 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -334,7 +334,7 @@ TEST_CASE_FIXTURE(Fixture, "check_expr_recursion_limit") #if defined(LUAU_ENABLE_ASAN) int limit = 250; #elif defined(_DEBUG) || defined(_NOOPT) - int limit = 350; + int limit = 300; #else int limit = 600; #endif From 373da161e915de2aa71ba83fe9baf23b269057f3 Mon Sep 17 00:00:00 2001 From: Arseny Kapoulkine Date: Thu, 24 Mar 2022 14:49:08 -0700 Subject: [PATCH 32/32] Sync to upstream/release/520 --- Analysis/include/Luau/Error.h | 1 + Analysis/include/Luau/ToString.h | 3 +- Analysis/include/Luau/TypePack.h | 6 +- Analysis/include/Luau/TypeVar.h | 3 + Analysis/include/Luau/Unifier.h | 4 +- Analysis/include/Luau/UnifierSharedState.h | 2 + Analysis/src/Error.cpp | 30 ++- Analysis/src/Linter.cpp | 43 +++- Analysis/src/Module.cpp | 31 +-- Analysis/src/ToString.cpp | 68 +++++-- Analysis/src/TypeInfer.cpp | 69 ++++--- Analysis/src/TypePack.cpp | 19 +- Analysis/src/TypeVar.cpp | 18 ++ Analysis/src/Unifier.cpp | 173 +++++++++++----- Ast/src/Parser.cpp | 20 +- VM/src/lapi.cpp | 44 ++-- VM/src/ldo.cpp | 21 +- VM/src/ldo.h | 2 +- VM/src/lgc.cpp | 224 +++++++++++---------- VM/src/lgc.h | 2 +- VM/src/lstate.cpp | 26 +-- VM/src/lstate.h | 44 ++-- VM/src/ltable.cpp | 4 +- tests/Autocomplete.test.cpp | 5 - tests/Conformance.test.cpp | 2 - tests/Fixture.cpp | 17 +- tests/Linter.test.cpp | 87 ++------ tests/ToString.test.cpp | 25 +++ tests/Transpiler.test.cpp | 2 - tests/TypeInfer.builtins.test.cpp | 2 - tests/TypeInfer.functions.test.cpp | 76 +++++++ tests/TypeInfer.modules.test.cpp | 85 +++++++- tests/TypeInfer.operators.test.cpp | 26 +++ tests/TypeInfer.refinements.test.cpp | 11 - tests/TypeInfer.singletons.test.cpp | 118 +---------- tests/TypeInfer.tables.test.cpp | 38 ++++ tests/TypeInfer.test.cpp | 2 +- tests/TypeInfer.unionTypes.test.cpp | 1 + 38 files changed, 798 insertions(+), 556 deletions(-) diff --git a/Analysis/include/Luau/Error.h b/Analysis/include/Luau/Error.h index 72350255e..53b946a06 100644 --- a/Analysis/include/Luau/Error.h +++ b/Analysis/include/Luau/Error.h @@ -96,6 +96,7 @@ struct CountMismatch size_t expected; size_t actual; Context context = Arg; + bool isVariadic = false; bool operator==(const CountMismatch& rhs) const; }; diff --git a/Analysis/include/Luau/ToString.h b/Analysis/include/Luau/ToString.h index a97bf6d6b..49ee82fe3 100644 --- a/Analysis/include/Luau/ToString.h +++ b/Analysis/include/Luau/ToString.h @@ -32,6 +32,7 @@ struct ToStringOptions size_t maxTypeLength = size_t(FInt::LuauTypeMaximumStringifierLength); std::optional nameMap; std::shared_ptr scope; // If present, module names will be added and types that are not available in scope will be marked as 'invalid' + std::vector namedFunctionOverrideArgNames; // If present, named function argument names will be overridden }; struct ToStringResult @@ -65,7 +66,7 @@ inline std::string toString(TypePackId ty) std::string toString(const TypeVar& tv, const ToStringOptions& opts = {}); std::string toString(const TypePackVar& tp, const ToStringOptions& opts = {}); -std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts = {}); +std::string toStringNamedFunction(const std::string& funcName, const FunctionTypeVar& ftv, const ToStringOptions& opts = {}); // It could be useful to see the text representation of a type during a debugging session instead of exploring the content of the class // These functions will dump the type to stdout and can be evaluated in Watch/Immediate windows or as gdb/lldb expression diff --git a/Analysis/include/Luau/TypePack.h b/Analysis/include/Luau/TypePack.h index 946be3561..85fa467f7 100644 --- a/Analysis/include/Luau/TypePack.h +++ b/Analysis/include/Luau/TypePack.h @@ -119,9 +119,9 @@ bool areEqual(SeenSet& seen, const TypePackVar& lhs, const TypePackVar& rhs); TypePackId follow(TypePackId tp); TypePackId follow(TypePackId tp, std::function mapper); -size_t size(TypePackId tp); -bool finite(TypePackId tp); -size_t size(const TypePack& tp); +size_t size(TypePackId tp, TxnLog* log = nullptr); +bool finite(TypePackId tp, TxnLog* log = nullptr); +size_t size(const TypePack& tp, TxnLog* log = nullptr); std::optional first(TypePackId tp); TypePackVar* asMutable(TypePackId tp); diff --git a/Analysis/include/Luau/TypeVar.h b/Analysis/include/Luau/TypeVar.h index 29578dcd9..b8c4b362f 100644 --- a/Analysis/include/Luau/TypeVar.h +++ b/Analysis/include/Luau/TypeVar.h @@ -488,6 +488,9 @@ const TableTypeVar* getTableType(TypeId type); // Returns nullptr if the type has no name. const std::string* getName(TypeId type); +// Returns name of the module where type was defined if type has that information +std::optional getDefinitionModuleName(TypeId type); + // Checks whether a union contains all types of another union. bool isSubset(const UnionTypeVar& super, const UnionTypeVar& sub); diff --git a/Analysis/include/Luau/Unifier.h b/Analysis/include/Luau/Unifier.h index f1ffbcc01..474af50cc 100644 --- a/Analysis/include/Luau/Unifier.h +++ b/Analysis/include/Luau/Unifier.h @@ -90,7 +90,9 @@ struct Unifier TypeId deeplyOptional(TypeId ty, std::unordered_map seen = {}); - void cacheResult(TypeId subTy, TypeId superTy); + bool canCacheResult(TypeId subTy, TypeId superTy); + void cacheResult(TypeId subTy, TypeId superTy, size_t prevErrorCount); + void cacheResult_DEPRECATED(TypeId subTy, TypeId superTy); public: void tryUnify(TypePackId subTy, TypePackId superTy, bool isFunctionCall = false); diff --git a/Analysis/include/Luau/UnifierSharedState.h b/Analysis/include/Luau/UnifierSharedState.h index 88997c41a..9a3ba56d1 100644 --- a/Analysis/include/Luau/UnifierSharedState.h +++ b/Analysis/include/Luau/UnifierSharedState.h @@ -2,6 +2,7 @@ #pragma once #include "Luau/DenseHash.h" +#include "Luau/Error.h" #include "Luau/TypeVar.h" #include "Luau/TypePack.h" @@ -42,6 +43,7 @@ struct UnifierSharedState DenseHashSet seenAny{nullptr}; DenseHashMap skipCacheForType{nullptr}; DenseHashSet, TypeIdPairHash> cachedUnify{{nullptr, nullptr}}; + DenseHashMap, TypeErrorData, TypeIdPairHash> cachedUnifyError{{nullptr, nullptr}}; DenseHashSet tempSeenTy{nullptr}; DenseHashSet tempSeenTp{nullptr}; diff --git a/Analysis/src/Error.cpp b/Analysis/src/Error.cpp index 26d3b76da..210c0191a 100644 --- a/Analysis/src/Error.cpp +++ b/Analysis/src/Error.cpp @@ -8,6 +8,7 @@ #include LUAU_FASTFLAGVARIABLE(BetterDiagnosticCodesInStudio, false); +LUAU_FASTFLAGVARIABLE(LuauTypeMismatchModuleName, false); static std::string wrongNumberOfArgsString(size_t expectedCount, size_t actualCount, const char* argPrefix = nullptr, bool isVariadic = false) { @@ -53,7 +54,32 @@ struct ErrorConverter { std::string operator()(const Luau::TypeMismatch& tm) const { - std::string result = "Type '" + Luau::toString(tm.givenType) + "' could not be converted into '" + Luau::toString(tm.wantedType) + "'"; + std::string givenTypeName = Luau::toString(tm.givenType); + std::string wantedTypeName = Luau::toString(tm.wantedType); + + std::string result; + + if (FFlag::LuauTypeMismatchModuleName) + { + if (givenTypeName == wantedTypeName) + { + if (auto givenDefinitionModule = getDefinitionModuleName(tm.givenType)) + { + if (auto wantedDefinitionModule = getDefinitionModuleName(tm.wantedType)) + { + result = "Type '" + givenTypeName + "' from '" + *givenDefinitionModule + "' could not be converted into '" + wantedTypeName + + "' from '" + *wantedDefinitionModule + "'"; + } + } + } + + if (result.empty()) + result = "Type '" + givenTypeName + "' could not be converted into '" + wantedTypeName + "'"; + } + else + { + result = "Type '" + givenTypeName + "' could not be converted into '" + wantedTypeName + "'"; + } if (tm.error) { @@ -147,7 +173,7 @@ struct ErrorConverter return "Function only returns " + std::to_string(e.expected) + " value" + expectedS + ". " + std::to_string(e.actual) + " are required here"; case CountMismatch::Arg: - return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual); + return "Argument count mismatch. Function " + wrongNumberOfArgsString(e.expected, e.actual, /*argPrefix*/ nullptr, e.isVariadic); } LUAU_ASSERT(!"Unknown context"); diff --git a/Analysis/src/Linter.cpp b/Analysis/src/Linter.cpp index 56c4e3e89..b7480e345 100644 --- a/Analysis/src/Linter.cpp +++ b/Analysis/src/Linter.cpp @@ -14,6 +14,7 @@ LUAU_FASTINTVARIABLE(LuauSuggestionDistance, 4) LUAU_FASTFLAGVARIABLE(LuauLintGlobalNeverReadBeforeWritten, false) +LUAU_FASTFLAGVARIABLE(LuauLintNoRobloxBits, false) namespace Luau { @@ -1135,16 +1136,20 @@ class LintUnknownType : AstVisitor enum TypeKind { - Kind_Invalid, + Kind_Unknown, Kind_Primitive, // primitive type supported by VM - boolean/userdata/etc. No differentiation between types of userdata. - Kind_Vector, // For 'vector' but only used when type is used - Kind_Userdata, // custom userdata type - Vector3/etc. + Kind_Vector, // 'vector' but only used when type is used + Kind_Userdata, // custom userdata type + + // TODO: remove these with LuauLintNoRobloxBits Kind_Class, // custom userdata type that reflects Roblox Instance-derived hierarchy - Part/etc. Kind_Enum, // custom userdata type referring to an enum item of enum classes, e.g. Enum.NormalId.Back/Enum.Axis.X/etc. }; bool containsPropName(TypeId ty, const std::string& propName) { + LUAU_ASSERT(!FFlag::LuauLintNoRobloxBits); + if (auto ctv = get(ty)) return lookupClassProp(ctv, propName) != nullptr; @@ -1163,13 +1168,23 @@ class LintUnknownType : AstVisitor if (name == "vector") return Kind_Vector; - if (std::optional maybeTy = context->scope->lookupType(name)) - // Kind_Userdata is probably not 100% precise but is close enough - return containsPropName(maybeTy->type, "ClassName") ? Kind_Class : Kind_Userdata; - else if (std::optional maybeTy = context->scope->lookupImportedType("Enum", name)) - return Kind_Enum; + if (FFlag::LuauLintNoRobloxBits) + { + if (std::optional maybeTy = context->scope->lookupType(name)) + return Kind_Userdata; - return Kind_Invalid; + return Kind_Unknown; + } + else + { + if (std::optional maybeTy = context->scope->lookupType(name)) + // Kind_Userdata is probably not 100% precise but is close enough + return containsPropName(maybeTy->type, "ClassName") ? Kind_Class : Kind_Userdata; + else if (std::optional maybeTy = context->scope->lookupImportedType("Enum", name)) + return Kind_Enum; + + return Kind_Unknown; + } } void validateType(AstExprConstantString* expr, std::initializer_list expected, const char* expectedString) @@ -1177,7 +1192,7 @@ class LintUnknownType : AstVisitor std::string name(expr->value.data, expr->value.size); TypeKind kind = getTypeKind(name); - if (kind == Kind_Invalid) + if (kind == Kind_Unknown) { emitWarning(*context, LintWarning::Code_UnknownType, expr->location, "Unknown type '%s'", name.c_str()); return; @@ -1189,7 +1204,7 @@ class LintUnknownType : AstVisitor return; // as a special case, Instance and EnumItem are both a userdata type (as returned by typeof) and a class type - if (ek == Kind_Userdata && (name == "Instance" || name == "EnumItem")) + if (!FFlag::LuauLintNoRobloxBits && ek == Kind_Userdata && (name == "Instance" || name == "EnumItem")) return; } @@ -1198,12 +1213,18 @@ class LintUnknownType : AstVisitor bool acceptsClassName(AstName method) { + LUAU_ASSERT(!FFlag::LuauLintNoRobloxBits); + return method.value[0] == 'F' && (method == "FindFirstChildOfClass" || method == "FindFirstChildWhichIsA" || method == "FindFirstAncestorOfClass" || method == "FindFirstAncestorWhichIsA"); } bool visit(AstExprCall* node) override { + // TODO: Simply remove the override + if (FFlag::LuauLintNoRobloxBits) + return true; + if (AstExprIndexName* index = node->func->as()) { AstExprConstantString* arg0 = node->args.size > 0 ? node->args.data[0]->as() : NULL; diff --git a/Analysis/src/Module.cpp b/Analysis/src/Module.cpp index a330a98d1..0787d3a42 100644 --- a/Analysis/src/Module.cpp +++ b/Analysis/src/Module.cpp @@ -12,10 +12,8 @@ #include LUAU_FASTFLAGVARIABLE(DebugLuauFreezeArena, false) -LUAU_FASTFLAGVARIABLE(DebugLuauTrackOwningArena, false) // Remove with FFlagLuauImmutableTypes LUAU_FASTINTVARIABLE(LuauTypeCloneRecursionLimit, 300) LUAU_FASTFLAGVARIABLE(LuauCloneDeclaredGlobals, false) -LUAU_FASTFLAG(LuauImmutableTypes) namespace Luau { @@ -65,8 +63,7 @@ TypeId TypeArena::addTV(TypeVar&& tv) { TypeId allocated = typeVars.allocate(std::move(tv)); - if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) - asMutable(allocated)->owningArena = this; + asMutable(allocated)->owningArena = this; return allocated; } @@ -75,8 +72,7 @@ TypeId TypeArena::freshType(TypeLevel level) { TypeId allocated = typeVars.allocate(FreeTypeVar{level}); - if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) - asMutable(allocated)->owningArena = this; + asMutable(allocated)->owningArena = this; return allocated; } @@ -85,8 +81,7 @@ TypePackId TypeArena::addTypePack(std::initializer_list types) { TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); - if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) - asMutable(allocated)->owningArena = this; + asMutable(allocated)->owningArena = this; return allocated; } @@ -95,8 +90,7 @@ TypePackId TypeArena::addTypePack(std::vector types) { TypePackId allocated = typePacks.allocate(TypePack{std::move(types)}); - if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) - asMutable(allocated)->owningArena = this; + asMutable(allocated)->owningArena = this; return allocated; } @@ -105,8 +99,7 @@ TypePackId TypeArena::addTypePack(TypePack tp) { TypePackId allocated = typePacks.allocate(std::move(tp)); - if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) - asMutable(allocated)->owningArena = this; + asMutable(allocated)->owningArena = this; return allocated; } @@ -115,8 +108,7 @@ TypePackId TypeArena::addTypePack(TypePackVar tp) { TypePackId allocated = typePacks.allocate(std::move(tp)); - if (FFlag::DebugLuauTrackOwningArena || FFlag::LuauImmutableTypes) - asMutable(allocated)->owningArena = this; + asMutable(allocated)->owningArena = this; return allocated; } @@ -439,16 +431,9 @@ TypeId clone(TypeId typeId, TypeArena& dest, SeenTypes& seenTypes, SeenTypePacks TypeCloner cloner{dest, typeId, seenTypes, seenTypePacks, cloneState}; Luau::visit(cloner, typeId->ty); // Mutates the storage that 'res' points into. - if (FFlag::LuauImmutableTypes) - { - // Persistent types are not being cloned and we get the original type back which might be read-only - if (!res->persistent) - asMutable(res)->documentationSymbol = typeId->documentationSymbol; - } - else - { + // Persistent types are not being cloned and we get the original type back which might be read-only + if (!res->persistent) asMutable(res)->documentationSymbol = typeId->documentationSymbol; - } } return res; diff --git a/Analysis/src/ToString.cpp b/Analysis/src/ToString.cpp index 010ca3612..59ee6de20 100644 --- a/Analysis/src/ToString.cpp +++ b/Analysis/src/ToString.cpp @@ -16,6 +16,7 @@ * Fair warning: Setting this will break a lot of Luau unit tests. */ LUAU_FASTFLAGVARIABLE(DebugLuauVerboseTypeNames, false) +LUAU_FASTFLAGVARIABLE(LuauDocFuncParameters, false) namespace Luau { @@ -769,6 +770,7 @@ struct TypePackStringifier else state.emit(", "); + // Do not respect opts.namedFunctionOverrideArgNames here if (elemIndex < elemNames.size() && elemNames[elemIndex]) { state.emit(elemNames[elemIndex]->name); @@ -1090,13 +1092,13 @@ std::string toString(const TypePackVar& tp, const ToStringOptions& opts) return toString(const_cast(&tp), std::move(opts)); } -std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeVar& ftv, ToStringOptions opts) +std::string toStringNamedFunction(const std::string& funcName, const FunctionTypeVar& ftv, const ToStringOptions& opts) { ToStringResult result; StringifierState state(opts, result, opts.nameMap); TypeVarStringifier tvs{state}; - state.emit(prefix); + state.emit(funcName); if (!opts.hideNamedFunctionTypeParameters) tvs.stringify(ftv.generics, ftv.genericPacks); @@ -1104,28 +1106,59 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV state.emit("("); auto argPackIter = begin(ftv.argTypes); - auto argNameIter = ftv.argNames.begin(); bool first = true; - while (argPackIter != end(ftv.argTypes)) + if (FFlag::LuauDocFuncParameters) { - if (!first) - state.emit(", "); - first = false; - - // We don't currently respect opts.functionTypeArguments. I don't think this function should. - if (argNameIter != ftv.argNames.end()) + size_t idx = 0; + while (argPackIter != end(ftv.argTypes)) { - state.emit((*argNameIter ? (*argNameIter)->name : "_") + ": "); - ++argNameIter; + if (!first) + state.emit(", "); + first = false; + + // We don't respect opts.functionTypeArguments + if (idx < opts.namedFunctionOverrideArgNames.size()) + { + state.emit(opts.namedFunctionOverrideArgNames[idx] + ": "); + } + else if (idx < ftv.argNames.size() && ftv.argNames[idx]) + { + state.emit(ftv.argNames[idx]->name + ": "); + } + else + { + state.emit("_: "); + } + tvs.stringify(*argPackIter); + + ++argPackIter; + ++idx; } - else + } + else + { + auto argNameIter = ftv.argNames.begin(); + while (argPackIter != end(ftv.argTypes)) { - state.emit("_: "); - } + if (!first) + state.emit(", "); + first = false; - tvs.stringify(*argPackIter); - ++argPackIter; + // We don't currently respect opts.functionTypeArguments. I don't think this function should. + if (argNameIter != ftv.argNames.end()) + { + state.emit((*argNameIter ? (*argNameIter)->name : "_") + ": "); + ++argNameIter; + } + else + { + state.emit("_: "); + } + + tvs.stringify(*argPackIter); + ++argPackIter; + } } if (argPackIter.tail()) @@ -1134,7 +1167,6 @@ std::string toStringNamedFunction(const std::string& prefix, const FunctionTypeV state.emit(", "); state.emit("...: "); - if (auto vtp = get(*argPackIter.tail())) tvs.stringify(vtp->ty); else diff --git a/Analysis/src/TypeInfer.cpp b/Analysis/src/TypeInfer.cpp index 41e8ce55f..9965d5aaf 100644 --- a/Analysis/src/TypeInfer.cpp +++ b/Analysis/src/TypeInfer.cpp @@ -27,10 +27,8 @@ LUAU_FASTFLAGVARIABLE(LuauWeakEqConstraint, false) // Eventually removed as fals LUAU_FASTFLAGVARIABLE(DebugLuauFreezeDuringUnification, false) LUAU_FASTFLAGVARIABLE(LuauRecursiveTypeParameterRestriction, false) LUAU_FASTFLAGVARIABLE(LuauGenericFunctionsDontCacheTypeParams, false) -LUAU_FASTFLAGVARIABLE(LuauImmutableTypes, false) LUAU_FASTFLAGVARIABLE(LuauSealExports, false) LUAU_FASTFLAGVARIABLE(LuauSelfCallAutocompleteFix, false) -LUAU_FASTFLAGVARIABLE(LuauSingletonTypes, false) LUAU_FASTFLAGVARIABLE(LuauDiscriminableUnions2, false) LUAU_FASTFLAGVARIABLE(LuauExpectedTypesOfProperties, false) LUAU_FASTFLAGVARIABLE(LuauErrorRecoveryType, false) @@ -38,6 +36,7 @@ LUAU_FASTFLAGVARIABLE(LuauOnlyMutateInstantiatedTables, false) LUAU_FASTFLAGVARIABLE(LuauPropertiesGetExpectedType, false) LUAU_FASTFLAGVARIABLE(LuauStatFunctionSimplify2, false) LUAU_FASTFLAGVARIABLE(LuauUnsealedTableLiteral, false) +LUAU_FASTFLAG(LuauTypeMismatchModuleName) LUAU_FASTFLAGVARIABLE(LuauTwoPassAliasDefinitionFix, false) LUAU_FASTFLAGVARIABLE(LuauAssertStripsFalsyTypes, false) LUAU_FASTFLAGVARIABLE(LuauReturnAnyInsteadOfICE, false) // Eventually removed as false. @@ -47,6 +46,8 @@ LUAU_FASTFLAGVARIABLE(LuauDoNotAccidentallyDependOnPointerOrdering, false) LUAU_FASTFLAGVARIABLE(LuauFixArgumentCountMismatchAmountWithGenericTypes, false) LUAU_FASTFLAGVARIABLE(LuauFixIncorrectLineNumberDuplicateType, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) +LUAU_FASTFLAGVARIABLE(LuauDecoupleOperatorInferenceFromUnifiedTypeInference, false) +LUAU_FASTFLAGVARIABLE(LuauArgCountMismatchSaysAtLeastWhenVariadic, false) namespace Luau { @@ -291,6 +292,7 @@ ModulePtr TypeChecker::check(const SourceModule& module, Mode mode, std::optiona // Clear unifier cache since it's keyed off internal types that get deallocated // This avoids fake cross-module cache hits and keeps cache size at bay when typechecking large module graphs. unifierState.cachedUnify.clear(); + unifierState.cachedUnifyError.clear(); unifierState.skipCacheForType.clear(); if (FFlag::LuauTwoPassAliasDefinitionFix) @@ -1303,7 +1305,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias { // If the table is already named and we want to rename the type function, we have to bind new alias to a copy // Additionally, we can't modify types that come from other modules - if (ttv->name || (FFlag::LuauImmutableTypes && follow(ty)->owningArena != ¤tModule->internalTypes)) + if (ttv->name || follow(ty)->owningArena != ¤tModule->internalTypes) { bool sameTys = std::equal(ttv->instantiatedTypeParams.begin(), ttv->instantiatedTypeParams.end(), binding->typeParams.begin(), binding->typeParams.end(), [](auto&& itp, auto&& tp) { @@ -1315,7 +1317,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias }); // Copy can be skipped if this is an identical alias - if ((FFlag::LuauImmutableTypes && !ttv->name) || ttv->name != name || !sameTys || !sameTps) + if (!ttv->name || ttv->name != name || !sameTys || !sameTps) { // This is a shallow clone, original recursive links to self are not updated TableTypeVar clone = TableTypeVar{ttv->props, ttv->indexer, ttv->level, ttv->state}; @@ -1349,7 +1351,7 @@ void TypeChecker::check(const ScopePtr& scope, const AstStatTypeAlias& typealias else if (auto mtv = getMutable(follow(ty))) { // We can't modify types that come from other modules - if (!FFlag::LuauImmutableTypes || follow(ty)->owningArena == ¤tModule->internalTypes) + if (follow(ty)->owningArena == ¤tModule->internalTypes) mtv->syntheticName = name; } @@ -1512,14 +1514,14 @@ ExprResult TypeChecker::checkExpr(const ScopePtr& scope, const AstExpr& result = {nilType}; else if (const AstExprConstantBool* bexpr = expr.as()) { - if (FFlag::LuauSingletonTypes && (forceSingleton || (expectedType && maybeSingleton(*expectedType)))) + if (forceSingleton || (expectedType && maybeSingleton(*expectedType))) result = {singletonType(bexpr->value)}; else result = {booleanType}; } else if (const AstExprConstantString* sexpr = expr.as()) { - if (FFlag::LuauSingletonTypes && (forceSingleton || (expectedType && maybeSingleton(*expectedType)))) + if (forceSingleton || (expectedType && maybeSingleton(*expectedType))) result = {singletonType(std::string(sexpr->value.data, sexpr->value.size))}; else result = {stringType}; @@ -2490,12 +2492,24 @@ TypeId TypeChecker::checkBinaryOperation( lhsType = follow(lhsType); rhsType = follow(rhsType); - if (!isNonstrictMode() && get(lhsType)) + if (FFlag::LuauDecoupleOperatorInferenceFromUnifiedTypeInference) { - auto name = getIdentifierOfBaseVar(expr.left); - reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); - if (!FFlag::LuauErrorRecoveryType) - return errorRecoveryType(scope); + if (!isNonstrictMode() && get(lhsType)) + { + auto name = getIdentifierOfBaseVar(expr.left); + reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); + // We will fall-through to the `return anyType` check below. + } + } + else + { + if (!isNonstrictMode() && get(lhsType)) + { + auto name = getIdentifierOfBaseVar(expr.left); + reportError(expr.location, CannotInferBinaryOperation{expr.op, name, CannotInferBinaryOperation::Operation}); + if (!FFlag::LuauErrorRecoveryType) + return errorRecoveryType(scope); + } } // If we know nothing at all about the lhs type, we can usually say nothing about the result. @@ -3452,7 +3466,8 @@ void TypeChecker::checkArgumentList( { if (FFlag::LuauFixArgumentCountMismatchAmountWithGenericTypes) minParams = getMinParameterCount(&state.log, paramPack); - state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex}}); + bool isVariadic = FFlag::LuauArgCountMismatchSaysAtLeastWhenVariadic && !finite(paramPack, &state.log); + state.reportError(TypeError{state.location, CountMismatch{minParams, paramIndex, CountMismatch::Context::Arg, isVariadic}}); return; } ++paramIter; @@ -4163,13 +4178,7 @@ TypeId TypeChecker::checkRequire(const ScopePtr& scope, const ModuleInfo& module return errorRecoveryType(scope); } - if (FFlag::LuauImmutableTypes) - return *moduleType; - - SeenTypes seenTypes; - SeenTypePacks seenTypePacks; - CloneState cloneState; - return clone(*moduleType, currentModule->internalTypes, seenTypes, seenTypePacks, cloneState); + return *moduleType; } void TypeChecker::tablify(TypeId type) @@ -4941,10 +4950,19 @@ TypeId TypeChecker::resolveType(const ScopePtr& scope, const AstType& annotation if (const auto& indexer = table->indexer) tableIndexer = TableIndexer(resolveType(scope, *indexer->indexType), resolveType(scope, *indexer->resultType)); - return addType(TableTypeVar{ - props, tableIndexer, scope->level, - TableState::Sealed // FIXME: probably want a way to annotate other kinds of tables maybe - }); + if (FFlag::LuauTypeMismatchModuleName) + { + TableTypeVar ttv{props, tableIndexer, scope->level, TableState::Sealed}; + ttv.definitionModuleName = currentModuleName; + return addType(std::move(ttv)); + } + else + { + return addType(TableTypeVar{ + props, tableIndexer, scope->level, + TableState::Sealed // FIXME: probably want a way to annotate other kinds of tables maybe + }); + } } else if (const auto& func = annotation.as()) { @@ -5206,6 +5224,9 @@ TypeId TypeChecker::instantiateTypeFun(const ScopePtr& scope, const TypeFun& tf, { ttv->instantiatedTypeParams = typeParams; ttv->instantiatedTypePackParams = typePackParams; + + if (FFlag::LuauTypeMismatchModuleName) + ttv->definitionModuleName = currentModuleName; } return instantiated; diff --git a/Analysis/src/TypePack.cpp b/Analysis/src/TypePack.cpp index 91123f468..5bb052349 100644 --- a/Analysis/src/TypePack.cpp +++ b/Analysis/src/TypePack.cpp @@ -222,20 +222,21 @@ TypePackId follow(TypePackId tp, std::function mapper) } } -size_t size(TypePackId tp) +size_t size(TypePackId tp, TxnLog* log) { - if (auto pack = get(follow(tp))) - return size(*pack); + tp = log ? log->follow(tp) : follow(tp); + if (auto pack = get(tp)) + return size(*pack, log); else return 0; } -bool finite(TypePackId tp) +bool finite(TypePackId tp, TxnLog* log) { - tp = follow(tp); + tp = log ? log->follow(tp) : follow(tp); if (auto pack = get(tp)) - return pack->tail ? finite(*pack->tail) : true; + return pack->tail ? finite(*pack->tail, log) : true; if (get(tp)) return false; @@ -243,14 +244,14 @@ bool finite(TypePackId tp) return true; } -size_t size(const TypePack& tp) +size_t size(const TypePack& tp, TxnLog* log) { size_t result = tp.head.size(); if (tp.tail) { - const TypePack* tail = get(follow(*tp.tail)); + const TypePack* tail = get(log ? log->follow(*tp.tail) : follow(*tp.tail)); if (tail) - result += size(*tail); + result += size(*tail, log); } return result; } diff --git a/Analysis/src/TypeVar.cpp b/Analysis/src/TypeVar.cpp index 895495352..36545ad90 100644 --- a/Analysis/src/TypeVar.cpp +++ b/Analysis/src/TypeVar.cpp @@ -290,6 +290,24 @@ const std::string* getName(TypeId type) return nullptr; } +std::optional getDefinitionModuleName(TypeId type) +{ + type = follow(type); + + if (auto ttv = get(type)) + { + if (!ttv->definitionModuleName.empty()) + return ttv->definitionModuleName; + } + else if (auto ftv = get(type)) + { + if (ftv->definition) + return ftv->definition->definitionModuleName; + } + + return std::nullopt; +} + bool isSubset(const UnionTypeVar& super, const UnionTypeVar& sub) { std::unordered_set superTypes; diff --git a/Analysis/src/Unifier.cpp b/Analysis/src/Unifier.cpp index 60a9c9a5d..398dc9e25 100644 --- a/Analysis/src/Unifier.cpp +++ b/Analysis/src/Unifier.cpp @@ -14,10 +14,9 @@ LUAU_FASTINT(LuauTypeInferRecursionLimit); LUAU_FASTINT(LuauTypeInferTypePackLoopLimit); -LUAU_FASTFLAG(LuauImmutableTypes) LUAU_FASTINTVARIABLE(LuauTypeInferIterationLimit, 2000); +LUAU_FASTFLAGVARIABLE(LuauExtendedIndexerError, false); LUAU_FASTFLAGVARIABLE(LuauTableSubtypingVariance2, false); -LUAU_FASTFLAG(LuauSingletonTypes) LUAU_FASTFLAG(LuauErrorRecoveryType); LUAU_FASTFLAGVARIABLE(LuauSubtypingAddOptPropsToUnsealedTables, false) LUAU_FASTFLAGVARIABLE(LuauWidenIfSupertypeIsFree2, false) @@ -26,6 +25,7 @@ LUAU_FASTFLAGVARIABLE(LuauTxnLogSeesTypePacks2, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogCheckForInvalidation, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogRefreshFunctionPointers, false) LUAU_FASTFLAGVARIABLE(LuauTxnLogDontRetryForIndexers, false) +LUAU_FASTFLAGVARIABLE(LuauUnifierCacheErrors, false) LUAU_FASTFLAG(LuauAnyInIsOptionalIsOptional) namespace Luau @@ -63,7 +63,7 @@ struct PromoteTypeLevels bool operator()(TID ty, const T&) { // Type levels of types from other modules are already global, so we don't need to promote anything inside - if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + if (ty->owningArena != typeArena) return false; return true; @@ -83,7 +83,7 @@ struct PromoteTypeLevels bool operator()(TypeId ty, const FunctionTypeVar&) { // Type levels of types from other modules are already global, so we don't need to promote anything inside - if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + if (ty->owningArena != typeArena) return false; promote(ty, log.getMutable(ty)); @@ -93,7 +93,7 @@ struct PromoteTypeLevels bool operator()(TypeId ty, const TableTypeVar& ttv) { // Type levels of types from other modules are already global, so we don't need to promote anything inside - if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + if (ty->owningArena != typeArena) return false; if (ttv.state != TableState::Free && ttv.state != TableState::Generic) @@ -118,7 +118,7 @@ struct PromoteTypeLevels static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypeId ty) { // Type levels of types from other modules are already global, so we don't need to promote anything inside - if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + if (ty->owningArena != typeArena) return; PromoteTypeLevels ptl{log, typeArena, minLevel}; @@ -130,7 +130,7 @@ static void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel void promoteTypeLevels(TxnLog& log, const TypeArena* typeArena, TypeLevel minLevel, TypePackId tp) { // Type levels of types from other modules are already global, so we don't need to promote anything inside - if (FFlag::LuauImmutableTypes && tp->owningArena != typeArena) + if (tp->owningArena != typeArena) return; PromoteTypeLevels ptl{log, typeArena, minLevel}; @@ -170,7 +170,7 @@ struct SkipCacheForType bool operator()(TypeId ty, const TableTypeVar&) { // Types from other modules don't contain mutable elements and are ok to cache - if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + if (ty->owningArena != typeArena) return false; TableTypeVar& ttv = *getMutable(ty); @@ -194,7 +194,7 @@ struct SkipCacheForType bool operator()(TypeId ty, const T& t) { // Types from other modules don't contain mutable elements and are ok to cache - if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + if (ty->owningArena != typeArena) return false; const bool* prev = skipCacheForType.find(ty); @@ -212,7 +212,7 @@ struct SkipCacheForType bool operator()(TypePackId tp, const T&) { // Types from other modules don't contain mutable elements and are ok to cache - if (FFlag::LuauImmutableTypes && tp->owningArena != typeArena) + if (tp->owningArena != typeArena) return false; return true; @@ -445,12 +445,33 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool if (get(subTy) || get(subTy)) return tryUnifyWithAny(superTy, subTy); - bool cacheEnabled = !isFunctionCall && !isIntersection; + bool cacheEnabled; auto& cache = sharedState.cachedUnify; // What if the types are immutable and we proved their relation before - if (cacheEnabled && cache.contains({superTy, subTy}) && (variance == Covariant || cache.contains({subTy, superTy}))) - return; + if (FFlag::LuauUnifierCacheErrors) + { + cacheEnabled = !isFunctionCall && !isIntersection && variance == Invariant; + + if (cacheEnabled) + { + if (cache.contains({subTy, superTy})) + return; + + if (auto error = sharedState.cachedUnifyError.find({subTy, superTy})) + { + reportError(TypeError{location, *error}); + return; + } + } + } + else + { + cacheEnabled = !isFunctionCall && !isIntersection; + + if (cacheEnabled && cache.contains({superTy, subTy}) && (variance == Covariant || cache.contains({subTy, superTy}))) + return; + } // If we have seen this pair of types before, we are currently recursing into cyclic types. // Here, we assume that the types unify. If they do not, we will find out as we roll back @@ -461,6 +482,8 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool log.pushSeen(superTy, subTy); + size_t errorCount = errors.size(); + if (const UnionTypeVar* uv = log.getMutable(subTy)) { tryUnifyUnionWithType(subTy, uv, superTy); @@ -480,8 +503,7 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else if (log.getMutable(superTy) && log.getMutable(subTy)) tryUnifyPrimitives(subTy, superTy); - else if (FFlag::LuauSingletonTypes && (log.getMutable(superTy) || log.getMutable(superTy)) && - log.getMutable(subTy)) + else if ((log.getMutable(superTy) || log.getMutable(superTy)) && log.getMutable(subTy)) tryUnifySingletons(subTy, superTy); else if (log.getMutable(superTy) && log.getMutable(subTy)) @@ -491,8 +513,11 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool { tryUnifyTables(subTy, superTy, isIntersection); - if (cacheEnabled && errors.empty()) - cacheResult(subTy, superTy); + if (!FFlag::LuauUnifierCacheErrors) + { + if (cacheEnabled && errors.empty()) + cacheResult_DEPRECATED(subTy, superTy); + } } // tryUnifyWithMetatable assumes its first argument is a MetatableTypeVar. The check is otherwise symmetrical. @@ -512,6 +537,9 @@ void Unifier::tryUnify_(TypeId subTy, TypeId superTy, bool isFunctionCall, bool else reportError(TypeError{location, TypeMismatch{superTy, subTy}}); + if (FFlag::LuauUnifierCacheErrors && cacheEnabled) + cacheResult(subTy, superTy, errorCount); + log.popSeen(superTy, subTy); } @@ -646,10 +674,21 @@ void Unifier::tryUnifyTypeWithUnion(TypeId subTy, TypeId superTy, const UnionTyp { TypeId type = uv->options[i]; - if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type}))) + if (FFlag::LuauUnifierCacheErrors) { - startIndex = i; - break; + if (cache.contains({subTy, type})) + { + startIndex = i; + break; + } + } + else + { + if (cache.contains({type, subTy}) && (variance == Covariant || cache.contains({subTy, type}))) + { + startIndex = i; + break; + } } } } @@ -737,10 +776,21 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV { TypeId type = uv->parts[i]; - if (cache.contains({superTy, type}) && (variance == Covariant || cache.contains({type, superTy}))) + if (FFlag::LuauUnifierCacheErrors) { - startIndex = i; - break; + if (cache.contains({type, superTy})) + { + startIndex = i; + break; + } + } + else + { + if (cache.contains({superTy, type}) && (variance == Covariant || cache.contains({type, superTy}))) + { + startIndex = i; + break; + } } } } @@ -771,17 +821,17 @@ void Unifier::tryUnifyIntersectionWithType(TypeId subTy, const IntersectionTypeV } } -void Unifier::cacheResult(TypeId subTy, TypeId superTy) +bool Unifier::canCacheResult(TypeId subTy, TypeId superTy) { bool* superTyInfo = sharedState.skipCacheForType.find(superTy); if (superTyInfo && *superTyInfo) - return; + return false; bool* subTyInfo = sharedState.skipCacheForType.find(subTy); if (subTyInfo && *subTyInfo) - return; + return false; auto skipCacheFor = [this](TypeId ty) { SkipCacheForType visitor{sharedState.skipCacheForType, types}; @@ -793,9 +843,33 @@ void Unifier::cacheResult(TypeId subTy, TypeId superTy) }; if (!superTyInfo && skipCacheFor(superTy)) - return; + return false; if (!subTyInfo && skipCacheFor(subTy)) + return false; + + return true; +} + +void Unifier::cacheResult(TypeId subTy, TypeId superTy, size_t prevErrorCount) +{ + if (errors.size() == prevErrorCount) + { + if (canCacheResult(subTy, superTy)) + sharedState.cachedUnify.insert({subTy, superTy}); + } + else if (errors.size() == prevErrorCount + 1) + { + if (canCacheResult(subTy, superTy)) + sharedState.cachedUnifyError[{subTy, superTy}] = errors.back().data; + } +} + +void Unifier::cacheResult_DEPRECATED(TypeId subTy, TypeId superTy) +{ + LUAU_ASSERT(!FFlag::LuauUnifierCacheErrors); + + if (!canCacheResult(subTy, superTy)) return; sharedState.cachedUnify.insert({superTy, subTy}); @@ -1283,24 +1357,6 @@ void Unifier::tryUnifyFunctions(TypeId subTy, TypeId superTy, bool isFunctionCal subFunction = log.getMutable(subTy); } - if (!FFlag::LuauImmutableTypes) - { - if (superFunction->definition && !subFunction->definition && !subTy->persistent) - { - PendingType* newSubTy = log.queue(subTy); - FunctionTypeVar* newSubFtv = getMutable(newSubTy); - LUAU_ASSERT(newSubFtv); - newSubFtv->definition = superFunction->definition; - } - else if (!superFunction->definition && subFunction->definition && !superTy->persistent) - { - PendingType* newSuperTy = log.queue(superTy); - FunctionTypeVar* newSuperFtv = getMutable(newSuperTy); - LUAU_ASSERT(newSuperFtv); - newSuperFtv->definition = subFunction->definition; - } - } - ctx = context; if (FFlag::LuauTxnLogSeesTypePacks2) @@ -1563,8 +1619,25 @@ void Unifier::tryUnifyTables(TypeId subTy, TypeId superTy, bool isIntersection) variance = Invariant; Unifier innerState = makeChildUnifier(); - innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); - checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); + + if (FFlag::LuauExtendedIndexerError) + { + innerState.tryUnify_(subTable->indexer->indexType, superTable->indexer->indexType); + + bool reported = !innerState.errors.empty(); + + checkChildUnifierTypeMismatch(innerState.errors, "[indexer key]", superTy, subTy); + + innerState.tryUnify_(subTable->indexer->indexResultType, superTable->indexer->indexResultType); + + if (!reported) + checkChildUnifierTypeMismatch(innerState.errors, "[indexer value]", superTy, subTy); + } + else + { + innerState.tryUnifyIndexer(*subTable->indexer, *superTable->indexer); + checkChildUnifierTypeMismatch(innerState.errors, superTy, subTy); + } if (innerState.errors.empty()) log.concat(std::move(innerState.log)); @@ -1771,6 +1844,7 @@ void Unifier::DEPRECATED_tryUnifyTables(TypeId subTy, TypeId superTy, bool isInt void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) { + LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); TableTypeVar* freeTable = log.getMutable(superTy); TableTypeVar* subTable = log.getMutable(subTy); @@ -1840,6 +1914,7 @@ void Unifier::tryUnifyFreeTable(TypeId subTy, TypeId superTy) void Unifier::tryUnifySealedTables(TypeId subTy, TypeId superTy, bool isIntersection) { + LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2); TableTypeVar* superTable = log.getMutable(superTy); TableTypeVar* subTable = log.getMutable(subTy); @@ -2120,6 +2195,8 @@ void Unifier::tryUnifyWithClass(TypeId subTy, TypeId superTy, bool reversed) void Unifier::tryUnifyIndexer(const TableIndexer& subIndexer, const TableIndexer& superIndexer) { + LUAU_ASSERT(!FFlag::LuauTableSubtypingVariance2 || !FFlag::LuauExtendedIndexerError); + tryUnify_(subIndexer.indexType, superIndexer.indexType); tryUnify_(subIndexer.indexResultType, superIndexer.indexResultType); } @@ -2211,7 +2288,7 @@ static void tryUnifyWithAny(std::vector& queue, Unifier& state, DenseHas queue.pop_back(); // Types from other modules don't have free types - if (FFlag::LuauImmutableTypes && ty->owningArena != typeArena) + if (ty->owningArena != typeArena) continue; if (seen.find(ty)) diff --git a/Ast/src/Parser.cpp b/Ast/src/Parser.cpp index 941a3ea4f..f6dfd9046 100644 --- a/Ast/src/Parser.cpp +++ b/Ast/src/Parser.cpp @@ -10,8 +10,6 @@ // See docs/SyntaxChanges.md for an explanation. LUAU_FASTINTVARIABLE(LuauRecursionLimit, 1000) LUAU_FASTINTVARIABLE(LuauParseErrorLimit, 100) -LUAU_FASTFLAGVARIABLE(LuauParseSingletonTypes, false) -LUAU_FASTFLAGVARIABLE(LuauTableFieldFunctionDebugname, false) namespace Luau { @@ -1233,8 +1231,7 @@ AstType* Parser::parseTableTypeAnnotation() while (lexer.current().type != '}') { - if (FFlag::LuauParseSingletonTypes && lexer.current().type == '[' && - (lexer.lookahead().type == Lexeme::RawString || lexer.lookahead().type == Lexeme::QuotedString)) + if (lexer.current().type == '[' && (lexer.lookahead().type == Lexeme::RawString || lexer.lookahead().type == Lexeme::QuotedString)) { const Lexeme begin = lexer.current(); nextLexeme(); // [ @@ -1500,17 +1497,17 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) nextLexeme(); return {allocator.alloc(begin, std::nullopt, nameNil), {}}; } - else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::ReservedTrue) + else if (lexer.current().type == Lexeme::ReservedTrue) { nextLexeme(); return {allocator.alloc(begin, true)}; } - else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::ReservedFalse) + else if (lexer.current().type == Lexeme::ReservedFalse) { nextLexeme(); return {allocator.alloc(begin, false)}; } - else if (FFlag::LuauParseSingletonTypes && (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString)) + else if (lexer.current().type == Lexeme::RawString || lexer.current().type == Lexeme::QuotedString) { if (std::optional> value = parseCharArray()) { @@ -1520,7 +1517,7 @@ AstTypeOrPack Parser::parseSimpleTypeAnnotation(bool allowPack) else return {reportTypeAnnotationError(begin, {}, /*isMissing*/ false, "String literal contains malformed escape sequence")}; } - else if (FFlag::LuauParseSingletonTypes && lexer.current().type == Lexeme::BrokenString) + else if (lexer.current().type == Lexeme::BrokenString) { Location location = lexer.current().location; nextLexeme(); @@ -2189,11 +2186,8 @@ AstExpr* Parser::parseTableConstructor() AstExpr* key = allocator.alloc(name.location, nameString); AstExpr* value = parseExpr(); - if (FFlag::LuauTableFieldFunctionDebugname) - { - if (AstExprFunction* func = value->as()) - func->debugname = name.name; - } + if (AstExprFunction* func = value->as()) + func->debugname = name.name; items.push_back({AstExprTable::Item::Record, key, value}); } diff --git a/VM/src/lapi.cpp b/VM/src/lapi.cpp index 3c0873147..46b109347 100644 --- a/VM/src/lapi.cpp +++ b/VM/src/lapi.cpp @@ -14,8 +14,6 @@ #include -LUAU_FASTFLAG(LuauGcAdditionalStats) - const char* lua_ident = "$Lua: Lua 5.1.4 Copyright (C) 1994-2008 Lua.org, PUC-Rio $\n" "$Authors: R. Ierusalimschy, L. H. de Figueiredo & W. Celes $\n" "$URL: www.lua.org $\n"; @@ -1060,8 +1058,11 @@ int lua_gc(lua_State* L, int what, int data) g->GCthreshold = 0; bool waspaused = g->gcstate == GCSpause; - double startmarktime = g->gcstats.currcycle.marktime; - double startsweeptime = g->gcstats.currcycle.sweeptime; + +#ifdef LUAI_GCMETRICS + double startmarktime = g->gcmetrics.currcycle.marktime; + double startsweeptime = g->gcmetrics.currcycle.sweeptime; +#endif // track how much work the loop will actually perform size_t actualwork = 0; @@ -1079,30 +1080,29 @@ int lua_gc(lua_State* L, int what, int data) } } - if (FFlag::LuauGcAdditionalStats) - { - // record explicit step statistics - GCCycleStats* cyclestats = g->gcstate == GCSpause ? &g->gcstats.lastcycle : &g->gcstats.currcycle; +#ifdef LUAI_GCMETRICS + // record explicit step statistics + GCCycleMetrics* cyclemetrics = g->gcstate == GCSpause ? &g->gcmetrics.lastcycle : &g->gcmetrics.currcycle; - double totalmarktime = cyclestats->marktime - startmarktime; - double totalsweeptime = cyclestats->sweeptime - startsweeptime; + double totalmarktime = cyclemetrics->marktime - startmarktime; + double totalsweeptime = cyclemetrics->sweeptime - startsweeptime; - if (totalmarktime > 0.0) - { - cyclestats->markexplicitsteps++; + if (totalmarktime > 0.0) + { + cyclemetrics->markexplicitsteps++; - if (totalmarktime > cyclestats->markmaxexplicittime) - cyclestats->markmaxexplicittime = totalmarktime; - } + if (totalmarktime > cyclemetrics->markmaxexplicittime) + cyclemetrics->markmaxexplicittime = totalmarktime; + } - if (totalsweeptime > 0.0) - { - cyclestats->sweepexplicitsteps++; + if (totalsweeptime > 0.0) + { + cyclemetrics->sweepexplicitsteps++; - if (totalsweeptime > cyclestats->sweepmaxexplicittime) - cyclestats->sweepmaxexplicittime = totalsweeptime; - } + if (totalsweeptime > cyclemetrics->sweepmaxexplicittime) + cyclemetrics->sweepmaxexplicittime = totalsweeptime; } +#endif // if cycle hasn't finished, advance threshold forward for the amount of extra work performed if (g->gcstate != GCSpause) diff --git a/VM/src/ldo.cpp b/VM/src/ldo.cpp index b5ae496b5..c133a59e1 100644 --- a/VM/src/ldo.cpp +++ b/VM/src/ldo.cpp @@ -17,8 +17,6 @@ #include -LUAU_FASTFLAG(LuauReduceStackReallocs) - /* ** {====================================================== ** Error-recovery functions @@ -33,6 +31,15 @@ struct lua_jmpbuf jmp_buf buf; }; +/* use POSIX versions of setjmp/longjmp if possible: they don't save/restore signal mask and are therefore faster */ +#if defined(__linux__) || defined(__APPLE__) +#define LUAU_SETJMP(buf) _setjmp(buf) +#define LUAU_LONGJMP(buf, code) _longjmp(buf, code) +#else +#define LUAU_SETJMP(buf) setjmp(buf) +#define LUAU_LONGJMP(buf, code) longjmp(buf, code) +#endif + int luaD_rawrunprotected(lua_State* L, Pfunc f, void* ud) { lua_jmpbuf jb; @@ -40,7 +47,7 @@ int luaD_rawrunprotected(lua_State* L, Pfunc f, void* ud) jb.status = 0; L->global->errorjmp = &jb; - if (setjmp(jb.buf) == 0) + if (LUAU_SETJMP(jb.buf) == 0) f(L, ud); L->global->errorjmp = jb.prev; @@ -52,7 +59,7 @@ l_noret luaD_throw(lua_State* L, int errcode) if (lua_jmpbuf* jb = L->global->errorjmp) { jb->status = errcode; - longjmp(jb->buf, 1); + LUAU_LONGJMP(jb->buf, 1); } if (L->global->cb.panic) @@ -165,8 +172,8 @@ static void correctstack(lua_State* L, TValue* oldstack) void luaD_reallocstack(lua_State* L, int newsize) { TValue* oldstack = L->stack; - int realsize = newsize + (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK); - LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK)); + int realsize = newsize + EXTRA_STACK; + LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - EXTRA_STACK); luaM_reallocarray(L, L->stack, L->stacksize, realsize, TValue, L->memcat); TValue* newstack = L->stack; for (int i = L->stacksize; i < realsize; i++) @@ -514,7 +521,7 @@ static void callerrfunc(lua_State* L, void* ud) static void restore_stack_limit(lua_State* L) { - LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK)); + LUAU_ASSERT(L->stack_last - L->stack == L->stacksize - EXTRA_STACK); if (L->size_ci > LUAI_MAXCALLS) { /* there was an overflow? */ int inuse = cast_int(L->ci - L->base_ci); diff --git a/VM/src/ldo.h b/VM/src/ldo.h index 1c1480d68..6e16e6f16 100644 --- a/VM/src/ldo.h +++ b/VM/src/ldo.h @@ -11,7 +11,7 @@ if ((char*)L->stack_last - (char*)L->top <= (n) * (int)sizeof(TValue)) \ luaD_growstack(L, n); \ else \ - condhardstacktests(luaD_reallocstack(L, L->stacksize - (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK))); + condhardstacktests(luaD_reallocstack(L, L->stacksize - EXTRA_STACK)); #define incr_top(L) \ { \ diff --git a/VM/src/lgc.cpp b/VM/src/lgc.cpp index a656854ed..8fc930d54 100644 --- a/VM/src/lgc.cpp +++ b/VM/src/lgc.cpp @@ -11,8 +11,6 @@ #include "lmem.h" #include "ludata.h" -LUAU_FASTFLAGVARIABLE(LuauGcAdditionalStats, false) - #include #define GC_SWEEPMAX 40 @@ -48,7 +46,8 @@ LUAU_FASTFLAGVARIABLE(LuauGcAdditionalStats, false) reallymarkobject(g, obj2gco(t)); \ } -static void recordGcStateTime(global_State* g, int startgcstate, double seconds, bool assist) +#ifdef LUAI_GCMETRICS +static void recordGcStateStep(global_State* g, int startgcstate, double seconds, bool assist, size_t work) { switch (startgcstate) { @@ -56,57 +55,75 @@ static void recordGcStateTime(global_State* g, int startgcstate, double seconds, // record root mark time if we have switched to next state if (g->gcstate == GCSpropagate) { - g->gcstats.currcycle.marktime += seconds; + g->gcmetrics.currcycle.marktime += seconds; - if (FFlag::LuauGcAdditionalStats && assist) - g->gcstats.currcycle.markassisttime += seconds; + if (assist) + g->gcmetrics.currcycle.markassisttime += seconds; } break; case GCSpropagate: case GCSpropagateagain: - g->gcstats.currcycle.marktime += seconds; + g->gcmetrics.currcycle.marktime += seconds; + g->gcmetrics.currcycle.markrequests += g->gcstepsize; - if (FFlag::LuauGcAdditionalStats && assist) - g->gcstats.currcycle.markassisttime += seconds; + if (assist) + g->gcmetrics.currcycle.markassisttime += seconds; break; case GCSatomic: - g->gcstats.currcycle.atomictime += seconds; + g->gcmetrics.currcycle.atomictime += seconds; break; case GCSsweep: - g->gcstats.currcycle.sweeptime += seconds; + g->gcmetrics.currcycle.sweeptime += seconds; + g->gcmetrics.currcycle.sweeprequests += g->gcstepsize; - if (FFlag::LuauGcAdditionalStats && assist) - g->gcstats.currcycle.sweepassisttime += seconds; + if (assist) + g->gcmetrics.currcycle.sweepassisttime += seconds; break; default: LUAU_ASSERT(!"Unexpected GC state"); } if (assist) - g->gcstats.stepassisttimeacc += seconds; + { + g->gcmetrics.stepassisttimeacc += seconds; + g->gcmetrics.currcycle.assistwork += work; + g->gcmetrics.currcycle.assistrequests += g->gcstepsize; + } else - g->gcstats.stepexplicittimeacc += seconds; + { + g->gcmetrics.stepexplicittimeacc += seconds; + g->gcmetrics.currcycle.explicitwork += work; + g->gcmetrics.currcycle.explicitrequests += g->gcstepsize; + } } -static void startGcCycleStats(global_State* g) +static double recordGcDeltaTime(double& timer) { - g->gcstats.currcycle.starttimestamp = lua_clock(); - g->gcstats.currcycle.pausetime = g->gcstats.currcycle.starttimestamp - g->gcstats.lastcycle.endtimestamp; + double now = lua_clock(); + double delta = now - timer; + timer = now; + return delta; } -static void finishGcCycleStats(global_State* g) +static void startGcCycleMetrics(global_State* g) { - g->gcstats.currcycle.endtimestamp = lua_clock(); - g->gcstats.currcycle.endtotalsizebytes = g->totalbytes; + g->gcmetrics.currcycle.starttimestamp = lua_clock(); + g->gcmetrics.currcycle.pausetime = g->gcmetrics.currcycle.starttimestamp - g->gcmetrics.lastcycle.endtimestamp; +} + +static void finishGcCycleMetrics(global_State* g) +{ + g->gcmetrics.currcycle.endtimestamp = lua_clock(); + g->gcmetrics.currcycle.endtotalsizebytes = g->totalbytes; - g->gcstats.completedcycles++; - g->gcstats.lastcycle = g->gcstats.currcycle; - g->gcstats.currcycle = GCCycleStats(); + g->gcmetrics.completedcycles++; + g->gcmetrics.lastcycle = g->gcmetrics.currcycle; + g->gcmetrics.currcycle = GCCycleMetrics(); - g->gcstats.cyclestatsacc.marktime += g->gcstats.lastcycle.marktime; - g->gcstats.cyclestatsacc.atomictime += g->gcstats.lastcycle.atomictime; - g->gcstats.cyclestatsacc.sweeptime += g->gcstats.lastcycle.sweeptime; + g->gcmetrics.currcycle.starttotalsizebytes = g->totalbytes; + g->gcmetrics.currcycle.heaptriggersizebytes = g->GCthreshold; } +#endif static void removeentry(LuaNode* n) { @@ -598,20 +615,19 @@ static size_t atomic(lua_State* L) LUAU_ASSERT(g->gcstate == GCSatomic); size_t work = 0; + +#ifdef LUAI_GCMETRICS double currts = lua_clock(); - double prevts = currts; +#endif /* remark occasional upvalues of (maybe) dead threads */ work += remarkupvals(g); /* traverse objects caught by write barrier and by 'remarkupvals' */ work += propagateall(g); - if (FFlag::LuauGcAdditionalStats) - { - currts = lua_clock(); - g->gcstats.currcycle.atomictimeupval += currts - prevts; - prevts = currts; - } +#ifdef LUAI_GCMETRICS + g->gcmetrics.currcycle.atomictimeupval += recordGcDeltaTime(currts); +#endif /* remark weak tables */ g->gray = g->weak; @@ -621,34 +637,26 @@ static size_t atomic(lua_State* L) markmt(g); /* mark basic metatables (again) */ work += propagateall(g); - if (FFlag::LuauGcAdditionalStats) - { - currts = lua_clock(); - g->gcstats.currcycle.atomictimeweak += currts - prevts; - prevts = currts; - } +#ifdef LUAI_GCMETRICS + g->gcmetrics.currcycle.atomictimeweak += recordGcDeltaTime(currts); +#endif /* remark gray again */ g->gray = g->grayagain; g->grayagain = NULL; work += propagateall(g); - if (FFlag::LuauGcAdditionalStats) - { - currts = lua_clock(); - g->gcstats.currcycle.atomictimegray += currts - prevts; - prevts = currts; - } +#ifdef LUAI_GCMETRICS + g->gcmetrics.currcycle.atomictimegray += recordGcDeltaTime(currts); +#endif /* remove collected objects from weak tables */ work += cleartable(L, g->weak); g->weak = NULL; - if (FFlag::LuauGcAdditionalStats) - { - currts = lua_clock(); - g->gcstats.currcycle.atomictimeclear += currts - prevts; - } +#ifdef LUAI_GCMETRICS + g->gcmetrics.currcycle.atomictimeclear += recordGcDeltaTime(currts); +#endif /* flip current white */ g->currentwhite = cast_byte(otherwhite(g)); @@ -742,8 +750,9 @@ static size_t gcstep(lua_State* L, size_t limit) if (!g->gray) { - if (FFlag::LuauGcAdditionalStats) - g->gcstats.currcycle.propagatework = g->gcstats.currcycle.explicitwork + g->gcstats.currcycle.assistwork; +#ifdef LUAI_GCMETRICS + g->gcmetrics.currcycle.propagatework = g->gcmetrics.currcycle.explicitwork + g->gcmetrics.currcycle.assistwork; +#endif // perform one iteration over 'gray again' list g->gray = g->grayagain; @@ -762,9 +771,10 @@ static size_t gcstep(lua_State* L, size_t limit) if (!g->gray) /* no more `gray' objects */ { - if (FFlag::LuauGcAdditionalStats) - g->gcstats.currcycle.propagateagainwork = - g->gcstats.currcycle.explicitwork + g->gcstats.currcycle.assistwork - g->gcstats.currcycle.propagatework; +#ifdef LUAI_GCMETRICS + g->gcmetrics.currcycle.propagateagainwork = + g->gcmetrics.currcycle.explicitwork + g->gcmetrics.currcycle.assistwork - g->gcmetrics.currcycle.propagatework; +#endif g->gcstate = GCSatomic; } @@ -772,8 +782,13 @@ static size_t gcstep(lua_State* L, size_t limit) } case GCSatomic: { - g->gcstats.currcycle.atomicstarttimestamp = lua_clock(); - g->gcstats.currcycle.atomicstarttotalsizebytes = g->totalbytes; +#ifdef LUAI_GCMETRICS + g->gcmetrics.currcycle.atomicstarttimestamp = lua_clock(); + g->gcmetrics.currcycle.atomicstarttotalsizebytes = g->totalbytes; +#endif + + g->gcstats.atomicstarttimestamp = lua_clock(); + g->gcstats.atomicstarttotalsizebytes = g->totalbytes; cost = atomic(L); /* finish mark phase */ @@ -809,18 +824,20 @@ static size_t gcstep(lua_State* L, size_t limit) return cost; } -static int64_t getheaptriggererroroffset(GCHeapTriggerStats* triggerstats, GCCycleStats* cyclestats) +static int64_t getheaptriggererroroffset(global_State* g) { // adjust for error using Proportional-Integral controller // https://en.wikipedia.org/wiki/PID_controller - int32_t errorKb = int32_t((cyclestats->atomicstarttotalsizebytes - cyclestats->heapgoalsizebytes) / 1024); + int32_t errorKb = int32_t((g->gcstats.atomicstarttotalsizebytes - g->gcstats.heapgoalsizebytes) / 1024); // we use sliding window for the error integral to avoid error sum 'windup' when the desired target cannot be reached - int32_t* slot = &triggerstats->terms[triggerstats->termpos % triggerstats->termcount]; + const size_t triggertermcount = sizeof(g->gcstats.triggerterms) / sizeof(g->gcstats.triggerterms[0]); + + int32_t* slot = &g->gcstats.triggerterms[g->gcstats.triggertermpos % triggertermcount]; int32_t prev = *slot; *slot = errorKb; - triggerstats->integral += errorKb - prev; - triggerstats->termpos++; + g->gcstats.triggerintegral += errorKb - prev; + g->gcstats.triggertermpos++; // controller tuning // https://en.wikipedia.org/wiki/Ziegler%E2%80%93Nichols_method @@ -832,7 +849,7 @@ static int64_t getheaptriggererroroffset(GCHeapTriggerStats* triggerstats, GCCyc const double Ki = 0.54 * Ku / Ti; // integral gain double proportionalTerm = Kp * errorKb; - double integralTerm = Ki * triggerstats->integral; + double integralTerm = Ki * g->gcstats.triggerintegral; double totalTerm = proportionalTerm + integralTerm; @@ -841,23 +858,20 @@ static int64_t getheaptriggererroroffset(GCHeapTriggerStats* triggerstats, GCCyc static size_t getheaptrigger(global_State* g, size_t heapgoal) { - GCCycleStats* lastcycle = &g->gcstats.lastcycle; - GCCycleStats* currcycle = &g->gcstats.currcycle; - // adjust threshold based on a guess of how many bytes will be allocated between the cycle start and sweep phase // our goal is to begin the sweep when used memory has reached the heap goal const double durationthreshold = 1e-3; - double allocationduration = currcycle->atomicstarttimestamp - lastcycle->endtimestamp; + double allocationduration = g->gcstats.atomicstarttimestamp - g->gcstats.endtimestamp; // avoid measuring intervals smaller than 1ms if (allocationduration < durationthreshold) return heapgoal; - double allocationrate = (currcycle->atomicstarttotalsizebytes - lastcycle->endtotalsizebytes) / allocationduration; - double markduration = currcycle->atomicstarttimestamp - currcycle->starttimestamp; + double allocationrate = (g->gcstats.atomicstarttotalsizebytes - g->gcstats.endtotalsizebytes) / allocationduration; + double markduration = g->gcstats.atomicstarttimestamp - g->gcstats.starttimestamp; int64_t expectedgrowth = int64_t(markduration * allocationrate); - int64_t offset = getheaptriggererroroffset(&g->gcstats.triggerstats, currcycle); + int64_t offset = getheaptriggererroroffset(g); int64_t heaptrigger = heapgoal - (expectedgrowth + offset); // clamp the trigger between memory use at the end of the cycle and the heap goal @@ -868,11 +882,6 @@ void luaC_step(lua_State* L, bool assist) { global_State* g = L->global; - if (assist) - g->gcstats.currcycle.assistrequests += g->gcstepsize; - else - g->gcstats.currcycle.explicitrequests += g->gcstepsize; - int lim = (g->gcstepsize / 100) * g->gcstepmul; /* how much to work */ LUAU_ASSERT(g->totalbytes >= g->GCthreshold); size_t debt = g->totalbytes - g->GCthreshold; @@ -881,24 +890,23 @@ void luaC_step(lua_State* L, bool assist) // at the start of the new cycle if (g->gcstate == GCSpause) - startGcCycleStats(g); + g->gcstats.starttimestamp = lua_clock(); - int lastgcstate = g->gcstate; - double lasttimestamp = lua_clock(); +#ifdef LUAI_GCMETRICS + if (g->gcstate == GCSpause) + startGcCycleMetrics(g); - size_t work = gcstep(L, lim); + double lasttimestamp = lua_clock(); +#endif - if (assist) - g->gcstats.currcycle.assistwork += work; - else - g->gcstats.currcycle.explicitwork += work; + int lastgcstate = g->gcstate; - recordGcStateTime(g, lastgcstate, lua_clock() - lasttimestamp, assist); + size_t work = gcstep(L, lim); + (void)work; - if (lastgcstate == GCSpropagate) - g->gcstats.currcycle.markrequests += g->gcstepsize; - else if (lastgcstate == GCSsweep) - g->gcstats.currcycle.sweeprequests += g->gcstepsize; +#ifdef LUAI_GCMETRICS + recordGcStateStep(g, lastgcstate, lua_clock() - lasttimestamp, assist, work); +#endif // at the end of the last cycle if (g->gcstate == GCSpause) @@ -909,13 +917,13 @@ void luaC_step(lua_State* L, bool assist) g->GCthreshold = heaptrigger; - finishGcCycleStats(g); + g->gcstats.heapgoalsizebytes = heapgoal; + g->gcstats.endtimestamp = lua_clock(); + g->gcstats.endtotalsizebytes = g->totalbytes; - if (FFlag::LuauGcAdditionalStats) - g->gcstats.currcycle.starttotalsizebytes = g->totalbytes; - - g->gcstats.currcycle.heapgoalsizebytes = heapgoal; - g->gcstats.currcycle.heaptriggersizebytes = heaptrigger; +#ifdef LUAI_GCMETRICS + finishGcCycleMetrics(g); +#endif } else { @@ -933,8 +941,10 @@ void luaC_fullgc(lua_State* L) { global_State* g = L->global; +#ifdef LUAI_GCMETRICS if (g->gcstate == GCSpause) - startGcCycleStats(g); + startGcCycleMetrics(g); +#endif if (g->gcstate <= GCSatomic) { @@ -954,11 +964,12 @@ void luaC_fullgc(lua_State* L) gcstep(L, SIZE_MAX); } - finishGcCycleStats(g); +#ifdef LUAI_GCMETRICS + finishGcCycleMetrics(g); + startGcCycleMetrics(g); +#endif /* run a full collection cycle */ - startGcCycleStats(g); - markroot(L); while (g->gcstate != GCSpause) { @@ -980,10 +991,11 @@ void luaC_fullgc(lua_State* L) if (g->GCthreshold < g->totalbytes) g->GCthreshold = g->totalbytes; - finishGcCycleStats(g); + g->gcstats.heapgoalsizebytes = heapgoalsizebytes; - g->gcstats.currcycle.heapgoalsizebytes = heapgoalsizebytes; - g->gcstats.currcycle.heaptriggersizebytes = g->GCthreshold; +#ifdef LUAI_GCMETRICS + finishGcCycleMetrics(g); +#endif } void luaC_barrierupval(lua_State* L, GCObject* v) @@ -1075,21 +1087,21 @@ int64_t luaC_allocationrate(lua_State* L) if (g->gcstate <= GCSatomic) { - double duration = lua_clock() - g->gcstats.lastcycle.endtimestamp; + double duration = lua_clock() - g->gcstats.endtimestamp; if (duration < durationthreshold) return -1; - return int64_t((g->totalbytes - g->gcstats.lastcycle.endtotalsizebytes) / duration); + return int64_t((g->totalbytes - g->gcstats.endtotalsizebytes) / duration); } // totalbytes is unstable during the sweep, use the rate measured at the end of mark phase - double duration = g->gcstats.currcycle.atomicstarttimestamp - g->gcstats.lastcycle.endtimestamp; + double duration = g->gcstats.atomicstarttimestamp - g->gcstats.endtimestamp; if (duration < durationthreshold) return -1; - return int64_t((g->gcstats.currcycle.atomicstarttotalsizebytes - g->gcstats.lastcycle.endtotalsizebytes) / duration); + return int64_t((g->gcstats.atomicstarttotalsizebytes - g->gcstats.endtotalsizebytes) / duration); } void luaC_wakethread(lua_State* L) diff --git a/VM/src/lgc.h b/VM/src/lgc.h index ebf999b53..dcd070b70 100644 --- a/VM/src/lgc.h +++ b/VM/src/lgc.h @@ -82,7 +82,7 @@ #define luaC_checkGC(L) \ { \ - condhardstacktests(luaD_reallocstack(L, L->stacksize - (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK))); \ + condhardstacktests(luaD_reallocstack(L, L->stacksize - EXTRA_STACK)); \ if (L->global->totalbytes >= L->global->GCthreshold) \ { \ condhardmemtests(luaC_validate(L), 1); \ diff --git a/VM/src/lstate.cpp b/VM/src/lstate.cpp index d4f3f0a19..fbc6fb1e4 100644 --- a/VM/src/lstate.cpp +++ b/VM/src/lstate.cpp @@ -10,8 +10,6 @@ #include "ldo.h" #include "ldebug.h" -LUAU_FASTFLAGVARIABLE(LuauReduceStackReallocs, false) - /* ** Main thread combines a thread state and the global state */ @@ -35,7 +33,7 @@ static void stack_init(lua_State* L1, lua_State* L) for (int i = 0; i < BASIC_STACK_SIZE + EXTRA_STACK; i++) setnilvalue(stack + i); /* erase new stack */ L1->top = stack; - L1->stack_last = stack + (L1->stacksize - (FFlag::LuauReduceStackReallocs ? EXTRA_STACK : 1 + EXTRA_STACK)); + L1->stack_last = stack + (L1->stacksize - EXTRA_STACK); /* initialize first ci */ L1->ci->func = L1->top; setnilvalue(L1->top++); /* `function' entry for this `ci' */ @@ -141,30 +139,16 @@ void lua_resetthread(lua_State* L) ci->top = ci->base + LUA_MINSTACK; setnilvalue(ci->func); L->ci = ci; - if (FFlag::LuauReduceStackReallocs) - { - if (L->size_ci != BASIC_CI_SIZE) - luaD_reallocCI(L, BASIC_CI_SIZE); - } - else - { + if (L->size_ci != BASIC_CI_SIZE) luaD_reallocCI(L, BASIC_CI_SIZE); - } /* clear thread state */ L->status = LUA_OK; L->base = L->ci->base; L->top = L->ci->base; L->nCcalls = L->baseCcalls = 0; /* clear thread stack */ - if (FFlag::LuauReduceStackReallocs) - { - if (L->stacksize != BASIC_STACK_SIZE + EXTRA_STACK) - luaD_reallocstack(L, BASIC_STACK_SIZE); - } - else - { + if (L->stacksize != BASIC_STACK_SIZE + EXTRA_STACK) luaD_reallocstack(L, BASIC_STACK_SIZE); - } for (int i = 0; i < L->stacksize; i++) setnilvalue(L->stack + i); } @@ -234,6 +218,10 @@ lua_State* lua_newstate(lua_Alloc f, void* ud) g->cb = lua_Callbacks(); g->gcstats = GCStats(); +#ifdef LUAI_GCMETRICS + g->gcmetrics = GCMetrics(); +#endif + if (luaD_rawrunprotected(L, f_luaopen, NULL) != 0) { /* memory allocation error: free partial state */ diff --git a/VM/src/lstate.h b/VM/src/lstate.h index b2bedb486..e7c373736 100644 --- a/VM/src/lstate.h +++ b/VM/src/lstate.h @@ -75,10 +75,26 @@ typedef struct CallInfo #define f_isLua(ci) (!ci_func(ci)->isC) #define isLua(ci) (ttisfunction((ci)->func) && f_isLua(ci)) -struct GCCycleStats +struct GCStats { - size_t starttotalsizebytes = 0; + // data for proportional-integral controller of heap trigger value + int32_t triggerterms[32] = {0}; + uint32_t triggertermpos = 0; + int32_t triggerintegral = 0; + + size_t atomicstarttotalsizebytes = 0; + size_t endtotalsizebytes = 0; size_t heapgoalsizebytes = 0; + + double starttimestamp = 0; + double atomicstarttimestamp = 0; + double endtimestamp = 0; +}; + +#ifdef LUAI_GCMETRICS +struct GCCycleMetrics +{ + size_t starttotalsizebytes = 0; size_t heaptriggersizebytes = 0; double pausetime = 0.0; // time from end of the last cycle to the start of a new one @@ -120,16 +136,7 @@ struct GCCycleStats size_t endtotalsizebytes = 0; }; -// data for proportional-integral controller of heap trigger value -struct GCHeapTriggerStats -{ - static const unsigned termcount = 32; - int32_t terms[termcount] = {0}; - uint32_t termpos = 0; - int32_t integral = 0; -}; - -struct GCStats +struct GCMetrics { double stepexplicittimeacc = 0.0; double stepassisttimeacc = 0.0; @@ -137,14 +144,10 @@ struct GCStats // when cycle is completed, last cycle values are updated uint64_t completedcycles = 0; - GCCycleStats lastcycle; - GCCycleStats currcycle; - - // only step count and their time is accumulated - GCCycleStats cyclestatsacc; - - GCHeapTriggerStats triggerstats; + GCCycleMetrics lastcycle; + GCCycleMetrics currcycle; }; +#endif /* ** `global state', shared by all threads of this state @@ -206,6 +209,9 @@ typedef struct global_State GCStats gcstats; +#ifdef LUAI_GCMETRICS + GCMetrics gcmetrics; +#endif } global_State; // clang-format on diff --git a/VM/src/ltable.cpp b/VM/src/ltable.cpp index 2deec2b9a..431501f3b 100644 --- a/VM/src/ltable.cpp +++ b/VM/src/ltable.cpp @@ -526,8 +526,8 @@ static TValue* newkey(lua_State* L, Table* t, const TValue* key) LuaNode* othern; LuaNode* n = getfreepos(t); /* get a free place */ if (n == NULL) - { /* cannot find a free place? */ - rehash(L, t, key); /* grow table */ + { /* cannot find a free place? */ + rehash(L, t, key); /* grow table */ if (!FFlag::LuauTableRehashRework) { diff --git a/tests/Autocomplete.test.cpp b/tests/Autocomplete.test.cpp index 17fd6b133..1db782cce 100644 --- a/tests/Autocomplete.test.cpp +++ b/tests/Autocomplete.test.cpp @@ -2726,11 +2726,6 @@ end TEST_CASE_FIXTURE(ACFixture, "autocomplete_on_string_singletons") { - ScopedFastFlag sffs[] = { - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, - }; - check(R"( --!strict local foo: "hello" | "bye" = "hello" diff --git a/tests/Conformance.test.cpp b/tests/Conformance.test.cpp index 9e4cb4a59..83d4518de 100644 --- a/tests/Conformance.test.cpp +++ b/tests/Conformance.test.cpp @@ -496,8 +496,6 @@ TEST_CASE("DateTime") TEST_CASE("Debug") { - ScopedFastFlag luauTableFieldFunctionDebugname{"LuauTableFieldFunctionDebugname", true}; - runConformance("debug.lua"); } diff --git a/tests/Fixture.cpp b/tests/Fixture.cpp index dbdd06a44..a7e7ea39d 100644 --- a/tests/Fixture.cpp +++ b/tests/Fixture.cpp @@ -133,26 +133,19 @@ AstStatBlock* Fixture::parse(const std::string& source, const ParseOptions& pars CheckResult Fixture::check(Mode mode, std::string source) { + ModuleName mm = fromString(mainModuleName); configResolver.defaultConfig.mode = mode; - fileResolver.source[mainModuleName] = std::move(source); - - CheckResult result = frontend.check(fromString(mainModuleName)); + fileResolver.source[mm] = std::move(source); + frontend.markDirty(mm); - configResolver.defaultConfig.mode = Mode::Strict; + CheckResult result = frontend.check(mm); return result; } CheckResult Fixture::check(const std::string& source) { - ModuleName mm = fromString(mainModuleName); - configResolver.defaultConfig.mode = Mode::Strict; - fileResolver.source[mm] = std::move(source); - frontend.markDirty(mm); - - CheckResult result = frontend.check(mm); - - return result; + return check(Mode::Strict, source); } LintResult Fixture::lint(const std::string& source, const std::optional& lintOptions) diff --git a/tests/Linter.test.cpp b/tests/Linter.test.cpp index 91b23197c..9ce9a4c25 100644 --- a/tests/Linter.test.cpp +++ b/tests/Linter.test.cpp @@ -597,6 +597,8 @@ return foo1 TEST_CASE_FIXTURE(Fixture, "UnknownType") { + ScopedFastFlag sff("LuauLintNoRobloxBits", true); + unfreeze(typeChecker.globalTypes); TableTypeVar::Props instanceProps{ {"ClassName", {typeChecker.anyType}}, @@ -606,81 +608,26 @@ TEST_CASE_FIXTURE(Fixture, "UnknownType") TypeId instanceType = typeChecker.globalTypes.addType(instanceTable); TypeFun instanceTypeFun{{}, instanceType}; - ClassTypeVar::Props enumItemProps{ - {"EnumType", {typeChecker.anyType}}, - }; - - ClassTypeVar enumItemClass{"EnumItem", enumItemProps, std::nullopt, std::nullopt, {}, {}}; - TypeId enumItemType = typeChecker.globalTypes.addType(enumItemClass); - TypeFun enumItemTypeFun{{}, enumItemType}; - - ClassTypeVar normalIdClass{"NormalId", {}, enumItemType, std::nullopt, {}, {}}; - TypeId normalIdType = typeChecker.globalTypes.addType(normalIdClass); - TypeFun normalIdTypeFun{{}, normalIdType}; - - // Normally this would be defined externally, so hack it in for testing - addGlobalBinding(typeChecker, "game", typeChecker.anyType, "@test"); - addGlobalBinding(typeChecker, "typeof", typeChecker.anyType, "@test"); typeChecker.globalScope->exportedTypeBindings["Part"] = instanceTypeFun; - typeChecker.globalScope->exportedTypeBindings["Workspace"] = instanceTypeFun; - typeChecker.globalScope->exportedTypeBindings["RunService"] = instanceTypeFun; - typeChecker.globalScope->exportedTypeBindings["Instance"] = instanceTypeFun; - typeChecker.globalScope->exportedTypeBindings["ColorSequence"] = TypeFun{{}, typeChecker.anyType}; - typeChecker.globalScope->exportedTypeBindings["EnumItem"] = enumItemTypeFun; - typeChecker.globalScope->importedTypeBindings["Enum"] = {{"NormalId", normalIdTypeFun}}; - freeze(typeChecker.globalTypes); LintResult result = lint(R"( -local _e01 = game:GetService("Foo") -local _e02 = game:GetService("NormalId") -local _e03 = game:FindService("table") -local _e04 = type(game) == "Part" -local _e05 = type(game) == "NormalId" -local _e06 = typeof(game) == "Bar" -local _e07 = typeof(game) == "Part" -local _e08 = typeof(game) == "vector" -local _e09 = typeof(game) == "NormalId" -local _e10 = game:IsA("ColorSequence") -local _e11 = game:IsA("Enum.NormalId") -local _e12 = game:FindFirstChildWhichIsA("function") - -local _o01 = game:GetService("Workspace") -local _o02 = game:FindService("RunService") -local _o03 = type(game) == "number" -local _o04 = type(game) == "vector" -local _o05 = typeof(game) == "string" -local _o06 = typeof(game) == "Instance" -local _o07 = typeof(game) == "EnumItem" -local _o08 = game:IsA("Part") -local _o09 = game:IsA("NormalId") -local _o10 = game:FindFirstChildWhichIsA("Part") +local game = ... +local _e01 = type(game) == "Part" +local _e02 = typeof(game) == "Bar" +local _e03 = typeof(game) == "vector" + +local _o01 = type(game) == "number" +local _o02 = type(game) == "vector" +local _o03 = typeof(game) == "Part" )"); - REQUIRE_EQ(result.warnings.size(), 12); - CHECK_EQ(result.warnings[0].location.begin.line, 1); - CHECK_EQ(result.warnings[0].text, "Unknown type 'Foo'"); - CHECK_EQ(result.warnings[1].location.begin.line, 2); - CHECK_EQ(result.warnings[1].text, "Unknown type 'NormalId' (expected class type)"); - CHECK_EQ(result.warnings[2].location.begin.line, 3); - CHECK_EQ(result.warnings[2].text, "Unknown type 'table' (expected class type)"); - CHECK_EQ(result.warnings[3].location.begin.line, 4); - CHECK_EQ(result.warnings[3].text, "Unknown type 'Part' (expected primitive type)"); - CHECK_EQ(result.warnings[4].location.begin.line, 5); - CHECK_EQ(result.warnings[4].text, "Unknown type 'NormalId' (expected primitive type)"); - CHECK_EQ(result.warnings[5].location.begin.line, 6); - CHECK_EQ(result.warnings[5].text, "Unknown type 'Bar'"); - CHECK_EQ(result.warnings[6].location.begin.line, 7); - CHECK_EQ(result.warnings[6].text, "Unknown type 'Part' (expected primitive or userdata type)"); - CHECK_EQ(result.warnings[7].location.begin.line, 8); - CHECK_EQ(result.warnings[7].text, "Unknown type 'vector' (expected primitive or userdata type)"); - CHECK_EQ(result.warnings[8].location.begin.line, 9); - CHECK_EQ(result.warnings[8].text, "Unknown type 'NormalId' (expected primitive or userdata type)"); - CHECK_EQ(result.warnings[9].location.begin.line, 10); - CHECK_EQ(result.warnings[9].text, "Unknown type 'ColorSequence' (expected class or enum type)"); - CHECK_EQ(result.warnings[10].location.begin.line, 11); - CHECK_EQ(result.warnings[10].text, "Unknown type 'Enum.NormalId'"); - CHECK_EQ(result.warnings[11].location.begin.line, 12); - CHECK_EQ(result.warnings[11].text, "Unknown type 'function' (expected class type)"); + REQUIRE_EQ(result.warnings.size(), 3); + CHECK_EQ(result.warnings[0].location.begin.line, 2); + CHECK_EQ(result.warnings[0].text, "Unknown type 'Part' (expected primitive type)"); + CHECK_EQ(result.warnings[1].location.begin.line, 3); + CHECK_EQ(result.warnings[1].text, "Unknown type 'Bar'"); + CHECK_EQ(result.warnings[2].location.begin.line, 4); + CHECK_EQ(result.warnings[2].text, "Unknown type 'vector' (expected primitive or userdata type)"); } TEST_CASE_FIXTURE(Fixture, "ForRangeTable") diff --git a/tests/ToString.test.cpp b/tests/ToString.test.cpp index 6713a589d..3051e2090 100644 --- a/tests/ToString.test.cpp +++ b/tests/ToString.test.cpp @@ -470,6 +470,7 @@ TEST_CASE_FIXTURE(Fixture, "self_recursive_instantiated_param") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_id") { + ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function id(x) return x end )"); @@ -482,6 +483,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_id") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") { + ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function map(arr, fn) local t = {} @@ -500,6 +502,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_map") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") { + ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(a: number, b: string) end local function test(...: T...): U... @@ -516,6 +519,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_generic_pack") TEST_CASE("toStringNamedFunction_unit_f") { + ScopedFastFlag flag{"LuauDocFuncParameters", true}; TypePackVar empty{TypePack{}}; FunctionTypeVar ftv{&empty, &empty, {}, false}; CHECK_EQ("f(): ()", toStringNamedFunction("f", ftv)); @@ -523,6 +527,7 @@ TEST_CASE("toStringNamedFunction_unit_f") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics") { + ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(x: a, ...): (a, a, b...) return x, x, ... @@ -537,6 +542,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics2") { + ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(): ...number return 1, 2, 3 @@ -551,6 +557,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics2") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics3") { + ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(): (string, ...number) return 'a', 1, 2, 3 @@ -565,6 +572,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_variadics3") TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_type_annotation_has_partial_argnames") { + ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local f: (number, y: number) -> number )"); @@ -577,6 +585,7 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_type_annotation_has_partial_ar TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_type_params") { + ScopedFastFlag flag{"LuauDocFuncParameters", true}; CheckResult result = check(R"( local function f(x: T, g: (T) -> U)): () end @@ -590,4 +599,20 @@ TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_hide_type_params") CHECK_EQ("f(x: T, g: (T) -> U): ()", toStringNamedFunction("f", *ftv, opts)); } +TEST_CASE_FIXTURE(Fixture, "toStringNamedFunction_overrides_param_names") +{ + ScopedFastFlag flag{"LuauDocFuncParameters", true}; + + CheckResult result = check(R"( + local function test(a, b : string, ... : number) return a end + )"); + + TypeId ty = requireType("test"); + const FunctionTypeVar* ftv = get(follow(ty)); + + ToStringOptions opts; + opts.namedFunctionOverrideArgNames = {"first", "second", "third"}; + CHECK_EQ("test(first: a, second: string, ...: number): a", toStringNamedFunction("test", *ftv, opts)); +} + TEST_SUITE_END(); diff --git a/tests/Transpiler.test.cpp b/tests/Transpiler.test.cpp index 5f0295b0e..5ac45ff21 100644 --- a/tests/Transpiler.test.cpp +++ b/tests/Transpiler.test.cpp @@ -651,8 +651,6 @@ local a: Packed TEST_CASE_FIXTURE(Fixture, "transpile_singleton_types") { - ScopedFastFlag luauParseSingletonTypes{"LuauParseSingletonTypes", true}; - std::string code = R"( type t1 = 'hello' type t2 = true diff --git a/tests/TypeInfer.builtins.test.cpp b/tests/TypeInfer.builtins.test.cpp index ec20a2c7f..c6fbebedb 100644 --- a/tests/TypeInfer.builtins.test.cpp +++ b/tests/TypeInfer.builtins.test.cpp @@ -887,8 +887,6 @@ TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types") TEST_CASE_FIXTURE(Fixture, "assert_removes_falsy_types2") { ScopedFastFlag sff[]{ - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, {"LuauAssertStripsFalsyTypes", true}, {"LuauDiscriminableUnions2", true}, {"LuauWidenIfSupertypeIsFree2", true}, diff --git a/tests/TypeInfer.functions.test.cpp b/tests/TypeInfer.functions.test.cpp index 4288098a3..da4ea074a 100644 --- a/tests/TypeInfer.functions.test.cpp +++ b/tests/TypeInfer.functions.test.cpp @@ -1335,4 +1335,80 @@ caused by: toString(result.errors[0])); } +TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic") +{ + ScopedFastFlag sff{"LuauArgCountMismatchSaysAtLeastWhenVariadic", true}; + CheckResult result = check(R"( + function test(a: number, b: string, ...) + end + + test(1) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto err = result.errors[0]; + auto acm = get(err); + REQUIRE(acm); + + CHECK_EQ(2, acm->expected); + CHECK_EQ(1, acm->actual); + CHECK_EQ(CountMismatch::Context::Arg, acm->context); + CHECK(acm->isVariadic); +} + +TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic_generic") +{ + ScopedFastFlag sff1{"LuauArgCountMismatchSaysAtLeastWhenVariadic", true}; + ScopedFastFlag sff2{"LuauFixArgumentCountMismatchAmountWithGenericTypes", true}; + CheckResult result = check(R"( +function test(a: number, b: string, ...) + return 1 +end + +function wrapper(f: (A...) -> number, ...: A...) +end + +wrapper(test) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto err = result.errors[0]; + auto acm = get(err); + REQUIRE(acm); + + CHECK_EQ(3, acm->expected); + CHECK_EQ(1, acm->actual); + CHECK_EQ(CountMismatch::Context::Arg, acm->context); + CHECK(acm->isVariadic); +} + +TEST_CASE_FIXTURE(Fixture, "too_few_arguments_variadic_generic2") +{ + ScopedFastFlag sff1{"LuauArgCountMismatchSaysAtLeastWhenVariadic", true}; + ScopedFastFlag sff2{"LuauFixArgumentCountMismatchAmountWithGenericTypes", true}; + CheckResult result = check(R"( +function test(a: number, b: string, ...) + return 1 +end + +function wrapper(f: (A...) -> number, ...: A...) +end + +pcall(wrapper, test) + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + + auto err = result.errors[0]; + auto acm = get(err); + REQUIRE(acm); + + CHECK_EQ(4, acm->expected); + CHECK_EQ(2, acm->actual); + CHECK_EQ(CountMismatch::Context::Arg, acm->context); + CHECK(acm->isVariadic); +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.modules.test.cpp b/tests/TypeInfer.modules.test.cpp index 636436103..e5eeae317 100644 --- a/tests/TypeInfer.modules.test.cpp +++ b/tests/TypeInfer.modules.test.cpp @@ -13,6 +13,8 @@ using namespace Luau; +LUAU_FASTFLAG(LuauTableSubtypingVariance2) + TEST_SUITE_BEGIN("TypeInferModules"); TEST_CASE_FIXTURE(Fixture, "require") @@ -268,8 +270,6 @@ function x:Destroy(): () end TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_2") { - ScopedFastFlag immutableTypes{"LuauImmutableTypes", true}; - fileResolver.source["game/A"] = R"( export type Type = { x: { a: number } } return {} @@ -288,8 +288,6 @@ type Rename = typeof(x.x) TEST_CASE_FIXTURE(Fixture, "do_not_modify_imported_types_3") { - ScopedFastFlag immutableTypes{"LuauImmutableTypes", true}; - fileResolver.source["game/A"] = R"( local y = setmetatable({}, {}) export type Type = { x: typeof(y) } @@ -307,4 +305,83 @@ type Rename = typeof(x.x) LUAU_REQUIRE_NO_ERRORS(result); } +TEST_CASE_FIXTURE(Fixture, "module_type_conflict") +{ + ScopedFastFlag luauTypeMismatchModuleName{"LuauTypeMismatchModuleName", true}; + + fileResolver.source["game/A"] = R"( +export type T = { x: number } +return {} + )"; + + fileResolver.source["game/B"] = R"( +export type T = { x: string } +return {} + )"; + + fileResolver.source["game/C"] = R"( +local A = require(game.A) +local B = require(game.B) +local a: A.T = { x = 2 } +local b: B.T = a + )"; + + CheckResult result = frontend.check("game/C"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauTableSubtypingVariance2) + { + CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/A' could not be converted into 'T' from 'game/B' +caused by: + Property 'x' is not compatible. Type 'number' could not be converted into 'string')"); + } + else + { + CHECK_EQ(toString(result.errors[0]), "Type 'T' from 'game/A' could not be converted into 'T' from 'game/B'"); + } +} + +TEST_CASE_FIXTURE(Fixture, "module_type_conflict_instantiated") +{ + ScopedFastFlag luauTypeMismatchModuleName{"LuauTypeMismatchModuleName", true}; + + fileResolver.source["game/A"] = R"( +export type Wrap = { x: T } +return {} + )"; + + fileResolver.source["game/B"] = R"( +local A = require(game.A) +export type T = A.Wrap +return {} + )"; + + fileResolver.source["game/C"] = R"( +local A = require(game.A) +export type T = A.Wrap +return {} + )"; + + fileResolver.source["game/D"] = R"( +local A = require(game.B) +local B = require(game.C) +local a: A.T = { x = 2 } +local b: B.T = a + )"; + + CheckResult result = frontend.check("game/D"); + LUAU_REQUIRE_ERROR_COUNT(1, result); + + if (FFlag::LuauTableSubtypingVariance2) + { + CHECK_EQ(toString(result.errors[0]), R"(Type 'T' from 'game/B' could not be converted into 'T' from 'game/C' +caused by: + Property 'x' is not compatible. Type 'number' could not be converted into 'string')"); + } + else + { + CHECK_EQ(toString(result.errors[0]), "Type 'T' from 'game/B' could not be converted into 'T' from 'game/C'"); + } +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.operators.test.cpp b/tests/TypeInfer.operators.test.cpp index baa259783..6a8a9d93f 100644 --- a/tests/TypeInfer.operators.test.cpp +++ b/tests/TypeInfer.operators.test.cpp @@ -756,4 +756,30 @@ TEST_CASE_FIXTURE(Fixture, "refine_and_or") CHECK_EQ("number", toString(requireType("u"))); } +TEST_CASE_FIXTURE(Fixture, "infer_any_in_all_modes_when_lhs_is_unknown") +{ + ScopedFastFlag sff{"LuauDecoupleOperatorInferenceFromUnifiedTypeInference", true}; + + CheckResult result = check(Mode::Strict, R"( + local function f(x, y) + return x + y + end + )"); + + LUAU_REQUIRE_ERROR_COUNT(1, result); + CHECK_EQ(toString(result.errors[0]), "Unknown type used in + operation; consider adding a type annotation to 'x'"); + + result = check(Mode::Nonstrict, R"( + local function f(x, y) + return x + y + end + )"); + + LUAU_REQUIRE_NO_ERRORS(result); + + // When type inference is unified, we could add an assertion that + // the strict and nonstrict types are equivalent. This isn't actually + // the case right now, though. +} + TEST_SUITE_END(); diff --git a/tests/TypeInfer.refinements.test.cpp b/tests/TypeInfer.refinements.test.cpp index 9b347921f..cddeab6ec 100644 --- a/tests/TypeInfer.refinements.test.cpp +++ b/tests/TypeInfer.refinements.test.cpp @@ -435,7 +435,6 @@ TEST_CASE_FIXTURE(Fixture, "term_is_equal_to_an_lvalue") { ScopedFastFlag sff[] = { {"LuauDiscriminableUnions2", true}, - {"LuauSingletonTypes", true}, }; CheckResult result = check(R"( @@ -1002,8 +1001,6 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_from_truthiness_of_x") { ScopedFastFlag sff[] = { {"LuauDiscriminableUnions2", true}, - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, }; CheckResult result = check(R"( @@ -1028,8 +1025,6 @@ TEST_CASE_FIXTURE(Fixture, "discriminate_tag") { ScopedFastFlag sff[] = { {"LuauDiscriminableUnions2", true}, - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, }; CheckResult result = check(R"( @@ -1066,8 +1061,6 @@ TEST_CASE_FIXTURE(Fixture, "and_or_peephole_refinement") TEST_CASE_FIXTURE(Fixture, "narrow_boolean_to_true_or_false") { ScopedFastFlag sff[]{ - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, {"LuauDiscriminableUnions2", true}, {"LuauAssertStripsFalsyTypes", true}, }; @@ -1091,8 +1084,6 @@ TEST_CASE_FIXTURE(Fixture, "narrow_boolean_to_true_or_false") TEST_CASE_FIXTURE(Fixture, "discriminate_on_properties_of_disjoint_tables_where_that_property_is_true_or_false") { ScopedFastFlag sff[]{ - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, {"LuauDiscriminableUnions2", true}, {"LuauAssertStripsFalsyTypes", true}, }; @@ -1134,8 +1125,6 @@ TEST_CASE_FIXTURE(RefinementClassFixture, "discriminate_from_isa_of_x") { ScopedFastFlag sff[] = { {"LuauDiscriminableUnions2", true}, - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.singletons.test.cpp b/tests/TypeInfer.singletons.test.cpp index 7f8d8fec2..d39341eac 100644 --- a/tests/TypeInfer.singletons.test.cpp +++ b/tests/TypeInfer.singletons.test.cpp @@ -13,11 +13,6 @@ TEST_SUITE_BEGIN("TypeSingletons"); TEST_CASE_FIXTURE(Fixture, "bool_singletons") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( local a: true = true local b: false = false @@ -28,11 +23,6 @@ TEST_CASE_FIXTURE(Fixture, "bool_singletons") TEST_CASE_FIXTURE(Fixture, "string_singletons") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( local a: "foo" = "foo" local b: "bar" = "bar" @@ -43,11 +33,6 @@ TEST_CASE_FIXTURE(Fixture, "string_singletons") TEST_CASE_FIXTURE(Fixture, "bool_singletons_mismatch") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( local a: true = false )"); @@ -58,11 +43,6 @@ TEST_CASE_FIXTURE(Fixture, "bool_singletons_mismatch") TEST_CASE_FIXTURE(Fixture, "string_singletons_mismatch") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( local a: "foo" = "bar" )"); @@ -73,11 +53,6 @@ TEST_CASE_FIXTURE(Fixture, "string_singletons_mismatch") TEST_CASE_FIXTURE(Fixture, "string_singletons_escape_chars") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( local a: "\n" = "\000\r" )"); @@ -88,11 +63,6 @@ TEST_CASE_FIXTURE(Fixture, "string_singletons_escape_chars") TEST_CASE_FIXTURE(Fixture, "bool_singleton_subtype") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( local a: true = true local b: boolean = a @@ -103,11 +73,6 @@ TEST_CASE_FIXTURE(Fixture, "bool_singleton_subtype") TEST_CASE_FIXTURE(Fixture, "string_singleton_subtype") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( local a: "foo" = "foo" local b: string = a @@ -118,11 +83,6 @@ TEST_CASE_FIXTURE(Fixture, "string_singleton_subtype") TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( function f(a: true, b: "foo") end f(true, "foo") @@ -133,11 +93,6 @@ TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons") TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons_mismatch") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( function f(a: true, b: "foo") end f(true, "bar") @@ -149,11 +104,6 @@ TEST_CASE_FIXTURE(Fixture, "function_call_with_singletons_mismatch") TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( function f(a, b) end local g : ((true, string) -> ()) & ((false, number) -> ()) = (f::any) @@ -166,11 +116,6 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons") TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons_mismatch") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( function f(a, b) end local g : ((true, string) -> ()) & ((false, number) -> ()) = (f::any) @@ -184,11 +129,6 @@ TEST_CASE_FIXTURE(Fixture, "overloaded_function_call_with_singletons_mismatch") TEST_CASE_FIXTURE(Fixture, "enums_using_singletons") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( type MyEnum = "foo" | "bar" | "baz" local a : MyEnum = "foo" @@ -201,11 +141,6 @@ TEST_CASE_FIXTURE(Fixture, "enums_using_singletons") TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( type MyEnum = "foo" | "bar" | "baz" local a : MyEnum = "bang" @@ -218,11 +153,6 @@ TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_mismatch") TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_subtyping") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( type MyEnum1 = "foo" | "bar" type MyEnum2 = MyEnum1 | "baz" @@ -237,8 +167,6 @@ TEST_CASE_FIXTURE(Fixture, "enums_using_singletons_subtyping") TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons") { ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, {"LuauExpectedTypesOfProperties", true}, }; @@ -257,11 +185,6 @@ TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons") TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons_mismatch") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( type Dog = { tag: "Dog", howls: boolean } type Cat = { tag: "Cat", meows: boolean } @@ -274,11 +197,6 @@ TEST_CASE_FIXTURE(Fixture, "tagged_unions_using_singletons_mismatch") TEST_CASE_FIXTURE(Fixture, "tagged_unions_immutable_tag") { - ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( type Dog = { tag: "Dog", howls: boolean } type Cat = { tag: "Cat", meows: boolean } @@ -292,10 +210,6 @@ TEST_CASE_FIXTURE(Fixture, "tagged_unions_immutable_tag") TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings") { - ScopedFastFlag sffs[] = { - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( --!strict type T = { @@ -320,10 +234,6 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings") } TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings_mismatch") { - ScopedFastFlag sffs[] = { - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( --!strict type T = { @@ -341,10 +251,6 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_singleton_strings_mismatch") TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer") { - ScopedFastFlag sffs[] = { - {"LuauParseSingletonTypes", true}, - }; - CheckResult result = check(R"( --!strict type S = "bar" @@ -367,8 +273,7 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_alias_or_parens_is_indexer") TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") { - ScopedFastFlag sffs[] = { - {"LuauParseSingletonTypes", true}, + ScopedFastFlag sffs[]{ {"LuauUnsealedTableLiteral", true}, }; @@ -386,8 +291,6 @@ TEST_CASE_FIXTURE(Fixture, "table_properties_type_error_escapes") TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_string") { ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, {"LuauExpectedTypesOfProperties", true}, }; @@ -409,8 +312,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "error_detailed_tagged_union_mismatch_bool") { ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, {"LuauExpectedTypesOfProperties", true}, }; @@ -432,8 +333,6 @@ caused by: TEST_CASE_FIXTURE(Fixture, "if_then_else_expression_singleton_options") { ScopedFastFlag sffs[] = { - {"LuauSingletonTypes", true}, - {"LuauParseSingletonTypes", true}, {"LuauExpectedTypesOfProperties", true}, }; @@ -451,7 +350,6 @@ local a: Animal = if true then { tag = 'cat', catfood = 'something' } else { tag TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_singleton") { ScopedFastFlag sff[]{ - {"LuauSingletonTypes", true}, {"LuauEqConstraint", true}, {"LuauDiscriminableUnions2", true}, {"LuauWidenIfSupertypeIsFree2", true}, @@ -477,8 +375,6 @@ TEST_CASE_FIXTURE(Fixture, "widen_the_supertype_if_it_is_free_and_subtype_has_si TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") { ScopedFastFlag sff[]{ - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, {"LuauDiscriminableUnions2", true}, {"LuauEqConstraint", true}, {"LuauWidenIfSupertypeIsFree2", true}, @@ -504,8 +400,6 @@ TEST_CASE_FIXTURE(Fixture, "return_type_of_f_is_not_widened") TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere") { ScopedFastFlag sff[]{ - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, {"LuauWidenIfSupertypeIsFree2", true}, }; @@ -521,8 +415,6 @@ TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere") TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere_except_for_tables") { ScopedFastFlag sff[]{ - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, {"LuauDiscriminableUnions2", true}, {"LuauWidenIfSupertypeIsFree2", true}, }; @@ -551,8 +443,6 @@ TEST_CASE_FIXTURE(Fixture, "widening_happens_almost_everywhere_except_for_tables TEST_CASE_FIXTURE(Fixture, "table_insert_with_a_singleton_argument") { ScopedFastFlag sff[]{ - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, {"LuauWidenIfSupertypeIsFree2", true}, }; @@ -577,8 +467,6 @@ TEST_CASE_FIXTURE(Fixture, "table_insert_with_a_singleton_argument") TEST_CASE_FIXTURE(Fixture, "functions_are_not_to_be_widened") { ScopedFastFlag sff[]{ - {"LuauParseSingletonTypes", true}, - {"LuauSingletonTypes", true}, {"LuauWidenIfSupertypeIsFree2", true}, }; @@ -595,7 +483,6 @@ TEST_CASE_FIXTURE(Fixture, "indexing_on_string_singletons") { ScopedFastFlag sff[]{ {"LuauDiscriminableUnions2", true}, - {"LuauSingletonTypes", true}, }; CheckResult result = check(R"( @@ -614,7 +501,6 @@ TEST_CASE_FIXTURE(Fixture, "indexing_on_union_of_string_singletons") { ScopedFastFlag sff[]{ {"LuauDiscriminableUnions2", true}, - {"LuauSingletonTypes", true}, }; CheckResult result = check(R"( @@ -633,7 +519,6 @@ TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_string_singleton") { ScopedFastFlag sff[]{ {"LuauDiscriminableUnions2", true}, - {"LuauSingletonTypes", true}, }; CheckResult result = check(R"( @@ -652,7 +537,6 @@ TEST_CASE_FIXTURE(Fixture, "taking_the_length_of_union_of_string_singleton") { ScopedFastFlag sff[]{ {"LuauDiscriminableUnions2", true}, - {"LuauSingletonTypes", true}, }; CheckResult result = check(R"( diff --git a/tests/TypeInfer.tables.test.cpp b/tests/TypeInfer.tables.test.cpp index 91140aaa4..0cc12d199 100644 --- a/tests/TypeInfer.tables.test.cpp +++ b/tests/TypeInfer.tables.test.cpp @@ -2078,6 +2078,44 @@ caused by: Property '__call' is not compatible. Type '(a, b) -> ()' could not be converted into '(a) -> ()'; different number of generic type parameters)"); } +TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_key") +{ + ScopedFastFlag luauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path + ScopedFastFlag luauExtendedIndexerError{"LuauExtendedIndexerError", true}; + + CheckResult result = check(R"( + type A = { [number]: string } + type B = { [string]: string } + + local a: A = { 'a', 'b' } + local b: B = a + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property '[indexer key]' is not compatible. Type 'number' could not be converted into 'string')"); +} + +TEST_CASE_FIXTURE(Fixture, "error_detailed_indexer_value") +{ + ScopedFastFlag luauTableSubtypingVariance2{"LuauTableSubtypingVariance2", true}; // Only for new path + ScopedFastFlag luauExtendedIndexerError{"LuauExtendedIndexerError", true}; + + CheckResult result = check(R"( + type A = { [number]: number } + type B = { [number]: string } + + local a: A = { 1, 2, 3 } + local b: B = a + )"); + + LUAU_REQUIRE_ERRORS(result); + CHECK_EQ(toString(result.errors[0]), R"(Type 'A' could not be converted into 'B' +caused by: + Property '[indexer value]' is not compatible. Type 'number' could not be converted into 'string')"); +} + TEST_CASE_FIXTURE(Fixture, "explicitly_typed_table") { ScopedFastFlag sffs[]{ diff --git a/tests/TypeInfer.test.cpp b/tests/TypeInfer.test.cpp index 571d0f8d6..660ddcfcf 100644 --- a/tests/TypeInfer.test.cpp +++ b/tests/TypeInfer.test.cpp @@ -334,7 +334,7 @@ TEST_CASE_FIXTURE(Fixture, "check_expr_recursion_limit") #if defined(LUAU_ENABLE_ASAN) int limit = 250; #elif defined(_DEBUG) || defined(_NOOPT) - int limit = 350; + int limit = 300; #else int limit = 600; #endif diff --git a/tests/TypeInfer.unionTypes.test.cpp b/tests/TypeInfer.unionTypes.test.cpp index ad1e31e5c..68b7c4fb7 100644 --- a/tests/TypeInfer.unionTypes.test.cpp +++ b/tests/TypeInfer.unionTypes.test.cpp @@ -296,6 +296,7 @@ return f() REQUIRE(acm); CHECK_EQ(1, acm->expected); CHECK_EQ(0, acm->actual); + CHECK_FALSE(acm->isVariadic); } TEST_CASE_FIXTURE(Fixture, "optional_field_access_error")