Переглянути джерело

fix some problems with calling functions from assembly and improve instruction shortcuts

Kolja Strohm 3 місяців тому
батько
коміт
da0fa6f68f
3 змінених файлів з 386 додано та 49 видалено
  1. 165 14
      Assembly.cpp
  2. 145 8
      Assembly.h
  3. 76 27
      Framework Tests/Assembly.cpp

+ 165 - 14
Assembly.cpp

@@ -1187,7 +1187,14 @@ public:
         Framework::Assembly::OperationArgument* arg,
         int index) const
     {
-        result.immLength = 1;
+        if (result.immLength >= 8)
+        {
+            Framework::Text* err = new Framework::Text();
+            err->append() << "Invalid argument type for operand " << index
+                          << " for operation " << op
+                          << " encoded as IMM8: imm bytes are already in use";
+            throw err->getText();
+        }
         const Framework::Assembly::ConstantArgument* constArg
             = arg->asConstantArgument();
         if (constArg == 0)
@@ -1210,15 +1217,22 @@ public:
                           << " but expected size BYTE";
             throw err->getText();
         }
-        result.imm[0] = (char)(value);
-        result.immLength = 1;
+        result.imm[(int)result.immLength] = (char)(value);
+        result.immLength += 1;
     }
 
     void encodeIMM16(MachineCodeInstruction& result,
         Framework::Assembly::OperationArgument* arg,
         int index) const
     {
-        result.immLength = 1;
+        if (result.immLength >= 7)
+        {
+            Framework::Text* err = new Framework::Text();
+            err->append() << "Invalid argument type for operand " << index
+                          << " for operation " << op
+                          << " encoded as IMM8: imm bytes are already in use";
+            throw err->getText();
+        }
         const Framework::Assembly::ConstantArgument* constArg
             = arg->asConstantArgument();
         if (constArg == 0)
@@ -1242,15 +1256,22 @@ public:
             throw err->getText();
         }
         short val = (short)(value);
-        memcpy(result.imm, &val, 2);
-        result.immLength = 2;
+        memcpy(result.imm + result.immLength, &val, 2);
+        result.immLength += 2;
     }
 
     void encodeIMM32(MachineCodeInstruction& result,
         Framework::Assembly::OperationArgument* arg,
         int index) const
     {
-        result.immLength = 1;
+        if (result.immLength >= 5)
+        {
+            Framework::Text* err = new Framework::Text();
+            err->append() << "Invalid argument type for operand " << index
+                          << " for operation " << op
+                          << " encoded as IMM8: imm bytes are already in use";
+            throw err->getText();
+        }
         const Framework::Assembly::ConstantArgument* constArg
             = arg->asConstantArgument();
         if (constArg == 0)
@@ -1273,15 +1294,22 @@ public:
                           << " but expected size range [BYTE, DWORD]";
             throw err->getText();
         }
-        memcpy(result.imm, &value, 4);
-        result.immLength = 4;
+        memcpy(result.imm + result.immLength, &value, 4);
+        result.immLength += 4;
     }
 
     void encodeIMM64(MachineCodeInstruction& result,
         Framework::Assembly::OperationArgument* arg,
         int index) const
     {
-        result.immLength = 1;
+        if (result.immLength >= 1)
+        {
+            Framework::Text* err = new Framework::Text();
+            err->append() << "Invalid argument type for operand " << index
+                          << " for operation " << op
+                          << " encoded as IMM8: imm bytes are already in use";
+            throw err->getText();
+        }
         const Framework::Assembly::ConstantArgument* constArg
             = arg->asConstantArgument();
         if (constArg == 0)
@@ -1304,8 +1332,8 @@ public:
                           << " but expected size range [BYTE, QWORD]";
             throw err->getText();
         }
-        memcpy(result.imm, &value, 8);
-        result.immLength = 8;
+        memcpy(result.imm + result.immLength, &value, 8);
+        result.immLength += 8;
     }
 };
 
@@ -6296,6 +6324,28 @@ void __intializeMachineCodeTranslationTable()
                             Framework::Assembly::MemoryBlockSize::QWORD),
                         MODRM_RM,
                         READ)}));
