#!/usr/bin/env python3

"""
This program inspects a Linux operating system and generates C source
code capable of printing human-readable descriptions of system calls
when given raw information from registers.  The script gathers the
mapping between system-call identifiers and system-call names through
the use of ausyscall. It also attempts to pull several headers from the
Linux kernel source tree of the running version of the kernel in order
to gather information about the arguments to each system call.
"""

import subprocess
import urllib.request
import re
import time
from optparse import OptionParser

ARG_TYPE_OVERRIDES = {
    "sys_open": ( ( "char *", "%s" ),  ( "int", "%i" ), ( "unsigned long", "%lu" ) )
}

class syscall_data:
    """This class stores the information of a single system call. The data
    includes the syscall name, identifier, and argument information.
    """

    def __init__(self, syscall_id, syscall_name):
        """The __init__ method is called to create the object.
        
        Inputs:
            syscall_id   -- integer
            syscall_name -- string

        No Output
        """

        self.__number           = syscall_id
        self.__name             = syscall_name
        self.__arg_names        = []
        self.__arg_types        = []
        self.__arg_printf_types = [] # Basic for esoteric (e.g., int for pid_t))
        self.__arg_formats      = []


    def __repr__(self):
        """The __repr__ fucntion creates a string representation of the object.

        Inputs:
            None

        Outputs:
            string representation of the object
        """
        syscall_str = self.get_syscall_name() + "("
        args_count  = self.get_args_count()

        for arg_num in range(args_count):
            syscall_str += self.get_arg_type(arg_num) + " " 
            syscall_str += self.get_arg_name(arg_num) + ", "
            
            if arg_num == args_count - 1:       # Remove the trailing space
               syscall_str = syscall_str[:-2]   # and comma if last argument.

        return syscall_str + ");"
        
    def get_syscall_name(self):
        """Return the syscall name"""
        return self.__name

    def get_args_count(self):
        """Return the number of arguments"""
        return len(self.__arg_types)

    def get_arg_type(self, arg_num):
        """Return the argument type of the argument at index arg_num"""
        return self.__arg_types[arg_num]

    def get_arg_name(self, arg_num):
        """Return the name of the argument at index arg_num"""
        return self.__arg_names[arg_num]
	
    def get_arg_printf_type(self, arg_num):
        """Return the printf type of the argument at index arg_num"""
        return self.__arg_printf_types[arg_num]

    def get_arg_format(self, arg_num):
        """Return the format string for the argument at index arg_num"""
        return self.__arg_formats[arg_num]

    def add_argument(self, arg_type, arg_name = ""):
        """Add an argument to the syscall. This method will append a value
        to each of the argument information lists

        Input:
            arg_type -- string
            arg_name -- string -- default = ""

        No Output
        """
        self.__arg_types.append(arg_type)
        
        if arg_name == "": # Use generic name if none provided.
            self.__arg_names.append("arg{0}".format(self.get_args_count()))
        else:
            self.__arg_names.append(arg_name)

        arg_formats = self.__get_printf_format_type(arg_type)
        if self.__name in ARG_TYPE_OVERRIDES:
            self.__arg_printf_types.append(ARG_TYPE_OVERRIDES[self.__name][len(self.__arg_formats)][0])
        else:
            self.__arg_printf_types.append(arg_formats[0])
        if self.__name in ARG_TYPE_OVERRIDES:
            self.__arg_formats.append(ARG_TYPE_OVERRIDES[self.__name][len(self.__arg_formats)][1])
        else:
            self.__arg_formats.append(arg_formats[1])
        
    def __get_printf_format_type(self, arg_type):
        """Return the type that will be used in the C printf function."""

        uint_args = [ "unsigned int",
                      "qid_t",
                      "uid_t",
                      "gid_t",
                      "clockid_t",
                      "timer_t",
                      "unsigned long",
                      "size_t",
                      "unsigned",
                      "umode_t",
                      "old_uid_t",
                      "old_gid_t",
                      "u64"
                    ]   

        int_args = [
                      "int",
                      "pid_t",
                      "key_t",
                      "mqd_t",
                      "key_serial_t"
                    ]

        lint_args = [
                      "loff_t",
                      "long"
                    ]

        if "*" in arg_type:
            return ("unsigned long", '0x%"PRIx64"')
        
        for arg in uint_args:
            if arg in arg_type:
                return ("unsigned long", "%lu")

        for arg in int_args:
            if arg in arg_type:
                return ("int", "%i")

        for arg in lint_args:
            if arg in arg_type:
                return ("long int", "%li")

        else:
            return ("unsigned long", '0x%"PRIx64"')

