// Copyright (C) The Arvados Authors. All rights reserved.
//
// SPDX-License-Identifier: AGPL-3.0

package ws

import (
	"context"
	"net/http"
	"net/url"
	"time"

	"git.arvados.org/arvados.git/sdk/go/arvados"
	"git.arvados.org/arvados.git/sdk/go/ctxlog"
)

const (
	maxPermCacheAge = time.Hour
	minPermCacheAge = 5 * time.Minute
)

type permChecker interface {
	SetToken(token string)
	Check(ctx context.Context, uuid string) (bool, error)
}

func newPermChecker(ac *arvados.Client) permChecker {
	return &cachingPermChecker{
		ac:         ac,
		token:      "-",
		cache:      make(map[string]cacheEnt),
		maxCurrent: 16,
	}
}

type cacheEnt struct {
	time.Time
	allowed bool
}

type cachingPermChecker struct {
	ac         *arvados.Client
	token      string
	cache      map[string]cacheEnt
	maxCurrent int

	nChecks  uint64
	nMisses  uint64
	nInvalid uint64
}

func (pc *cachingPermChecker) SetToken(token string) {
	if pc.token == token {
		return
	}
	pc.token = token
	pc.cache = make(map[string]cacheEnt)
}

func (pc *cachingPermChecker) Check(ctx context.Context, uuid string) (bool, error) {
	pc.nChecks++
	logger := ctxlog.FromContext(ctx).
		WithField("token", pc.token).
		WithField("uuid", uuid)
	pc.tidy()
	now := time.Now()
	if perm, ok := pc.cache[uuid]; ok && now.Sub(perm.Time) < maxPermCacheAge {
		logger.WithField("allowed", perm.allowed).Debug("cache hit")
		return perm.allowed, nil
	}

	path, err := pc.ac.PathForUUID("get", uuid)
	if err != nil {
		pc.nInvalid++
		return false, err
	}

	pc.nMisses++
	ctx = arvados.ContextWithAuthorization(ctx, "Bearer "+pc.token)
	ctx, cancel := context.WithDeadline(ctx, time.Now().Add(time.Minute))
	defer cancel()
	var buf map[string]interface{}
	err = pc.ac.RequestAndDecodeContext(ctx, &buf, "GET", path, nil, url.Values{
		"include_trash": {"true"},
		"select":        {`["uuid"]`},
	})

	var allowed bool
	if err == nil {
		allowed = true
	} else if txErr, ok := err.(*arvados.TransactionError); ok && pc.isNotAllowed(txErr.StatusCode) {
		allowed = false
	} else {
		// If "context deadline exceeded", "client
		// disconnected", HTTP 5xx, network error, etc., don't
		// cache the result.
		logger.WithError(err).Error("lookup error")
		return false, err
	}
	logger.WithField("allowed", allowed).Debug("cache miss")
	pc.cache[uuid] = cacheEnt{Time: now, allowed: allowed}
	return allowed, nil
}

func (pc *cachingPermChecker) isNotAllowed(status int) bool {
	switch status {
	case http.StatusForbidden, http.StatusUnauthorized, http.StatusNotFound:
		return true
	default:
		return false
	}
}

func (pc *cachingPermChecker) tidy() {
	if len(pc.cache) <= pc.maxCurrent*2 {
		return
	}
	tooOld := time.Now().Add(-minPermCacheAge)
	for uuid, t := range pc.cache {
		if t.Before(tooOld) {
			delete(pc.cache, uuid)
		}
	}
	pc.maxCurrent = len(pc.cache)
}