"""
Security middleware and authentication for Odoo API Bridge.
"""
import hashlib
import os
import secrets
import configparser
import socket
from typing import Optional, Dict, List
from fastapi import HTTPException, Header, Request, Depends, status
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
import ipaddress
from datetime import datetime, timedelta
import jwt
# Security configuration
security_scheme = HTTPBearer()
[docs]
class SecurityConfig:
"""Security configuration and validation"""
def __init__(self):
# Load configuration from file
self.config_file = "config/api_security.ini"
self.config = configparser.ConfigParser()
self._load_config_file()
# API Keys - loaded from config file
self.api_keys = self._load_api_keys()
# JWT Secret - from config or environment
self.jwt_secret = self._get_jwt_secret()
# Allowed IP ranges - from config file
self.allowed_ip_ranges = self._load_ip_ranges()
# Rate limiting - from config file
self.rate_limit = self._get_rate_limit()
def _load_config_file(self):
"""Load configuration from file"""
if os.path.exists(self.config_file):
self.config.read(self.config_file)
else:
print(f"Warning: Config file {self.config_file} not found. Using defaults.")
self._create_default_config()
def _create_default_config(self):
"""Create default configuration"""
self.config['security_settings'] = {
'jwt_secret': 'your-super-secret-jwt-key-change-in-production',
'rate_limit': '100',
'allowed_ip_ranges': '10.0.0.0/8,172.16.0.0/12,192.168.0.0/16'
}
self.config['api_keys'] = {}
self.config['client_permissions'] = {}
# Create config directory if it doesn't exist
os.makedirs('config', exist_ok=True)
# Save default config
with open(self.config_file, 'w') as f:
self.config.write(f)
print(f"Created default config at {self.config_file}")
def _get_jwt_secret(self) -> str:
"""Get JWT secret from config or environment"""
# Environment variable takes precedence
env_secret = os.getenv("JWT_SECRET")
if env_secret:
return env_secret
# Otherwise use config file
return self.config.get('security_settings', 'jwt_secret',
fallback='your-super-secret-jwt-key-change-in-production')
def _load_ip_ranges(self) -> List[str]:
"""Load allowed IP ranges from config"""
ranges_str = self.config.get('security_settings', 'allowed_ip_ranges',
fallback='10.0.0.0/8,172.16.0.0/12,192.168.0.0/16')
return [ip.strip() for ip in ranges_str.split(',') if ip.strip()]
def _get_rate_limit(self) -> int:
"""Get rate limit from config"""
return self.config.getint('security_settings', 'rate_limit', fallback=100)
def _load_api_keys(self) -> Dict[str, str]:
"""Load API keys from config file"""
api_keys = {}
if self.config.has_section('api_keys'):
for key_name in self.config.options('api_keys'):
key_data = self.config.get('api_keys', key_name)
parts = key_data.split(',')
if len(parts) >= 3:
hashed_key, description, status = parts[0], parts[1], parts[2]
# Only load active keys
if status.strip().lower() == 'active':
api_keys[hashed_key] = key_name
return api_keys
def _hash_key(self, key: str) -> str:
"""Hash API key for secure storage"""
return hashlib.sha256(key.encode()).hexdigest()
[docs]
def generate_api_key(self) -> str:
"""Generate a new secure API key"""
return secrets.token_urlsafe(32)
# Initialize security config
security_config = SecurityConfig()
# Rate limiting storage (in production, use Redis)
request_counts = {}
[docs]
async def verify_api_key(x_api_key: str = Header(..., description="API Key for authentication")):
"""Verify API key authentication"""
if not x_api_key:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="API key required in X-API-Key header"
)
key_hash = hashlib.sha256(x_api_key.encode()).hexdigest()
if key_hash not in security_config.api_keys:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Invalid API key"
)
return security_config.api_keys[key_hash]
[docs]
async def verify_ip_whitelist(request: Request):
"""Verify client IP is in whitelist (supports both IP ranges and domain names)"""
if not security_config.allowed_ip_ranges:
return True # No IP restriction
client_ip = request.headers.get("x-forwarded-for", "").split(",")[0].strip()
if not client_ip:
client_ip = request.client.host
try:
client_addr = ipaddress.ip_address(client_ip)
for allowed_range in security_config.allowed_ip_ranges:
try:
# Try to parse as IP network (CIDR notation)
if client_addr in ipaddress.ip_network(allowed_range):
return True
except ValueError:
# Not a valid IP range, might be a domain name
try:
# Resolve domain name to IP addresses
resolved_ips = socket.getaddrinfo(allowed_range, None)
for ip_info in resolved_ips:
resolved_ip = ipaddress.ip_address(ip_info[4][0])
if client_addr == resolved_ip:
return True
except (socket.gaierror, OSError):
# Domain resolution failed, skip this entry
pass
except (ipaddress.AddressValueError, ValueError):
pass
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Access denied from IP: {client_ip}"
)
[docs]
async def rate_limiter(request: Request):
"""Basic rate limiting"""
client_ip = request.client.host
current_time = datetime.now()
# Clean old entries (older than 1 minute)
cutoff_time = current_time - timedelta(minutes=1)
request_counts[client_ip] = [
timestamp for timestamp in request_counts.get(client_ip, [])
if timestamp > cutoff_time
]
# Check rate limit
if len(request_counts.get(client_ip, [])) >= security_config.rate_limit:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Rate limit exceeded. Try again later."
)
# Add current request
if client_ip not in request_counts:
request_counts[client_ip] = []
request_counts[client_ip].append(current_time)
[docs]
async def security_dependencies(
request: Request,
client_name: str = Depends(verify_api_key),
ip_check = Depends(verify_ip_whitelist),
rate_check = Depends(rate_limiter)
):
"""Combined security dependencies"""
return {
"client_name": client_name,
"client_ip": request.client.host,
"timestamp": datetime.now().isoformat()
}
# Optional: JWT Token authentication (for more advanced use cases)
[docs]
class JWTAuth:
"""JWT-based authentication"""
[docs]
@staticmethod
def create_token(data: dict, expires_delta: timedelta = timedelta(hours=24)):
"""Create JWT token"""
expire = datetime.utcnow() + expires_delta
to_encode = data.copy()
to_encode.update({"exp": expire})
return jwt.encode(to_encode, security_config.jwt_secret, algorithm="HS256")
[docs]
@staticmethod
def verify_token(token: str):
"""Verify JWT token"""
try:
payload = jwt.decode(token, security_config.jwt_secret, algorithms=["HS256"])
return payload
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token expired"
)
except jwt.JWTError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token"
)
[docs]
async def verify_jwt_token(credentials: HTTPAuthorizationCredentials = Depends(security_scheme)):
"""Verify JWT token from Authorization header"""
return JWTAuth.verify_token(credentials.credentials)
# Google Cloud IAP authentication (for enterprise deployments)
[docs]
async def verify_iap_jwt(request: Request):
"""Verify Google Cloud Identity-Aware Proxy JWT"""
try:
from google.auth.transport import requests as google_requests
from google.oauth2 import id_token
iap_jwt = request.headers.get('x-goog-iap-jwt-assertion')
if not iap_jwt:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="IAP JWT missing"
)
# Verify the token
request_adapter = google_requests.Request()
audience = f"/projects/{os.getenv('GOOGLE_CLOUD_PROJECT')}/global/backendServices/{os.getenv('BACKEND_SERVICE_ID')}"
decoded_jwt = id_token.verify_token(iap_jwt, request_adapter, audience=audience)
return {
"user_email": decoded_jwt.get('email'),
"user_id": decoded_jwt.get('sub'),
"iss": decoded_jwt.get('iss')
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"IAP authentication failed: {str(e)}"
)