/* -LICENSE-START-
** Copyright (c) 2013 Blackmagic Design
**
** Permission is hereby granted, free of charge, to any person or organization
** obtaining a copy of the software and accompanying documentation covered by
** this license (the "Software") to use, reproduce, display, distribute,
** execute, and transmit the Software, and to prepare derivative works of the
** Software, and to permit third-parties to whom the Software is furnished to
** do so, all subject to the following:
** 
** The copyright notices in the Software and this entire statement, including
** the above license grant, this restriction and the following disclaimer,
** must be included in all copies of the Software, in whole or in part, and
** all derivative works of the Software, unless such copies or derivative
** works are solely in the form of machine-executable object code generated by
** a source language processor.
** 
** THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
** IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
** FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT
** SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE
** FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE,
** ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
** DEALINGS IN THE SOFTWARE.
** -LICENSE-END-
*/
#include <linux/mm.h>
#include <linux/vmalloc.h>
#include <linux/slab.h>
#include <linux/sched.h>
#include <linux/pci.h>
#include <linux/dma-mapping.h>
#include <linux/scatterlist.h>
#include "bm_mm.h"
#include "bm_util.h"
#include "bm_version.h"

struct bm_user_mem
{
	size_t n_pages;
	struct page* pages[0]; // First address
};

struct bm_dma_subpage
{
	uint32_t size;
	dma_addr_t busAddr;
};

struct bm_mmap
{
	vm_address_t vaddr;
	vm_address_t paddr;
	bm_user_mem_t umem; // Must be last member in struct
};

static bm_mm_stats_t statistics = { {0}, {0}, {0}, {0} };

static inline size_t get_num_pages(void* addr, size_t size)
{
	vm_address_t first = ((vm_address_t)addr) >> PAGE_SHIFT;
	vm_address_t last = ((vm_address_t)addr + size - 1) >> PAGE_SHIFT;
	return last - first + 1;
}

static void bm_put_user_pages(bm_user_mem_t* umem, bool dirty)
{
	size_t i;
	struct page *p;
	
	for (i = 0; i < umem->n_pages; i++)
	{
		p = umem->pages[i];
		if (p)
		{
			if (dirty)
				SetPageDirty(p);

			put_page(p);
			bm_atomic_sub(&statistics.pages_held, 1);
		}
	}
}

void bm_mm_put_user_pages(bm_user_mem_t* umem, bool dirty)
{
	bm_put_user_pages(umem, dirty);
	kfree(umem);
}

static bool bm_get_user_pages(bm_user_mem_t* umem, void* address, size_t n_pages, bool write)
{
	long ret;
	umem->n_pages = n_pages;

	if (!current->mm)
		return false;

#if KERNEL_VERSION_OR_LATER(4, 9, 0)
	ret = get_user_pages_unlocked((unsigned long)address & PAGE_MASK, umem->n_pages, umem->pages, write ? FOLL_WRITE : 0);
#elif KERNEL_VERSION_OR_LATER(4, 6, 0)
	ret = get_user_pages_unlocked((unsigned long)address & PAGE_MASK, umem->n_pages, write, 0, umem->pages);
#elif KERNEL_PATCH_VERSION_OR_LATER(4, 4, 168)
	ret = get_user_pages_unlocked(current, current->mm, (unsigned long)address & PAGE_MASK, umem->n_pages, umem->pages, write ? FOLL_WRITE : 0);
#elif KERNEL_VERSION_OR_LATER(4, 0, 0)
	ret = get_user_pages_unlocked(current, current->mm, (unsigned long)address & PAGE_MASK, umem->n_pages, write, 0, umem->pages);
#else
	down_read(&current->mm->mmap_sem);
		ret = get_user_pages(current, current->mm, (unsigned long)address & PAGE_MASK, umem->n_pages, write, 0, umem->pages, NULL);
	up_read(&current->mm->mmap_sem);
#endif

	if (ret < (long)n_pages)
	{
		bm_mm_put_user_pages(umem, 0);
		return false;
	}

	bm_atomic_add(&statistics.pages_held, n_pages);

	return true;
}

bm_user_mem_t* bm_mm_get_user_pages(void* address, vm_size_t size, bool write)
{
	bm_user_mem_t* umem;
	size_t length = get_num_pages(address, size);

	umem = kzalloc(sizeof(bm_user_mem_t) + length * sizeof(struct page*), GFP_KERNEL);
	if (!umem)
		return NULL;

	if (!bm_get_user_pages(umem, address, length, write))
		return NULL;

	return umem;
}

