Source code for geo_sampling.extractor

"""Road extraction functionality."""

import csv
from typing import List, Optional, Union
from pathlib import Path

from ._types import RoadSegment, RoadTypeFilter
from .data import GADMProvider, OSMProvider
from .visualization import RoadPlotter


[docs] class RoadExtractor: """Main class for extracting road segments from OpenStreetMap data. This class handles the complete workflow: 1. Download administrative boundaries from GADM 2. Generate BBBike extract URLs 3. Download OSM data 4. Extract and segment roads """ def __init__( self, country: str, region: str, admin_level: int = 1, data_dir: str = "data" ): """Initialize road extractor. Args: country: Country name (e.g., "India") region: Region name (e.g., "NCT of Delhi") admin_level: Administrative level (1-4) data_dir: Directory for caching downloaded data """ self.country = country self.region = region self.admin_level = admin_level self.data_dir = data_dir # Initialize data providers self.gadm = GADMProvider(data_dir) self.osm = OSMProvider(data_dir) # Cached data self._road_segments: Optional[List[RoadSegment]] = None self._country_code: Optional[str] = None
[docs] def get_roads( self, road_types: RoadTypeFilter = None, segment_length: int = 500, force_refresh: bool = False, ) -> List[RoadSegment]: """Extract road segments for the configured region. Args: road_types: Road types to include (None = all types) segment_length: Target segment length in meters force_refresh: Force re-download of data Returns: List of RoadSegment objects """ # Use cached results if available and not forcing refresh if self._road_segments and not force_refresh: return self._filter_by_road_types(self._road_segments, road_types) print(f"Extracting roads for {self.region}, {self.country}...") # Step 1: Get country code and validate country_code = self._get_country_code() print(f"Using country code: {country_code}") # Step 2: Load administrative boundaries print(f"Loading administrative boundaries (level {self.admin_level})...") region_names, region_polygon, bbox = self.gadm.load_boundaries( country_code, self.admin_level, self.region ) print(f"Found regions: {region_names}") print( f"Bounding box: {bbox.min_lat:.3f}, {bbox.min_long:.3f} to {bbox.max_lat:.3f}, {bbox.max_long:.3f}" ) # Step 3: Download OSM data print("Downloading OSM data from BBBike...") roads_shapefile = self.osm.download_osm_data(region_polygon, self.region, bbox) print(f"OSM data downloaded to: {roads_shapefile}") # Step 4: Extract road segments print("Extracting road segments...") self._road_segments = self.osm.extract_road_segments( roads_shapefile, road_types if isinstance(road_types, list) else None, segment_length, ) print(f"Extracted {len(self._road_segments)} road segments") # Return filtered results return self._filter_by_road_types(self._road_segments, road_types)
[docs] def get_available_road_types(self) -> List[str]: """Get list of available road types in the extracted data. Returns: Sorted list of unique road types """ if not self._road_segments: self.get_roads() # Extract roads first return sorted(set(segment.osm_type for segment in self._road_segments))
[docs] def save_csv( self, path: Union[str, Path], road_types: RoadTypeFilter = None, segment_length: int = 500, ) -> None: """Save road segments to CSV file. Args: path: Output file path road_types: Road types to include segment_length: Target segment length in meters """ segments = self.get_roads(road_types, segment_length) columns = [ "segment_id", "osm_id", "osm_name", "osm_type", "start_lat", "start_long", "end_lat", "end_long", ] with open(path, "w", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=columns) writer.writeheader() for segment in segments: writer.writerow(segment.to_dict()) print(f"Saved {len(segments)} segments to {path}")
[docs] def plot( self, road_types: RoadTypeFilter = None, segment_length: int = 500, title: Optional[str] = None, figsize: tuple = (12, 10), ) -> None: """Plot the extracted road segments. Args: road_types: Road types to include segment_length: Target segment length in meters title: Plot title (auto-generated if None) figsize: Figure size """ segments = self.get_roads(road_types, segment_length) if not title: title = ( f"Road Segments: {self.region}, {self.country} (N = {len(segments)})" ) plotter = RoadPlotter(figsize=figsize) plotter.plot(segments, title=title)
[docs] def to_dataframe( self, road_types: RoadTypeFilter = None, segment_length: int = 500 ): """Convert road segments to pandas DataFrame. Args: road_types: Road types to include segment_length: Target segment length in meters Returns: pandas.DataFrame with road segment data """ import pandas as pd segments = self.get_roads(road_types, segment_length) data = [segment.to_dict() for segment in segments] return pd.DataFrame(data)
def _get_country_code(self) -> str: """Get three-letter country code for the country.""" if self._country_code: return self._country_code # Simple mapping for common countries # In a real implementation, you'd want a more comprehensive mapping country_codes = { "india": "IND", "united states": "USA", "usa": "USA", "thailand": "THA", "singapore": "SGP", "united kingdom": "GBR", "uk": "GBR", "canada": "CAN", "australia": "AUS", "china": "CHN", "japan": "JPN", "germany": "DEU", "france": "FRA", "italy": "ITA", "spain": "ESP", "brazil": "BRA", "mexico": "MEX", } country_lower = self.country.lower() if country_lower in country_codes: self._country_code = country_codes[country_lower] return self._country_code # Try first 3 letters as fallback self._country_code = self.country[:3].upper() print( f"Warning: Using fallback country code '{self._country_code}' for '{self.country}'" ) return self._country_code def _filter_by_road_types( self, segments: List[RoadSegment], road_types: RoadTypeFilter ) -> List[RoadSegment]: """Filter segments by road types.""" if not road_types: return segments # Convert single string to list if isinstance(road_types, str): road_types = [road_types] return [seg for seg in segments if seg.osm_type in road_types]
[docs] def extract_roads( country: str, region: str, admin_level: int = 1, road_types: RoadTypeFilter = None, segment_length: int = 500, data_dir: str = "data", ) -> List[RoadSegment]: """Convenience function to extract roads in one call. Args: country: Country name region: Region name admin_level: Administrative level road_types: Road types to include segment_length: Target segment length in meters data_dir: Data directory Returns: List of RoadSegment objects """ extractor = RoadExtractor(country, region, admin_level, data_dir) return extractor.get_roads(road_types, segment_length)