+        OperationCodeTable::machineCodeTranslationTable.add(
+            new OperationCodeTable(Framework::Assembly::ENTER,
+                {// ENTER
+                    MachineCodeTableEntry(false,
+                        0xC8,
+                        (char)1,
+                        false,
+                        false,
+                        false,
+                        0,
+                        0,
+                        isIMM(Framework::Assembly::MemoryBlockSize::WORD),
+                        IMM16,
+                        READ,
+                        isIMM(Framework::Assembly::MemoryBlockSize::BYTE),
+                        IMM8,
+                        READ)}));
+        OperationCodeTable::machineCodeTranslationTable.add(
+            new OperationCodeTable(Framework::Assembly::LEAVE,
+                {// LEAVE
+                    MachineCodeTableEntry(
+                        false, 0xC9, (char)1, false, false, false, 0, 0)}));
         OperationCodeTable::machineCodeTranslationTable.add(
             new OperationCodeTable(Framework::Assembly::RET,
                 {// RET
@@ -7054,14 +7104,115 @@ void Framework::Assembly::AssemblyBlock::addLoadValue(
             new MemoryAccessArgument(MemoryBlockSize::QWORD, temp)}));
 }
 
+void Framework::Assembly::AssemblyBlock::addMoveValue(
+    GPRegister target, char value)
+{
+    instructions.add(new Instruction(MOV,
+        {new GPRegisterArgument(target, LOWER8), new ConstantArgument(value)}));
+}
+
+void Framework::Assembly::AssemblyBlock::addMoveValue(
+    GPRegister target, short value)
+{
+    instructions.add(new Instruction(MOV,
+        {new GPRegisterArgument(target, LOWER16),
+            new ConstantArgument(value)}));
+}
+
+void Framework::Assembly::AssemblyBlock::addMoveValue(
+    GPRegister target, int value)
+{
+    instructions.add(new Instruction(MOV,
+        {new GPRegisterArgument(target, LOWER32),
+            new ConstantArgument(value)}));
+}
+
+void Framework::Assembly::AssemblyBlock::addMoveValue(
+    GPRegister target, __int64 value)
+{
+    instructions.add(new Instruction(
+        MOV, {new GPRegisterArgument(target), new ConstantArgument(value)}));
+}
+
+void Framework::Assembly::AssemblyBlock::addMoveValue(
+    FPRegister target, float value, GPRegister temp)
+{
+    int data = *reinterpret_cast<int*>(&value);
+    addMoveValue(temp, data);
+    addPush(temp, LOWER32);
+    instructions.add(new Instruction(MOVSS,
+        {new FPRegisterArgument(target, X),
+            new MemoryAccessArgument(
+                MemoryBlockSize::DWORD, RSP, true, -4, true)}));
+    addPop(temp, LOWER32);
+}
+
+void Framework::Assembly::AssemblyBlock::addMoveValue(
+    FPRegister target, double value, GPRegister temp)
+{
+    __int64 data = *reinterpret_cast<__int64*>(&value);
+    addMoveValue(temp, data);
+    addPush(temp);
+    instructions.add(new Instruction(MOVSD,
+        {new FPRegisterArgument(target, X),
+            new MemoryAccessArgument(
+                MemoryBlockSize::QWORD, RSP, true, -8, true)}));
+    addPop(temp);
+}
+
+void Framework::Assembly::AssemblyBlock::addMoveValue(
+    GPRegister target, GPRegister source, GPRegisterPart part)
+{
+    instructions.add(new Instruction(MOV,
+        {new GPRegisterArgument(target, part),
+            new GPRegisterArgument(source, part)}));
+}
+
+void Framework::Assembly::AssemblyBlock::addMoveValue(
+    FPRegister target, FPRegister source, FPDataType type, FPRegisterPart part)
+{
+    Operation op = NOP;
+    switch (type)
+    {
+    case SINGLE_FLOAT:
+        op = MOVSS;
+        break;
+    case SINGLE_DOUBLE:
+        op = MOVSD;
+        break;
+    case PACKED_FLOAT:
+        op = MOVAPS;
+        break;
+    case PACKED_DOUBLE:
+        op = MOVAPD;
+        break;
+    }
+    instructions.add(new Instruction(op,
+        {new FPRegisterArgument(target, part),
+            new FPRegisterArgument(source, part)}));
+}
+
 void Framework::Assembly::AssemblyBlock::addCall(
     void* functionAddress, GPRegister temp)
 {
     instructions.add(new Instruction(MOV,
         {new GPRegisterArgument(temp),
             new ConstantArgument(reinterpret_cast<__int64>(functionAddress))}));
-    instructions.add(new Instruction(
-        CALL, {new MemoryAccessArgument(MemoryBlockSize::QWORD, temp)}));
+    instructions.add(new Instruction(CALL, {new GPRegisterArgument(temp)}));
+}
+
+void Framework::Assembly::AssemblyBlock::addEnter(
+    short stackSize, char nestingLevel)
+{
+    instructions.add(
+        new Framework::Assembly::Instruction(Framework::Assembly::ENTER,
+            {new Framework::Assembly::ConstantArgument(stackSize),
+                new Framework::Assembly::ConstantArgument(nestingLevel)}));
+}
+
+void Framework::Assembly::AssemblyBlock::addLeave()
+{
+    instructions.add(new Instruction(LEAVE, {}));
 }
 
 void Framework::Assembly::AssemblyBlock::addReturn()

