中间件功能为“检查 client IP,如果不在 allow 范围内,则提前返回 forbidden response”,代码如下:


use std::{future::{ready, Ready}, collections::HashSet};

use actix_web::{dev::{Transform, Service, ServiceRequest, ServiceResponse, forward_ready}, Error, HttpResponse};
use futures_util::future::{LocalBoxFuture, Either};


pub struct IpChecker {
    pub allows: HashSet<String>
}

impl IpChecker {
    pub fn allow(mut self, ip: &str) -> Self {
        self.allows.insert(ip.to_string());
        self
    }
}

impl Default for IpChecker {
    fn default() -> Self {
        Self { allows: HashSet::new() }
    }
}

impl<S> Transform<S, ServiceRequest> for IpChecker
where
    S: Service<ServiceRequest, Response = ServiceResponse, Error = Error>,
    S::Future: 'static,
{
    type Response = S::Response;
    type Error = S::Error;
    type InitError = ();
    type Transform = IpCheckerMiddleware<S>;
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        ready(Ok(IpCheckerMiddleware { service, allows: self.allows.clone() }))
    }
}

pub struct IpCheckerMiddleware<S> {
    service: S,
    allows: HashSet<String>
}

impl<S> Service<ServiceRequest> for IpCheckerMiddleware<S>
where
    S: Service<ServiceRequest, Response = ServiceResponse, Error = Error>,
    S::Future: 'static,
{
    type Response = S::Response;
    type Error = Error;
    type Future = LocalBoxFuture<'static, Result<Self::Response, Self::Error>>;

    forward_ready!(service);

    fn call(&self, req: ServiceRequest) -> Self::Future {
        let conn_info = req.connection_info().clone();
        let mut forbidden = true;
        if let Some(val) =  conn_info.realip_remote_addr() {
            println!("Real Address {:?}", val);
            if self.allows.contains(val) {
                forbidden = false
            }
        }
        log::info!("You requested: {}", req.path());

        let either = if forbidden {
            Either::Left(req.into_response(HttpResponse::Forbidden().body("Forbidden")))
        } else {
            Either::Right(self.service.call(req))
        };

        Box::pin(async move {
            let a = match either {
                Either::Left(res) => Ok(res),
                Either::Right(fut) => fut.await,
            };
            a
            // let res = fut.await?;
            // Ok(HttpResponse::Forbidden().finish())
            // Ok(res)
        })
    }
}

main.rs 中,创建 serverapp 的代码如下:


pub async fn start_server(opts: &ServerOpts) -> std::io::Result<()> {
    let content_dir = opts.content_dir.clone();
    let state = AppState { content_dir };
    HttpServer::new(move || {
        let cors = Cors::default()
            .allow_any_origin()
            .allow_any_method()
            .allow_any_header()
            .max_age(3600);
        App::new()
            .wrap(IpChecker::default().allow("127.0.0.1")) // 放在第一个
            .wrap(cors)
            .app_data(web::Data::new(state.clone()))
            .service(hello)
            .service(upload_file)
    })
    .bind((opts.host.as_str(), opts.port))?
    .run()
    .await
}

注意: 将 IPChecker 放在 wrap 第一个中间价,否则会报类型不兼容错误。