class kernel_data:
    """Holds data about the kernel upon which the script runs.""" 
    
    def __init__(self, kernel_info, get_kernel_source):
        """The __init__ method is called to create the object.
        
        Inputs:
            version           -- string
            get_kernel_source -- boolean

        No Output
        """
        self.__name = self.__set_name(kernel_info)
        self.__version = kernel_info[1]
        self.__version_short = self.__set_short_version(kernel_info)
        self.__arch = self.__set_arch(kernel_info)
        self.__syscall_header_src = []
        
        if get_kernel_source:
            self.__syscall_header_src = self.__fetch_syscall_header_src()  
    
    def __repr__(self):
        """Return string representation of the kernel_data"""
        return self.get_kernel_version()
            
    def get_kernel_name(self):
        """Return the kernel name"""
        return self.__name
    
    def get_kernel_version(self):
        """Return the long version of the kernel"""
        return self.__version

    def get_kernel_version_short(self):
        """Return the short kernel version number"""
        return self.__version_short

    def get_kernel_arch(self):
        """Return the kernel architecture"""
        return self.__arch
    
    def get_header_src(self):
        """Returns the list of lines from the source header files"""
        return self.__syscall_header_src

    def __set_name(self, kernel_info):
        """Set the name of the kernel based on the version"""
        return kernel_info[0]

    def __set_short_version(self, kernel_info):
        """Set the short kernel version number based on the version"""
        return ".".join(kernel_info[1].split("_")[1].split(".")[:2])        

    def __set_arch(self, kernel_info):
        """Set the kernel architecture based on the version"""
        return kernel_info[2]
    
    def __fetch_syscall_header_src(self):
        """Attempts to get the two header files that contain the syscall prototypes online first then locally"""
        linux_h_url = "http://lxr.free-electrons.com/source/include/linux/syscalls.h?v={0};raw=1".format(self.get_kernel_version_short())
        arch_h_url = "http://lxr.free-electrons.com/source/arch/{0}/include/asm/syscalls.h?v={1};raw=1".format(self.get_kernel_arch(), self.get_kernel_version_short())

        try:
            linux_h_src = urllib.request.urlopen(linux_h_url)
            arch_h_src = urllib.request.urlopen(arch_h_url)
            linux_h = linux_h_src.readlines()
            linux_h.extend(arch_h_src.readlines())
            return linux_h
        except urllib2.URLError:
            print("failed to open source-code URLs; attempting to open local files")

        try:
            linux_h_src = open("linux_syscalls.h", "r")
            arch_h_src = open("arch_syscalls.h", "r")
            linux_h = linux_h_src.readlines()
            linux_h.extend(arch_h_src.readlines())
            return linux_h
       
        except IOError:
            print("failed to open local files")

def get_kernel_info(get_kernel_source):
    """Get kernel information and create a kernel_data object"""
    kernel_name = subprocess.check_output(["uname", "-s"]).decode().strip()
    kernel_version = kernel_name + "_" + subprocess.check_output(["uname", "-r"]).decode().strip()
    arch = subprocess.check_output(["uname", "-m"]).decode().strip()
    kernel_info = [kernel_name, kernel_version, arch]
    return kernel_data(kernel_info, get_kernel_source)


def get_syscalls():
    """For each syscall number, parse the name and create a syscall_data object"""
    syscall_id = 0
    syscall_name = subprocess.check_output(["ausyscall", "{0}".format(syscall_id)], stderr=subprocess.STDOUT).decode()
    syscall_map = {}
    syscall_order = []
    
    while (True):
        syscall_name = "sys_" + syscall_name.strip()
        syscall_map[syscall_name] = syscall_data(syscall_id, syscall_name)
        syscall_order.append(syscall_name)
        syscall_id += 1

        try: 
            syscall_name = subprocess.check_output(["ausyscall", "{0}".format(syscall_id)], stderr=subprocess.STDOUT).decode()
            if syscall_name[0] == "_":
                syscall_name = syscall_name[1:]
    
        except subprocess.CalledProcessError:
            return (syscall_map, syscall_order)

def extract_syscall_prototypes(source_data):
    """Extract the system call prototypes from the source header file data"""
    curr_syscall = ""
    syscall_prototypes = []

    line = 0
    end = len(source_data)

    while (line < end):
        if "asmlinkage" == source_data[line].decode().split(" ")[0]:
            curr_syscall = source_data[line].decode().strip()
            line += 1
            while ("#" not in source_data[line].decode()
               and "/*" not in source_data[line].decode()
               and "asmlinkage" not in source_data[line].decode()):
                curr_syscall += " " + source_data[line].decode().strip()
                line += 1
            syscall_prototypes.append(curr_syscall)
        else:
            line += 1

    return syscall_prototypes

