Skip to content

middleware

stac_fastapi.api.middleware

Api middleware.

CORSMiddleware

Bases: CORSMiddleware

Subclass of Starlette's standard CORS middleware with default values set to those recommended by the STAC API spec.

github.com/radiantearth/stac-api-spec/blob/914cf8108302e2ec734340080a45aaae4859bb63/implementation.md#cors

Source code in stac_fastapi/api/middleware.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class CORSMiddleware(_CORSMiddleware):
    """Subclass of Starlette's standard CORS middleware with default values set to those
    recommended by the STAC API spec.

    https://github.com/radiantearth/stac-api-spec/blob/914cf8108302e2ec734340080a45aaae4859bb63/implementation.md#cors
    """

    def __init__(
        self,
        app: ASGIApp,
        allow_origins: typing.Sequence[str] = ("*",),
        allow_methods: typing.Sequence[str] = (
            "OPTIONS",
            "POST",
            "GET",
        ),
        allow_headers: typing.Sequence[str] = ("Content-Type",),
        allow_credentials: bool = False,
        allow_origin_regex: typing.Optional[str] = None,
        expose_headers: typing.Sequence[str] = (),
        max_age: int = 600,
    ) -> None:
        """Create CORS middleware."""
        super().__init__(
            app,
            allow_origins,
            allow_methods,
            allow_headers,
            allow_credentials,
            allow_origin_regex,
            expose_headers,
            max_age,
        )

__init__

__init__(
    app: ASGIApp,
    allow_origins: Sequence[str] = ("*",),
    allow_methods: Sequence[str] = ("OPTIONS", "POST", "GET"),
    allow_headers: Sequence[str] = ("Content-Type",),
    allow_credentials: bool = False,
    allow_origin_regex: Optional[str] = None,
    expose_headers: Sequence[str] = (),
    max_age: int = 600,
) -> None

Create CORS middleware.

Source code in stac_fastapi/api/middleware.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def __init__(
    self,
    app: ASGIApp,
    allow_origins: typing.Sequence[str] = ("*",),
    allow_methods: typing.Sequence[str] = (
        "OPTIONS",
        "POST",
        "GET",
    ),
    allow_headers: typing.Sequence[str] = ("Content-Type",),
    allow_credentials: bool = False,
    allow_origin_regex: typing.Optional[str] = None,
    expose_headers: typing.Sequence[str] = (),
    max_age: int = 600,
) -> None:
    """Create CORS middleware."""
    super().__init__(
        app,
        allow_origins,
        allow_methods,
        allow_headers,
        allow_credentials,
        allow_origin_regex,
        expose_headers,
        max_age,
    )

ProxyHeaderMiddleware

Account for forwarding headers when deriving base URL.

Prioritise standard Forwarded header, look for non-standard X-Forwarded-* if missing. Default to what can be derived from the URL if no headers provided. Middleware updates the host header that is interpreted by starlette when deriving Request.base_url.

Source code in stac_fastapi/api/middleware.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
class ProxyHeaderMiddleware:
    """Account for forwarding headers when deriving base URL.

    Prioritise standard Forwarded header, look for non-standard X-Forwarded-* if missing.
    Default to what can be derived from the URL if no headers provided. Middleware updates
    the host header that is interpreted by starlette when deriving Request.base_url.
    """

    def __init__(self, app: ASGIApp):
        """Create proxy header middleware."""
        self.app = app

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        """Call from stac-fastapi framework."""
        if scope["type"] == "http":
            proto, domain, port = self._get_forwarded_url_parts(scope)
            scope["scheme"] = proto
            if domain is not None:
                port_suffix = ""
                if port is not None:
                    if (proto == "http" and port != HTTP_PORT) or (
                        proto == "https" and port != HTTPS_PORT
                    ):
                        port_suffix = f":{port}"

                scope["headers"] = self._replace_header_value_by_name(
                    scope,
                    "host",
                    f"{domain}{port_suffix}",
                )

        await self.app(scope, receive, send)

    def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str, str, str]:
        proto = scope.get("scheme", "http")
        header_host = self._get_header_value_by_name(scope, "host")
        if header_host is None:
            domain, port = scope.get("server")
        else:
            header_host_parts = header_host.split(":")
            if len(header_host_parts) == 2:
                domain, port = header_host_parts
            else:
                domain = header_host_parts[0]
                port = None

        port_str = None  # make sure it is defined in all paths since we access it later

        if forwarded := self._get_header_value_by_name(scope, "forwarded"):
            for proxy in forwarded.split(","):
                if proto_expr := _PROTO_HEADER_REGEX.search(proxy):
                    proto = proto_expr.group("proto")
                if host_expr := _HOST_HEADER_REGEX.search(proxy):
                    domain = host_expr.group("host")
                    port_str = host_expr.group("port")  # None if not present in the match

        else:
            domain = self._get_header_value_by_name(scope, "x-forwarded-host", domain)
            proto = self._get_header_value_by_name(scope, "x-forwarded-proto", proto)
            port_str = self._get_header_value_by_name(scope, "x-forwarded-port", port)

        with contextlib.suppress(ValueError):  # ignore ports that are not valid integers
            port = int(port_str) if port_str is not None else port

        return (proto, domain, port)

    def _get_header_value_by_name(
        self, scope: Scope, header_name: str, default_value: Optional[str] = None
    ) -> Optional[str]:
        headers = scope["headers"]
        candidates = [
            value.decode() for key, value in headers if key.decode() == header_name
        ]
        return candidates[0] if len(candidates) == 1 else default_value

    @staticmethod
    def _replace_header_value_by_name(
        scope: Scope, header_name: str, new_value: str
    ) -> List[Tuple[str, str]]:
        return [
            (name, value)
            for name, value in scope["headers"]
            if name.decode() != header_name
        ] + [(str.encode(header_name), str.encode(new_value))]

__call__ async

__call__(scope: Scope, receive: Receive, send: Send) -> None

Call from stac-fastapi framework.

Source code in stac_fastapi/api/middleware.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
    """Call from stac-fastapi framework."""
    if scope["type"] == "http":
        proto, domain, port = self._get_forwarded_url_parts(scope)
        scope["scheme"] = proto
        if domain is not None:
            port_suffix = ""
            if port is not None:
                if (proto == "http" and port != HTTP_PORT) or (
                    proto == "https" and port != HTTPS_PORT
                ):
                    port_suffix = f":{port}"

            scope["headers"] = self._replace_header_value_by_name(
                scope,
                "host",
                f"{domain}{port_suffix}",
            )

    await self.app(scope, receive, send)

__init__

__init__(app: ASGIApp)

Create proxy header middleware.

Source code in stac_fastapi/api/middleware.py
60
61
62
def __init__(self, app: ASGIApp):
    """Create proxy header middleware."""
    self.app = app