/*
    BFilter - a smart ad-filtering web proxy
    Copyright (C) 2002-2005  Joseph Artsimovich <joseph_a@mail.ru>

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*/

#include "pch.h"

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include "DnsCache.h"
#include "RefCountable.h"
#include "AtomicCounter.h"
#include "InetAddr.h"
#include <ace/Synch.h>
#include <ace/Singleton.h>
#include <ace/INET_Addr.h>
#include <ace/Time_Value.h>
#include <boost/multi_index_container.hpp>
#include <boost/multi_index/member.hpp>
#include <boost/multi_index/ordered_index.hpp>
#include <iterator>
#include <utility>

using namespace std;
using namespace boost::multi_index;

struct DnsCache::Entry
{
	Entry(InetAddr const& symbolic_addr,
		ACE_INET_Addr const& resolved_addr,
		const ACE_Time_Value m_timeout)
	:	symbolicAddr(symbolic_addr),
		resolvedAddr(resolved_addr),
		timeout(timeout)
	{
	}
	
	InetAddr symbolicAddr;
	ACE_INET_Addr resolvedAddr;
	ACE_Time_Value timeout; // absolute time value
};


class DnsCache::Impl : public DnsCache
{
public:
	Impl();
	
	virtual ~Impl();
	
	virtual void put(InetAddr const& symbolic_addr, ACE_INET_Addr const& resolved_addr);
	
	virtual bool get(InetAddr const& symbolic_addr, ACE_INET_Addr& resolved_addr);
private:
	class SymbolicAddrTag {};
	class TimeoutTag {};
	
	typedef ACE_Thread_Mutex Mutex;
	typedef multi_index_container<
		Entry,
		indexed_by<
			ordered_unique<
				tag<SymbolicAddrTag>,
				member<Entry, InetAddr, &Entry::symbolicAddr>
			>,
			ordered_non_unique<
				tag<TimeoutTag>,
				member<Entry, ACE_Time_Value, &Entry::timeout>
			>
		>
	> Container;
	typedef Container::index<SymbolicAddrTag>::type SymbolicAddrIdx;
	typedef Container::index<TimeoutTag>::type TimeoutIdx;
	
	void removeTimedOut();
	
	void removeExcess();
	
	Mutex m_mutex;
	Container m_container;
};


DnsCache*
DnsCache::instance()
{
	return ACE_Singleton<Impl, ACE_Recursive_Thread_Mutex>::instance();
}


/*============================ DnsCache::Impl ===========================*/

DnsCache::Impl::Impl()
{
}

DnsCache::Impl::~Impl()
{
}

void
DnsCache::Impl::put(InetAddr const& symbolic_addr, ACE_INET_Addr const& resolved_addr)
{
	ACE_GUARD_RETURN(Mutex, guard, m_mutex, );
	
	removeTimedOut();
	
	ACE_Time_Value timeout(TIMEOUT);
	timeout += ACE_OS::gettimeofday();
	
	Entry entry(symbolic_addr, resolved_addr, timeout);
	entry.symbolicAddr.setPort(80);
	// We keep ports the same, to avoid unnecessary lookups.
	
	pair<Container::iterator, bool> res = m_container.insert(entry);
	if (res.second) {
		removeExcess();
	} else {
		m_container.replace(res.first, entry);
	}
}

bool
DnsCache::Impl::get(InetAddr const& symbolic_addr, ACE_INET_Addr& resolved_addr)
{	
	ACE_GUARD_RETURN(Mutex, guard, m_mutex, false);
	
	removeTimedOut();
	
	InetAddr sym_addr(symbolic_addr);
	sym_addr.setPort(80); // see above
	
	SymbolicAddrIdx& idx = m_container.get<SymbolicAddrTag>();
	SymbolicAddrIdx::iterator it = idx.find(sym_addr);
	if (it != idx.end()) {
		resolved_addr = it->resolvedAddr;
		resolved_addr.set_port_number(symbolic_addr.getPort());
		return true;
	}
	
	return false;
}

void
DnsCache::Impl::removeTimedOut()
{
	// we are protected by the caller's guard
	
	ACE_Time_Value now(ACE_OS::gettimeofday());
	TimeoutIdx& idx = m_container.get<TimeoutTag>();
	
	TimeoutIdx::iterator it = idx.upper_bound(now);
	idx.erase(idx.begin(), it);
}

void
DnsCache::Impl::removeExcess()
{
	// we are protected by the caller's guard
	
	int excess = int(m_container.size()) - CAPACITY;
	if (excess > 0) {
		TimeoutIdx& idx = m_container.get<TimeoutTag>();
		TimeoutIdx::iterator end(idx.begin());
		std::advance(end, excess);
		idx.erase(idx.begin(), end);
	}
}