bm_mmap_t* bm_mmap_map(bm_user_mem_t* umem, void* address, vm_size_t size, bool write)
{
	bm_mmap_t* mmap;
	size_t n_pages = umem ? umem->n_pages : get_num_pages(address, size);

	mmap = kzalloc(sizeof(bm_mmap_t) + n_pages * sizeof(struct page*), GFP_KERNEL);
	if (!mmap)
		return NULL;

	if (umem)
	{
		size_t i;

		for (i = 0; i < n_pages; ++i)
		{
			mmap->umem.pages[i] = umem->pages[i];
			get_page(mmap->umem.pages[i]);
		}

		mmap->umem.n_pages = n_pages;
		bm_atomic_add(&statistics.pages_held, n_pages);
	}
	else
	{
		if (!bm_get_user_pages(&mmap->umem, address, n_pages, write))
		{
			kfree(mmap);
			return NULL;
		}
	}

	mmap->paddr = (vm_address_t)vmap(mmap->umem.pages, mmap->umem.n_pages, VM_MAP, PAGE_KERNEL);
	if (!mmap->paddr)
	{
		bm_put_user_pages(&mmap->umem, 0);
		kfree(mmap);
		return NULL;
	}

	bm_atomic_add(&statistics.pages_vmapped, mmap->umem.n_pages);

	mmap->vaddr = mmap->paddr + offset_in_page(address);

	return mmap;
}

void bm_mmap_unmap(bm_mmap_t* mmap)
{
	bm_atomic_sub(&statistics.pages_vmapped, mmap->umem.n_pages);
	vunmap((void*)mmap->paddr);
	bm_put_user_pages(&mmap->umem, 0);
	kfree(mmap);
}

vm_address_t bm_mmap_get_vaddress(bm_mmap_t* mmap)
{
	return mmap->vaddr;
}

bm_sg_table_t* bm_dma_sg_from_user_pages(bm_pci_device_t* pci, bm_user_mem_t* umem)
{
	bm_sg_table_t* sgTable = kmalloc(sizeof(*sgTable), GFP_KERNEL);
	if (sgTable && sg_alloc_table_from_pages(sgTable, umem->pages, umem->n_pages, 0, umem->n_pages * PAGE_SIZE, GFP_KERNEL) != 0)
	{
		kfree(sgTable);
		sgTable = NULL;
	}
	return sgTable;
}

bm_sg_table_t* bm_dma_sg_from_kernel_vmalloc(bm_pci_device_t* pci, void* addr, vm_size_t size)
{
	bm_sg_table_t* sgTable = kmalloc(sizeof(*sgTable), GFP_KERNEL);
	if (sgTable)
	{
		if (sg_alloc_table(sgTable, get_num_pages(addr, size), GFP_KERNEL) != 0)
		{
			kfree(sgTable);
			sgTable = NULL;
		}
		else
		{
			vm_address_t aligned_addr = (vm_address_t)addr & PAGE_MASK;
			size_t offset = 0;
			struct scatterlist* entry = sgTable->sgl;
			for (;;)
			{
				struct page* page = vmalloc_to_page((void*)(aligned_addr + offset));
				sg_set_page(entry, page, PAGE_SIZE, 0);
				entry = sg_next(entry);
				if (! entry)
					break;
				offset += PAGE_SIZE;
			}
		}
	}
	return sgTable;
}

int bm_dma_sg_bus_map(bm_pci_device_t* pci, bm_sg_table_t* sgTable, bm_dma_direction_t dir)
{
	int nents = dma_map_sg(&pci->pdev->dev, sgTable->sgl, sgTable->orig_nents, dir);
	if (unlikely(nents <= 0))
	{
		sg_free_table(sgTable);
		kfree(sgTable);
		return 0;
	}
	else
	{
		sgTable->nents = nents;
		bm_atomic_add(&statistics.pages_mapped, sgTable->orig_nents);
		return nents;
	}
}

void bm_dma_sg_bus_unmap(bm_pci_device_t* pci, bm_sg_table_t* sgTable, bm_dma_direction_t dir)
{
	dma_unmap_sg(&pci->pdev->dev, sgTable->sgl, sgTable->orig_nents, dir);
	bm_atomic_sub(&statistics.pages_mapped, sgTable->orig_nents);
	sg_free_table(sgTable);
	kfree(sgTable);
}