def parse_syscall_protos(syscall_prototypes, syscall_map):
    """Parse the syscall prototypes to add arguments to their 
    respective syscall_data objects"""
    invalid_names = ["int", "long", "*"] # Takes care of protos with no arg names.
     
    for syscall in syscall_prototypes:
        curr_syscall = re.findall('sys.*', syscall)[0].strip() # Strip the type before the syscall.
        curr_syscall = curr_syscall.split("(")
        syscall_name = curr_syscall[0].strip()

        # These syscall names and prototypes differ; fix here.
        if syscall_name == "sys_pwrite64":
            syscall_name = "sys_pwrite"
        if syscall_name == "sys_pread64":
            syscall_name = "sys_pread"
        
        # Only want to add information to syscalls that we have in our map
        # which was generated with ausyscall.
        if syscall_name in  syscall_map:
            curr_args = curr_syscall[1].split(", ")
            if curr_args[0] == "void);":    # No arguments if the first is void.
                continue

            if syscall_map[syscall_name].get_args_count() != 0:
                # Likely uses #ifdef to provide alternate definitions of syscall.
                print("skip repeat", syscall_name, "def.: ", syscall_map[syscall_name])
                continue

            for arg in curr_args:
                
                if arg[-2:] == ");":        # Remove the end for the last arg.
                    arg = arg[:-2]
                split_arg = arg.split(" ")  # Split the arg. to get the name and type.
                arg_name = split_arg[len(split_arg) - 1]    # Last of the split is the name.
                arg_type = " ".join(split_arg[:-1])         # Join all but the last for the type.

                if arg_name != "" and arg_name[0] == "*":   # Check to see if its a pointer because 
                    arg_name = arg_name[1:]                 # the * will be with name
                    arg_type += " *"                        # swap the * to be with the type.

                if arg_name in invalid_names:               # Some protos dont have names for args
                    arg_type += " " + arg_name              # so we fix that here.
                    arg_name = ""
            
                syscall_map[syscall_name].add_argument(arg_type, arg_name)
    
    return syscall_map

def create_struct_c(syscall_map, syscall_order, kernel_info):
    # generate beginning of source file 
    str = "/* Generated on {0} on {1}*/\n".format(kernel_info.get_kernel_version(), time.strftime("%d %b %Y %H:%M:%S", time.localtime()))
    str += """
#include <libvmi/libvmi.h>
#include <libvmi/events.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>
#include <inttypes.h>

#include "generated-linux.h"

static const int RETURN_ADDR_WIDTH = sizeof(void *);

static char *
get_process_name(vmi_instance_t vmi, vmi_pid_t pid) 
{
	/* Gets the process name of the process with the input pid */
	/* offsets from the LibVMI config file */	
	unsigned long task_offset = vmi_get_offset(vmi, "linux_tasks");
	unsigned long pid_offset = vmi_get_offset(vmi, "linux_pid");
	unsigned long name_offset = vmi_get_offset(vmi, "linux_name");
	
	/* addresses for the linux process list and current process */
	addr_t list_head = 0;
	addr_t list_curr = 0;
	addr_t curr_proc = 0;
	
	vmi_pid_t curr_pid = 0;		/* pid of the processes task struct we are examining */
	char *proc = NULL;		/* process name of the current process we are examining */

	list_head = vmi_translate_ksym2v(vmi, "init_task") + task_offset; 	/* get the address to the head of the process list */

	if (list_head == task_offset) {
		fprintf(stderr, "failed to read address for init_task\\n");
		goto done;
	}
	
	list_curr = list_head;							/* set the current process to the head */

	do{
		curr_proc = list_curr - task_offset;						/* subtract the task offset to get to the start of the task_struct */
		if (VMI_FAILURE == vmi_read_32_va(vmi, curr_proc + pid_offset, 0, (uint32_t*)&curr_pid)) {		/* read the current pid using the pid offset from the start of the task struct */
			fprintf(stderr, "failed to get the pid of the process we are examining\\n");
			goto done;
		}
	
		if (pid == curr_pid) {
			proc = vmi_read_str_va(vmi, curr_proc + name_offset, 0);		/* get the process name if the current pid is equal to the pis we are looking for */
			goto done;								/* go to done to exit */
		}
	
		if (VMI_FAILURE == vmi_read_addr_va(vmi, list_curr, 0, &list_curr)) {				/* read the memory from the address of list_curr which will return a pointer to the */
			fprintf(stderr, "failed to get the next task in the process list\\n");
			goto done;
		}

	} while (list_curr != list_head);							/* next task_struct. Continue the loop until we get back to the beginning as the  */
												/* process list is doubly linked and circular */

done:
	if (NULL == proc) {		/* if proc is NULL we don't know the process name */
		return "unknown";
	}
	
	return proc;

}

"""

    # generate the print function for each syscall
    for syscall in syscall_order:
        str += create_print_function(syscall_map[syscall])

    # generate the print_sysret_info function
    str += """void gt_linux_print_sysret(vmi_instance_t vmi, vmi_event_t *event, vmi_pid_t pid, gt_tid_t tid, void *user_data) {
	reg_t syscall_return = event->x86_regs->rax;
	reg_t rsp = event->x86_regs->rsp;
	fprintf(stderr, "pid: %u/0x%"PRIx64" (%s) return: 0x%"PRIx64"\\n", pid, rsp - RETURN_ADDR_WIDTH, get_process_name(vmi, pid), syscall_return);
}

"""

    # Generate SYSCALLS table.
    str += """const GTSyscallCallback GT_LINUX_SYSCALLS[] = {
"""
    
    # generate the actual struct list
    for syscall_id in range(len(syscall_order)):
        str += "\t{{ \"{0}\", gt_linux_print_syscall_{1}, gt_linux_print_sysret }},\n".format(syscall_order[syscall_id], syscall_order[syscall_id])

    str += """\t{ NULL, NULL, NULL }
};
"""

    # write all the functions to generated-linux.c
    f = open("generated-linux.c", "w+")
    f.write(str)
    f.close

