diff --git a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Base.hs b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Base.hs index 7c414076b..5955334b8 100644 --- a/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Base.hs +++ b/accelerate-llvm-ptx/src/Data/Array/Accelerate/LLVM/PTX/CodeGen/Base.hs @@ -445,7 +445,7 @@ shfl sop tR val delta = go tR val repack :: Int32 -> CodeGen PTX (Operands (Vec m Int32)) repack 0 = return $ ir v' (A.undef (VectorScalarType v')) repack i = do - d <- instr $ ExtractElement (i-1) c + d <- instr $ ExtractElement integralType c (constOp (i-1)) e <- integral integralType d f <- repack (i-1) g <- instr $ InsertElement (i-1) (op v' f) (op integralType e) diff --git a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Arithmetic.hs b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Arithmetic.hs index 48555c94a..47ab72c60 100644 --- a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Arithmetic.hs +++ b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Arithmetic.hs @@ -20,6 +20,8 @@ module Data.Array.Accelerate.LLVM.CodeGen.Arithmetic where +import Data.Primitive.Vec + import Data.Array.Accelerate.AST ( PrimMaybe ) import Data.Array.Accelerate.Analysis.Match import Data.Array.Accelerate.Representation.Tag @@ -464,6 +466,17 @@ min ty x y | otherwise = do c <- unbool <$> lte ty x y binop (flip Select c) ty x y +-- Vector operators +-- ---------------------- + +vecCreate :: VectorType (Vec n a) -> CodeGen arch (Operands (Vec n a)) +vecCreate = undefined + +vecIndex :: VectorType (Vec n a) -> IntegralType i -> Operands (Vec n a) -> Operands i -> CodeGen arch (Operands a) +vecIndex tv ti (OP_Vec v) i = do + (OP_Int32 i') <- fromIntegral ti (IntegralNumType TypeInt32) i + instr $ ExtractElement TypeInt32 v i' + -- Logical operators -- ----------------- diff --git a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Array.hs b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Array.hs index ea984d21c..1f9ca1f82 100644 --- a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Array.hs +++ b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Array.hs @@ -28,6 +28,7 @@ import LLVM.AST.Type.AddrSpace import LLVM.AST.Type.Instruction import LLVM.AST.Type.Instruction.Volatile import LLVM.AST.Type.Operand +import LLVM.AST.Type.Constant import LLVM.AST.Type.Representation import Data.Array.Accelerate.Representation.Array @@ -205,16 +206,15 @@ store addrspace volatility e p v | SingleScalarType{} <- e = do_ $ Store volatility p v | VectorScalarType s <- e , VectorType n base <- s - , m <- fromIntegral n - = if popCount m == 1 + = if popCount n == 1 then do_ $ Store volatility p v else do p' <- instr' $ PtrCast (PtrPrimType (ScalarPrimType (SingleScalarType base)) addrspace) p -- - let go i - | i >= m = return () + let go i + | i >= n = return () | otherwise = do - x <- instr' $ ExtractElement i v + x <- instr' $ ExtractElement integralType v (ConstantOperand (ScalarConstant scalarType n)) q <- instr' $ GetElementPtr p' [integral integralType i] _ <- instr' $ Store volatility q x go (i+1) diff --git a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Constant.hs b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Constant.hs index 26c9497ed..962f1d86f 100644 --- a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Constant.hs +++ b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Constant.hs @@ -61,7 +61,7 @@ scalar t = ConstantOperand . ScalarConstant t single :: SingleType a -> a -> Operand a single t = scalar (SingleScalarType t) -vector :: VectorType (Vec n a) -> (Vec n a) -> Operand (Vec n a) +vector :: VectorType (Vec n a) -> Vec n a -> Operand (Vec n a) vector t = scalar (VectorScalarType t) num :: NumType a -> a -> Operand a diff --git a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Exp.hs b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Exp.hs index 77a2f0860..95149b14b 100644 --- a/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Exp.hs +++ b/accelerate-llvm/src/Data/Array/Accelerate/LLVM/CodeGen/Exp.hs @@ -48,7 +48,8 @@ import qualified Data.Array.Accelerate.LLVM.CodeGen.Loop as L import Data.Primitive.Vec import LLVM.AST.Type.Instruction -import LLVM.AST.Type.Operand ( Operand ) +import LLVM.AST.Type.Operand ( Operand(..), constOp) +import LLVM.AST.Type.Constant ( Constant(..), ) import Control.Applicative hiding ( Const ) import Control.Monad @@ -105,7 +106,7 @@ llvmOfOpenExp top env aenv = cvtE top llvmOfOpenExp body (env `pushE` (lhs, x)) aenv Evar (Var _ ix) -> return $ prj ix env Const tp c -> return $ ir tp $ scalar tp c - PrimConst c -> let tp = (SingleScalarType $ primConstType c) + PrimConst c -> let tp = primConstType c in return $ ir tp $ scalar tp $ primConst c PrimApp f x -> primFun f x Undef tp -> return $ ir tp $ undef tp @@ -165,7 +166,7 @@ llvmOfOpenExp top env aenv = cvtE top go (VecRnil _) _ = internalError "index mismatch" go (VecRsucc vecr') i = do xs <- go vecr' (i - 1) - x <- instr' $ ExtractElement (fromIntegral i - 1) vec + x <- instr' $ ExtractElement TypeInt vec (constOp (i - 1)) return $ OP_Pair xs (ir singleTp x) singleTp :: SingleType single -- GHC 8.4 cannot infer this type for some reason @@ -307,6 +308,7 @@ llvmOfOpenExp top env aenv = cvtE top PrimEq t -> primbool $ A.uncurry (A.eq t) =<< cvtE x PrimNEq t -> primbool $ A.uncurry (A.neq t) =<< cvtE x PrimLNot -> primbool $ A.lnot =<< bool (cvtE x) + PrimVectorIndex v i -> A.uncurry (A.vecIndex v i) =<< cvtE x -- no missing patterns, whoo! diff --git a/accelerate-llvm/src/LLVM/AST/Type/Instruction.hs b/accelerate-llvm/src/LLVM/AST/Type/Instruction.hs index 4566ab1d2..b54df8486 100644 --- a/accelerate-llvm/src/LLVM/AST/Type/Instruction.hs +++ b/accelerate-llvm/src/LLVM/AST/Type/Instruction.hs @@ -182,8 +182,9 @@ data Instruction a where -- -- - ExtractElement :: Int32 -- TupleIdx (ProdRepr (Vec n a)) a + ExtractElement :: IntegralType i -- TupleIdx (ProdRepr (Vec n a)) a -> Operand (Vec n a) + -> Operand i -> Instruction a -- @@ -406,7 +407,7 @@ instance Downcast (Instruction a) LLVM.Instruction where BXor _ x y -> LLVM.Xor (downcast x) (downcast y) md LNot x -> LLVM.Xor (downcast x) (LLVM.ConstantOperand (LLVM.Int 1 1)) md InsertElement i v x -> LLVM.InsertElement (downcast v) (downcast x) (constant i) md - ExtractElement i v -> LLVM.ExtractElement (downcast v) (constant i) md + ExtractElement _ v i -> LLVM.ExtractElement (downcast v) (downcast i) md ExtractValue _ i s -> extractStruct i (downcast s) Load _ v p -> LLVM.Load (downcast v) (downcast p) atomicity alignment md Store v p x -> LLVM.Store (downcast v) (downcast p) (downcast x) atomicity alignment md @@ -594,7 +595,7 @@ instance TypeOf Instruction where LAnd x _ -> typeOf x LOr x _ -> typeOf x LNot x -> typeOf x - ExtractElement _ x -> typeOfVec x + ExtractElement _ x _ -> typeOfVec x InsertElement _ x _ -> typeOf x ExtractValue t _ _ -> scalar t Load t _ _ -> scalar t diff --git a/accelerate-llvm/src/LLVM/AST/Type/Operand.hs b/accelerate-llvm/src/LLVM/AST/Type/Operand.hs index 9bd3ab21e..f2e73b8c0 100644 --- a/accelerate-llvm/src/LLVM/AST/Type/Operand.hs +++ b/accelerate-llvm/src/LLVM/AST/Type/Operand.hs @@ -15,6 +15,7 @@ module LLVM.AST.Type.Operand ( Operand(..), + constOp, ) where @@ -32,6 +33,9 @@ data Operand a where LocalReference :: Type a -> Name a -> Operand a ConstantOperand :: Constant a -> Operand a +constOp :: (IsScalar a) => a -> Operand a +constOp x = ConstantOperand (ScalarConstant scalarType x) + -- | Convert to llvm-hs --