labyrinth

It's a-maze-ing how deep the rabbit hole goes.

3 min read


Table of Contents

[ + ] Overview
[ + ] Reversing
[ + ] Solution

[ + ]Overview

Here is the challenge, a single prompt asking us for a simple answer:

Image

Let's take a look under the hood.

[ + ]Reversing

From a high level, it looks fairly straight forward. The function traverse which takes some input s must equal 9595 for us to get our flag.

Image

Taking a look at the traverse sub-routine it's not immediately clear what is happening. It seems that while the value a1 is not 0, we loop, add some values to v3 then return.

Image

To get a better idea of what is happening, we need to look at the dissassembly. Below is a screenshot of what is happening inside the loop. We can see on the left that there are a few conditional statements.

The first conditional statement checks if al is equal to L (0x4C). If so, it will jump to the blue box, get the address at rax + 0x8, then set that as the address for var_8 (v4).

the second conditional statement checks if al is equal to R (0x52). If so, it will jump to the orange box, get the address at rax + 0x10, then set that as the address for var_8 (v4).

Image

So what does this mean? Let's take a look at the data being used in the operation. First, let's dump it from the binary.

Image

And open it up in 010 Editor. Now it is starting to make a little more sense. The first 0x8 bytes (yellow) are the ones added to the total, the bytes from 0x8 - 0xF (red) is the address we jump to for the next iteration if L is specified, and the bytes from 0x10 - 0x17 (green) is the address we jump to for the next iteration if R is specified.

Image

So what now? Well, we can take the offset of each address from the base and use that as an index.

Since the base (PIE is disabled) is 0x601070 , we can calculate everything.

0: 
    Add 0xE3
    L -> 2 ((0x6010B0 - 0x601070) / 0x20)
    R -> 6 ((0x601130 - 0x601070) / 0x20)

1: 
    ...
...

We could continue this manually, but we can automate it. The idea is that we have a graph where the sum of all the positions we go to must equal 9595.

[ + ]Solution

First we need to parse the data we dumped:

offset = 0x601070
data = []
with open('hex_dump.dat', 'rb') as f:
    dump = f.read()
    count = 0
    while count < len(dump):
        tmp = [dump[count + 3], dump[count + 2], dump[count + 1], dump[count]]
        if not all(x == 0 for x in tmp):
            data.append(int.from_bytes(bytearray(tmp), byteorder='big'))
        count += 4

data_dict = {}

for pos, x in enumerate(data):
    current_index = pos // 3
    if not data_dict.get(current_index):
        data_dict[current_index] = {}
        data_dict[current_index]['path'] = {}
    item_type = pos - (current_index * 3)
    if item_type == 0:
        data_dict[current_index]['value'] = x
    elif item_type == 1:
        data_dict[current_index]['path'][1] = (x - offset) // 32
    else:
        data_dict[current_index]['path'][2] = (x - offset) // 32

value_mapping = {}
for k, v in data_dict.items():
    value_mapping[v['value']] = k

Now that we have data_dict in a graph-like format, we can do a breadth first search:

def find_path(data, node, target, path=None):
    if path is None:
        path = []
    if target == data[node]['value']:
        return path
    if target < data[node]['value']:
        return None
    for choice, next_node in data[node]['path'].items():
        new_path = path + [choice]
        result = find_path(data, next_node, target - data[node]['value'], new_path)
        if result is not None:
            return result
    return None

My algorithm was not perfect, so I had to manually add the remaining move after replacing 1s and 2s:

if valid_path is not None:
    valid_path = ''.join([str(x) for x in valid_path]).replace('1', 'L').replace('2', 'R')
    # Alg is not perfect, need to append L to the end
    valid_path += 'L'
    return valid_path

After all of that, we can solve the challenge.

Image
 Solver
from pwn import *

# Context
context.arch = 'amd64'
context.log_level = 'DEBUG'


# Main vars
NETID = ''
HOST, PORT = 'host', 1253


def find_path(data, node, target, path=None):
    if path is None:
        path = []
    if target == data[node]['value']:
        return path
    if target < data[node]['value']:
        return None

    for choice, next_node in data[node]['path'].items():
        new_path = path + [choice]
        result = find_path(data, next_node, target - data[node]['value'], new_path)
        if result is not None:
            return result

    return None


def solver():
    offset = 0x601070
    data = []
    with open('hex_dump.dat', 'rb') as f:
        dump = f.read()
        count = 0
        while count < len(dump):
            tmp = [dump[count + 3], dump[count + 2], dump[count + 1], dump[count]]
            if not all(x == 0 for x in tmp):
                data.append(int.from_bytes(bytearray(tmp), byteorder='big'))
            count += 4

    data_dict = {}

    for pos, x in enumerate(data):
        current_index = pos // 3
        if not data_dict.get(current_index):
            data_dict[current_index] = {}
            data_dict[current_index]['path'] = {}
        item_type = pos - (current_index * 3)
        if item_type == 0:
            data_dict[current_index]['value'] = x
        elif item_type == 1:
            data_dict[current_index]['path'][1] = (x - offset) // 32
        else:
            data_dict[current_index]['path'][2] = (x - offset) // 32

    value_mapping = {}
    for k, v in data_dict.items():
        value_mapping[v['value']] = k

    target = 9595
    valid_path = find_path(data_dict, 0, target)
    if valid_path is not None:
        valid_path = ''.join([str(x) for x in valid_path]).replace('1', 'L').replace('2', 'R')
        # Alg is not perfect, need to append L to the end
        valid_path += 'L'
        return valid_path
    return None


def main():
    conn = remote(HOST, PORT)
    conn.recvuntil(b'(something like abc123): ')
    conn.sendline(NETID)
    conn.recvuntil(b'You\'re trapped in a windy, loopy maze. All the walls look the same. Can you find your way through?\n')
    solution = solver()
    if solution is not None:
        conn.sendline(solution)
        conn.recvuntil(b'flag{')
        response = conn.recvline()
        conn.close()
        print("flag{" + response.decode().strip())


if __name__ == "__main__":
    main()