+ 145 - 8
Assembly.h

@@ -113,14 +113,16 @@ namespace Framework
             // Jump if below or equal: SpecialRegister.FLAGS(CF) = 1 or
             // SpecialRegister.FLAGS(ZF) = 1
             JBE,
-            JO,   // Jump if overflow: SpecialRegister.FLAGS(OF) = 1
-            JNO,  // Jump if not overflow: SpecialRegister.FLAGS(OF) = 0
-            JP,   // Jump if parity even: SpecialRegister.FLAGS(PF) = 1
-            JNP,  // Jump if not parity odd: SpecialRegister.FLAGS(PF) = 0
-            JS,   // Jump if sign: SpecialRegister.FLAGS(SF) = 1
-            JNS,  // Jump if not sign: SpecialRegister.FLAGS(SF) = 0
-            CALL, // Call subroutine
-            RET,  // Return from subroutine
+            JO,    // Jump if overflow: SpecialRegister.FLAGS(OF) = 1
+            JNO,   // Jump if not overflow: SpecialRegister.FLAGS(OF) = 0
+            JP,    // Jump if parity even: SpecialRegister.FLAGS(PF) = 1
+            JNP,   // Jump if not parity odd: SpecialRegister.FLAGS(PF) = 0
+            JS,    // Jump if sign: SpecialRegister.FLAGS(SF) = 1
+            JNS,   // Jump if not sign: SpecialRegister.FLAGS(SF) = 0
+            CALL,  // Call subroutine
+            ENTER, // Set up stack frame for procedure
+            LEAVE, // Destroy stack frame for procedure
+            RET,   // Return from subroutine
 
             // Stack Operations
 
@@ -164,6 +166,14 @@ namespace Framework
             TRUE_S = 31    // True (signaling)
         };
 