bm_sg_segment_t* bm_dma_sg_bus_map_first_segment(bm_sg_table_t* sgTable, vm_address_t sgDataBaseAddr, addr64_t* segAddress, uint64_t* segSize)
{
	unsigned int basePageOffset = offset_in_page(sgDataBaseAddr);
	struct scatterlist* currentSeg = sgTable->sgl;
	*segAddress = sg_dma_address(currentSeg) + basePageOffset;
	*segSize = sg_dma_len(currentSeg) - basePageOffset;
	return currentSeg;
}

bm_sg_segment_t* bm_dma_sg_bus_map_next_segment(bm_sg_segment_t* currentSeg, addr64_t* segAddress, uint64_t* segSize)
{
	currentSeg = sg_next(currentSeg);
	if (likely(currentSeg != NULL))
	{
		*segAddress = sg_dma_address(currentSeg);
		*segSize = sg_dma_len(currentSeg);
	}
	return currentSeg;
}

addr64_t bm_dma_sg_bus_segment(const bm_sg_table_t* sgTable, vm_address_t sgDataBaseAddr, vm_offset_t offset, uint32_t* lenLimit)
{
	vm_offset_t sumOfSeenSegs = 0;
	unsigned int segLen;
	dma_addr_t segAddr;
	uint32_t remainingContiguousSize;
	addr64_t startAddr;
	int entryIdx = 0;
	struct scatterlist* currentSeg = sgTable->sgl;
	unsigned int basePageOffset = offset_in_page(sgDataBaseAddr);

	segLen = sg_dma_len(currentSeg) - basePageOffset;
	for (;;)
	{
		sumOfSeenSegs += segLen;
		if (offset < sumOfSeenSegs)
			break;
		currentSeg = sg_next(currentSeg);
		++entryIdx;
		if (entryIdx >= sgTable->nents || ! currentSeg)
			return 0;
		segLen = sg_dma_len(currentSeg);
		if (! segLen)
			return 0;
	}

	remainingContiguousSize = sumOfSeenSegs - offset;
	segAddr = sg_dma_address(currentSeg);
	if (currentSeg == sgTable->sgl)
		segAddr += basePageOffset;
	startAddr = segAddr + segLen - remainingContiguousSize;

	if (lenLimit)
	{
		while (remainingContiguousSize < *lenLimit)
		{
			dma_addr_t nextIfContiguous = segAddr + segLen;
			currentSeg = sg_next(currentSeg);
			++entryIdx;
			if (entryIdx >= sgTable->nents || ! currentSeg)
				break;
			segAddr = sg_dma_address(currentSeg);
			if (nextIfContiguous != segAddr)
				break;
			segLen = sg_dma_len(currentSeg);
			if (! segLen)
				break;
			remainingContiguousSize += segLen;
		}
		if (*lenLimit > remainingContiguousSize)
			*lenLimit = remainingContiguousSize;
	}

	return startAddr;
}

bm_dma_subpage_t* bm_dma_bus_map_kernel_subpage(bm_pci_device_t* pci, void* addr, vm_size_t size, bm_dma_direction_t dir)
{
	bm_dma_subpage_t* subpage = kzalloc(sizeof(*subpage), GFP_KERNEL);
	if (! subpage)
		return NULL;
	subpage->busAddr = dma_map_single(&pci->pdev->dev, addr, size, dir);
	if (dma_mapping_error(&pci->pdev->dev, subpage->busAddr))
	{
		kfree(subpage);
		return NULL;
	}
	subpage->size = size;
	bm_atomic_add(&statistics.memory_mapped, size);
	return subpage;
}

void bm_dma_bus_unmap_kernel_subpage(bm_pci_device_t* pci, bm_dma_subpage_t* subpage, bm_dma_direction_t dir)
{
	dma_unmap_single(&pci->pdev->dev, subpage->busAddr, subpage->size, dir);
	bm_atomic_sub(&statistics.memory_mapped, subpage->size);
	kfree(subpage);
}

addr64_t bm_dma_bus_address_of_kernel_subpage(const bm_dma_subpage_t* subpage, void* addr, vm_offset_t offset, uint32_t* lenLimit)
{
	if (lenLimit && *lenLimit > subpage->size - offset)
		*lenLimit = subpage->size - offset;
	return subpage->busAddr + offset;
}

vm_address_t bm_mm_phys_to_virt(addr64_t phys)
{
	return (vm_address_t)phys_to_virt(phys);
}

const bm_mm_stats_t* bm_mm_statistics(void)
{
	return &statistics;
}