def create_print_function(syscall_data):
    """Generate a print function for a single syscall"""
    func_str = """void *gt_linux_print_syscall_{0}(vmi_instance_t vmi, vmi_event_t *event, vmi_pid_t pid, gt_tid_t tid, void *user_data)
{{
	char *proc = get_process_name(vmi, pid);
""".format(syscall_data.get_syscall_name())
    func_str += create_printf_statement(syscall_data.get_args_count(), syscall_data)
    func_str += """	return NULL;
}\n\n"""

    return func_str

def create_printf_statement(count, syscall_data):
    REG = "rdi", "rsi", "rdx", "r10", "r8", "r9"

    func_str = ""

    for i in range(count):
        if syscall_data.get_arg_printf_type(i) == "char *":
            func_str += "\tchar *arg{0} = vmi_read_str_va(vmi, event->x86_regs->{1}, pid);\n".format(i, REG[i])
        else:
            func_str += "\treg_t arg{0} = event->x86_regs->{1};\n".format(i, REG[i])
    func_str += '\treg_t rsp = event->x86_regs->rsp;\n'
    func_str += '\tfprintf(stderr, "pid: %u/0x%"PRIx64" (%s) syscall: %s('

    if count != 0:
        if syscall_data.get_arg_printf_type(0) == "char *":
            func_str += '\\"{0}\\"'.format(syscall_data.get_arg_format(0))
        else:
            func_str += '{0}'.format(syscall_data.get_arg_format(0))
        for i in range(1, count):
            if syscall_data.get_arg_printf_type(i) == "char *":
                func_str += ', \\"{0}\\"'.format(syscall_data.get_arg_format(i))
            else:
                func_str += ', {0}'.format(syscall_data.get_arg_format(i))

    func_str += ')\\n", pid, rsp, proc, "' + syscall_data.get_syscall_name() + '"'

    if count != 0:
        for i in range(0, count):
            func_str += ', ({0}) arg{1}'.format(syscall_data.get_arg_printf_type(i), i)

    func_str += ');\n'

    return func_str

def main(get_kernel_source):
    """This function gets information on the kernel the script is running on,
    gathers a mapping of syscall names to syscall numbers for teh kernel it
    is running on, parses Linux source files for syscall prototypes, parses the
    prototypes to generate a complete object which holds a syscall name number and
    all information about its arguments and using this data generates two C files
    for use with guestrace."""
    print("Getting kernel info!")
    kernel_info = get_kernel_info(get_kernel_source)

    print("Getting syscall info! takes a few seconds...")
    (syscall_map, syscall_order) = get_syscalls()
    
    if get_kernel_source:
        print("Gathering syscall prototypes!")
        syscall_protos = extract_syscall_prototypes(kernel_info.get_header_src())
        
        print("Parsing prototypes!")
        syscall_map_complete = parse_syscall_protos(syscall_protos, syscall_map)
        
    else:
        print("Skipped functions invloving the linux source headers!")

    print("Generating generated-linux.c!")
    create_struct_c(syscall_map, syscall_order, kernel_info)

if __name__ == "__main__":
    parser = OptionParser()
    parser.add_option("-n", "--no_source",dest = "get_source",  default = True, action = "store_false", help = "Don't look for linux source files")
    (options, args) = parser.parse_args() 
    main(options.get_source)