+        enum FPDataType
+        {
+            SINGLE_FLOAT,
+            SINGLE_DOUBLE,
+            PACKED_FLOAT,
+            PACKED_DOUBLE
+        };
+
         // General Purpose Registers
         enum GPRegister
         {
@@ -179,6 +189,7 @@ namespace Framework
             // pushed to the stack. The stack grows downwards so lower means
             // more elements in the stack. needs to be aligned to 16 bytes
             RSP = 0b0100,
+            // base pointer points to the base of the current stack frame
             // non-volatile register (mus be restored on return)
             RBP = 0b0101,
             // non-volatile register (mus be restored on return)
@@ -675,6 +686,96 @@ namespace Framework
              */
             DLLEXPORT void addLoadValue(
                 double* valueAddress, FPRegister target, GPRegister temp = RAX);
+
+            /**
+             * calls a function at a specified memory address.
+             *
+             * \param functionAddress pointet to the address of the function to
+             * call
+             */
+            template<typename T>
+            void addLoadAddress(T* addr, GPRegister temp = RAX)
+            {
+                instructions.add(new Instruction(MOV,
+                    {new GPRegisterArgument(temp),
+                        new ConstantArgument(
+                            reinterpret_cast<__int64>(addr))}));
+            }
+
+            /**
+             * moves the given value to LOWER8 bits of target register.
+             *
+             * \param target the register to move the value to
+             * \param value the value to move
+             */
+            DLLEXPORT void addMoveValue(GPRegister target, char value);
+            /**
+             * moves the given value to LOWER16 bits of target register.
+             *
+             * \param target the register to move the value to
+             * \param value the value to move
+             */
+            DLLEXPORT void addMoveValue(GPRegister target, short value);
+            /**
+             * moves the given value to LOWER32 bits of target register.
+             *
+             * \param target the register to move the value to
+             * \param value the value to move
+             */
+            DLLEXPORT void addMoveValue(GPRegister target, int value);
+            /**
+             * moves the given value to LOWER64 bits of target register.
+             *
+             * \param target the register to move the value to
+             * \param value the value to move
+             */
+            DLLEXPORT void addMoveValue(GPRegister target, __int64 value);
+            /**
+             * moves the given value to LOWER32 bits of target register.
+             * The value needs to be temporarily pushed to the stack, so
+             * addLoadValue should be used if performance is critical.
+             *
+             * \param target the register to move the value to
+             * \param value the value to move
+             * \param temp temporary register that can be used us store the
+             * value
+             */
+            DLLEXPORT void addMoveValue(
+                FPRegister target, float value, GPRegister temp = RAX);
+            /**
+             * moves the given value to LOWER64 bits of target register.
+             * The value needs to be temporarily pushed to the stack, so
+             * addLoadValue should be used if performance is critical.
+             *
+             * \param target the register to move the value to
+             * \param value the value to move
+             * \param temp temporary register that can be used us store the
+             * value
+             */
+            DLLEXPORT void addMoveValue(
+                FPRegister target, double value, GPRegister temp = RAX);
+            /**
+             * moves the value from source register to target register.
+             *
+             * \param target the register to move the value to
+             * \param source the register to move the value from
+             * \param part the part of the register to move
+             */
+            DLLEXPORT void addMoveValue(GPRegister target,
+                GPRegister source,
+                GPRegisterPart part = FULL64);
+            /**
+             * moves the value from source register to target register.
+             *
+             * \param target the register to move the value to
+             * \param source the register to move the value from
+             * \param part the part of the register to move
+             */
+            DLLEXPORT void addMoveValue(FPRegister target,
+                FPRegister source,
+                FPDataType type,
+                FPRegisterPart part = Y);
+
             /**
              * calls a function at a specified memory address.
              *
@@ -683,6 +784,33 @@ namespace Framework
              */
             DLLEXPORT void addCall(
                 void* functionAddress, GPRegister temp = RAX);
+
+            /**
+             * calls a function at a specified memory address.
+             *
+             * \param functionAddress pointet to the address of the function to
+             * call
+             */
+            template<typename T>
+            void addMemberCall(T&& functionAddress, GPRegister temp = RAX)
+            {
+                addCall((void*&)functionAddress, temp);
+            }
+
+            /**
+             * adds an ENTER instruction to set up a stack frame for a function
+             *
+             * \param stackSize the size of the stack frame to create. should be
+             * a multiple of 16 and >= 32
+             * \param nestingLevel the nesting level of the function. usually 0
+             */
+            DLLEXPORT void addEnter(
+                short stackSize = 32, char nestingLevel = 0);
+            /**
+             * adds a LEAVE instruction to destroy the stack frame of a
+             * function.
+             */
+            DLLEXPORT void addLeave();
             /**
              * returns from executing the compiled assembly function.
              */
@@ -814,6 +942,15 @@ namespace Framework
              * reinterpret_cast<returnType(*)(parameterTypes...)>(compile())(parameters...)
              */
             DLLEXPORT void* compile();
+
+            /**
+             * \return a pointer to a function that contains the compiled byte
+             * code of this assembly block. and can be called directly with
+             */
+            template<typename T> T compileToFunction()
+            {
+                return reinterpret_cast<T>(compile());
+            }
         };
     } // namespace Assembly
 } // namespace Framework

+ 76 - 27
Framework Tests/Assembly.cpp

@@ -5,6 +5,29 @@
 
 using namespace Microsoft::VisualStudio::CppUnitTestFramework;
 
