Skip to content

middleware

stac_fastapi.api.middleware

Api middleware.

CORSMiddleware

Bases: CORSMiddleware

Subclass of Starlette's standard CORS middleware with default values set to those recommended by the STAC API spec.

github.com/radiantearth/stac-api-spec/blob/914cf8108302e2ec734340080a45aaae4859bb63/implementation.md#cors

Source code in stac_fastapi/api/stac_fastapi/api/middleware.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class CORSMiddleware(_CORSMiddleware):
    """Subclass of Starlette's standard CORS middleware with default values set to those
    recommended by the STAC API spec.

    https://github.com/radiantearth/stac-api-spec/blob/914cf8108302e2ec734340080a45aaae4859bb63/implementation.md#cors
    """

    def __init__(
        self,
        app: ASGIApp,
        allow_origins: typing.Sequence[str] = ("*",),
        allow_methods: typing.Sequence[str] = (
            "OPTIONS",
            "POST",
            "GET",
        ),
        allow_headers: typing.Sequence[str] = ("Content-Type",),
        allow_credentials: bool = False,
        allow_origin_regex: typing.Optional[str] = None,
        expose_headers: typing.Sequence[str] = (),
        max_age: int = 600,
    ) -> None:
        """Create CORS middleware."""
        super().__init__(
            app,
            allow_origins,
            allow_methods,
            allow_headers,
            allow_credentials,
            allow_origin_regex,
            expose_headers,
            max_age,
        )

__init__

__init__(
    app: ASGIApp,
    allow_origins: Sequence[str] = ("*",),
    allow_methods: Sequence[str] = ("OPTIONS", "POST", "GET"),
    allow_headers: Sequence[str] = ("Content-Type",),
    allow_credentials: bool = False,
    allow_origin_regex: Optional[str] = None,
    expose_headers: Sequence[str] = (),
    max_age: int = 600,
) -> None

Create CORS middleware.

Source code in stac_fastapi/api/stac_fastapi/api/middleware.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def __init__(
    self,
    app: ASGIApp,
    allow_origins: typing.Sequence[str] = ("*",),
    allow_methods: typing.Sequence[str] = (
        "OPTIONS",
        "POST",
        "GET",
    ),
    allow_headers: typing.Sequence[str] = ("Content-Type",),
    allow_credentials: bool = False,
    allow_origin_regex: typing.Optional[str] = None,
    expose_headers: typing.Sequence[str] = (),
    max_age: int = 600,
) -> None:
    """Create CORS middleware."""
    super().__init__(
        app,
        allow_origins,
        allow_methods,
        allow_headers,
        allow_credentials,
        allow_origin_regex,
        expose_headers,
        max_age,
    )

ProxyHeaderMiddleware

Account for forwarding headers when deriving base URL.

Prioritise standard Forwarded header, look for non-standard X-Forwarded-* if missing. Default to what can be derived from the URL if no headers provided. Middleware updates the host header that is interpreted by starlette when deriving Request.base_url.

Source code in stac_fastapi/api/stac_fastapi/api/middleware.py
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class ProxyHeaderMiddleware:
    """Account for forwarding headers when deriving base URL.

    Prioritise standard Forwarded header, look for non-standard X-Forwarded-* if missing.
    Default to what can be derived from the URL if no headers provided. Middleware updates
    the host header that is interpreted by starlette when deriving Request.base_url.
    """

    def __init__(self, app: ASGIApp):
        """Create proxy header middleware."""
        self.app = app

    async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
        """Call from stac-fastapi framework."""
        if scope["type"] == "http":
            proto, domain, port = self._get_forwarded_url_parts(scope)
            scope["scheme"] = proto
            if domain is not None:
                port_suffix = ""
                if port is not None:
                    if (proto == "http" and port != HTTP_PORT) or (
                        proto == "https" and port != HTTPS_PORT
                    ):
                        port_suffix = f":{port}"

                scope["headers"] = self._replace_header_value_by_name(
                    scope,
                    "host",
                    f"{domain}{port_suffix}",
                )

        await self.app(scope, receive, send)

    def _get_forwarded_url_parts(self, scope: Scope) -> Tuple[str, str, int]:
        proto = scope.get("scheme", "http")
        # Assume default port based on protocol, can be overridden later
        port = 443 if proto == "https" else 80

        if header_host := self._get_header_value_by_name(scope, "host"):
            header_host_parts = header_host.split(":")
            domain = header_host_parts[0]
            if len(header_host_parts) == 2:
                with contextlib.suppress(ValueError):
                    port = int(header_host_parts[1])
        else:
            # Not sure when we would not have a host header, but fallback to server info
            domain, port = scope["server"]
            port = int(port)

        forwarded_port: Optional[str] = None
        forwarding_occurred = any(
            key
            in [
                b"forwarded",
                b"x-forwarded-proto",
                b"x-forwarded-host",
                b"x-forwarded-port",
            ]
            for key, _ in scope["headers"]
        )

        if forwarded := self._get_header_value_by_name(scope, "forwarded"):
            for proxy in forwarded.split(","):
                if proto_expr := _PROTO_HEADER_REGEX.search(proxy):
                    proto = proto_expr.group("proto")
                if host_expr := _HOST_HEADER_REGEX.search(proxy):
                    domain = host_expr.group("host")
                    forwarded_port = host_expr.group("port")  # None if not present

        else:
            domain = self._get_header_value_by_name(scope, "x-forwarded-host") or domain
            proto = self._get_header_value_by_name(scope, "x-forwarded-proto") or proto
            forwarded_port = self._get_header_value_by_name(scope, "x-forwarded-port")

        if forwarding_occurred and not forwarded_port:
            # If forwarding occurred but no port was specified, use protocol default
            forwarded_port = "443" if proto == "https" else "80"

        if forwarded_port:
            # ignore ports that are not valid integers
            with contextlib.suppress(ValueError):
                port = int(forwarded_port)

        return (proto, domain, port)

    def _get_header_value_by_name(
        self, scope: Scope, header_name: str, default_value: Optional[str] = None
    ) -> Optional[str]:
        headers = scope["headers"]
        candidates = [
            value.decode() for key, value in headers if key.decode() == header_name
        ]
        return candidates[0] if len(candidates) == 1 else default_value

    @staticmethod
    def _replace_header_value_by_name(
        scope: Scope, header_name: str, new_value: str
    ) -> List[Tuple[str, str]]:
        return [
            (name, value)
            for name, value in scope["headers"]
            if name.decode() != header_name
        ] + [(str.encode(header_name), str.encode(new_value))]

__call__ async

__call__(scope: Scope, receive: Receive, send: Send) -> None

Call from stac-fastapi framework.

Source code in stac_fastapi/api/stac_fastapi/api/middleware.py
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
    """Call from stac-fastapi framework."""
    if scope["type"] == "http":
        proto, domain, port = self._get_forwarded_url_parts(scope)
        scope["scheme"] = proto
        if domain is not None:
            port_suffix = ""
            if port is not None:
                if (proto == "http" and port != HTTP_PORT) or (
                    proto == "https" and port != HTTPS_PORT
                ):
                    port_suffix = f":{port}"

            scope["headers"] = self._replace_header_value_by_name(
                scope,
                "host",
                f"{domain}{port_suffix}",
            )

    await self.app(scope, receive, send)

__init__

__init__(app: ASGIApp)

Create proxy header middleware.

Source code in stac_fastapi/api/stac_fastapi/api/middleware.py
60
61
62
def __init__(self, app: ASGIApp):
    """Create proxy header middleware."""
    self.app = app