+int globalFunc(int a, int b)
+{
+    return a * b;
+}
+
+class A
+{
+public:
+    virtual int getValue()
+    {
+        return 0;
+    }
+};
+
+class B : public A
+{
+public:
+    virtual int getValue() override
+    {
+        return 10;
+    }
+};
+
 namespace FrameworkTests
 {
     TEST_CLASS (AssemblyTests)
@@ -13,13 +36,9 @@ namespace FrameworkTests
         TEST_METHOD (Add8Test)
         {
             Framework::Assembly::AssemblyBlock codeBlock;
-            codeBlock.addInstruction(
-                new Framework::Assembly::Instruction(Framework::Assembly::MOV,
-                    {new Framework::Assembly::GPRegisterArgument(
-                         Framework::Assembly::RAX, Framework::Assembly::LOWER8),
-                        new Framework::Assembly::GPRegisterArgument(
-                            Framework::Assembly::RDX,
-                            Framework::Assembly::LOWER8)}));
+            codeBlock.addMoveValue(Framework::Assembly::RAX,
+                Framework::Assembly::RDX,
+                Framework::Assembly::LOWER8);
             codeBlock.addInstruction(
                 new Framework::Assembly::Instruction(Framework::Assembly::ADD,
                     {new Framework::Assembly::GPRegisterArgument(
@@ -38,13 +57,9 @@ namespace FrameworkTests
         TEST_METHOD (Add16Test)
         {
             Framework::Assembly::AssemblyBlock codeBlock;
-            codeBlock.addInstruction(new Framework::Assembly::Instruction(
-                Framework::Assembly::MOV,
-                {new Framework::Assembly::GPRegisterArgument(
-                     Framework::Assembly::RAX, Framework::Assembly::LOWER16),
-                    new Framework::Assembly::GPRegisterArgument(
-                        Framework::Assembly::RDX,
-                        Framework::Assembly::LOWER16)}));
+            codeBlock.addMoveValue(Framework::Assembly::RAX,
+                Framework::Assembly::RDX,
+                Framework::Assembly::LOWER16);
             codeBlock.addInstruction(new Framework::Assembly::Instruction(
                 Framework::Assembly::ADD,
                 {new Framework::Assembly::GPRegisterArgument(
@@ -63,13 +78,9 @@ namespace FrameworkTests
         TEST_METHOD (Add32Test)
         {
             Framework::Assembly::AssemblyBlock codeBlock;
-            codeBlock.addInstruction(new Framework::Assembly::Instruction(
-                Framework::Assembly::MOV,
-                {new Framework::Assembly::GPRegisterArgument(
-                     Framework::Assembly::RAX, Framework::Assembly::LOWER32),
-                    new Framework::Assembly::GPRegisterArgument(
-                        Framework::Assembly::RDX,
-                        Framework::Assembly::LOWER32)}));
+            codeBlock.addMoveValue(Framework::Assembly::RAX,
+                Framework::Assembly::RDX,
+                Framework::Assembly::LOWER32);
             codeBlock.addInstruction(new Framework::Assembly::Instruction(
                 Framework::Assembly::ADD,
                 {new Framework::Assembly::GPRegisterArgument(
@@ -87,12 +98,8 @@ namespace FrameworkTests
         TEST_METHOD (Add64Test)
         {
             Framework::Assembly::AssemblyBlock codeBlock;
-            codeBlock.addInstruction(
-                new Framework::Assembly::Instruction(Framework::Assembly::MOV,
-                    {new Framework::Assembly::GPRegisterArgument(
-                         Framework::Assembly::RAX),
-                        new Framework::Assembly::GPRegisterArgument(
-                            Framework::Assembly::RDX)}));
+            codeBlock.addMoveValue(
+                Framework::Assembly::RAX, Framework::Assembly::RDX);
             codeBlock.addInstruction(
                 new Framework::Assembly::Instruction(Framework::Assembly::ADD,
                     {new Framework::Assembly::GPRegisterArgument(
@@ -202,5 +209,47 @@ namespace FrameworkTests
             dresult = getd();
             Assert::AreEqual((double)66.0, dresult);
         }
+        int c;
+
+        int testMethod(int a, int b)
+        {
+            return a + b + c;
+        }
+
+        TEST_METHOD (testCall)
+        {
+            std::cout << 1.463f << 4.235;
+            Framework::Assembly::AssemblyBlock gfcodeBlock;
+            gfcodeBlock.addEnter();
+            gfcodeBlock.addMoveValue(Framework::Assembly::RCX, 20);
+            gfcodeBlock.addMoveValue(Framework::Assembly::RDX, 50);
+            gfcodeBlock.addCall(globalFunc);
+            gfcodeBlock.addLeave();
+            int (*f)() = gfcodeBlock.compileToFunction<int (*)()>();
+            int result = f();
+            Assert::AreEqual(globalFunc(20, 50), result);
+            Framework::Assembly::AssemblyBlock lfcodeBlock;
+            lfcodeBlock.addEnter();
+            lfcodeBlock.addLoadAddress(this, Framework::Assembly::RCX);
+            lfcodeBlock.addMoveValue(Framework::Assembly::RCX, 20);
+            lfcodeBlock.addMoveValue(Framework::Assembly::RDX, 50);
+            lfcodeBlock.addMemberCall(&AssemblyTests::testMethod);
+            lfcodeBlock.addLeave();
+            c = 1;
+            f = lfcodeBlock.compileToFunction<int (*)()>();
+            result = f();
+            A* a = new B();
+            Assert::AreEqual(testMethod(20, 50), result);
+            Framework::Assembly::AssemblyBlock vfcodeBlock;
+            vfcodeBlock.addEnter();
+            vfcodeBlock.addLoadAddress(a, Framework::Assembly::RCX);
+            vfcodeBlock.addMemberCall(&A::getValue);
+            vfcodeBlock.addLeave();
+            c = 1;
+            f = vfcodeBlock.compileToFunction<int (*)()>();
+            result = f();
+            Assert::AreEqual(a->getValue(), result);
+            delete a;
+        }
     };
 } // namespace FrameworkTests