Source code for geo_sampling.sampler

"""Road sampling functionality."""

import csv
import random
from collections import defaultdict
from typing import List, Optional, Union, Dict
from pathlib import Path

from ._types import RoadSegment, RoadTypeFilter
from .visualization import RoadPlotter


[docs] class RoadSampler: """Class for sampling road segments with various strategies.""" def __init__(self, segments: List[RoadSegment]): """Initialize sampler with road segments. Args: segments: List of road segments to sample from """ self.segments = segments self._validate_segments()
[docs] def random_sample( self, n: int, road_types: RoadTypeFilter = None, seed: Optional[int] = None ) -> List[RoadSegment]: """Get random sample of road segments. Args: n: Number of segments to sample road_types: Road types to include (None = all types) seed: Random seed for reproducibility Returns: List of sampled RoadSegment objects Raises: ValueError: If sample size exceeds available segments """ if seed is not None: random.seed(seed) # Filter by road types if specified available_segments = self._filter_by_road_types(road_types) if n > len(available_segments): raise ValueError( f"Sample size {n} exceeds available segments {len(available_segments)}" ) if n <= 0: return [] return random.sample(available_segments, n)
[docs] def stratified_sample( self, n: int, by: str = "osm_type", road_types: RoadTypeFilter = None, seed: Optional[int] = None, ) -> List[RoadSegment]: """Get stratified sample maintaining proportions across strata. Args: n: Total number of segments to sample by: Field to stratify by (default: "osm_type") road_types: Road types to include seed: Random seed for reproducibility Returns: List of sampled RoadSegment objects """ if seed is not None: random.seed(seed) available_segments = self._filter_by_road_types(road_types) if n > len(available_segments): raise ValueError( f"Sample size {n} exceeds available segments {len(available_segments)}" ) # Group segments by stratification field strata = defaultdict(list) for segment in available_segments: stratum_value = getattr(segment, by) strata[stratum_value].append(segment) # Calculate sample sizes for each stratum total_segments = len(available_segments) sampled_segments = [] # Proportional allocation allocated = 0 stratum_samples = {} for stratum, stratum_segments in strata.items(): proportion = len(stratum_segments) / total_segments stratum_n = int(round(n * proportion)) # Ensure we don't exceed available segments in this stratum stratum_n = min(stratum_n, len(stratum_segments)) stratum_samples[stratum] = stratum_n allocated += stratum_n # Adjust for rounding errors - add remaining to largest strata remaining = n - allocated if remaining > 0: # Sort strata by size and allocate remaining samples sorted_strata = sorted( strata.items(), key=lambda x: len(x[1]), reverse=True ) for stratum, stratum_segments in sorted_strata: if remaining <= 0: break # Add one more if possible current_allocation = stratum_samples[stratum] if current_allocation < len(stratum_segments): stratum_samples[stratum] += 1 remaining -= 1 # Sample from each stratum for stratum, stratum_segments in strata.items(): stratum_n = stratum_samples[stratum] if stratum_n > 0: stratum_sample = random.sample(stratum_segments, stratum_n) sampled_segments.extend(stratum_sample) # Shuffle final results random.shuffle(sampled_segments) return sampled_segments
[docs] def sample_by_length( self, target_length_km: float, road_types: RoadTypeFilter = None, seed: Optional[int] = None, ) -> List[RoadSegment]: """Sample segments to approximate a target total length. Args: target_length_km: Target total length in kilometers road_types: Road types to include seed: Random seed for reproducibility Returns: List of sampled RoadSegment objects """ if seed is not None: random.seed(seed) available_segments = self._filter_by_road_types(road_types) # Calculate segment lengths (assuming 500m segments) # In a real implementation, you'd calculate actual lengths segment_length_km = 0.5 # 500m = 0.5km target_n = int(target_length_km / segment_length_km) if target_n > len(available_segments): print( f"Warning: Target length requires {target_n} segments, " f"but only {len(available_segments)} available. Using all." ) return available_segments return random.sample(available_segments, target_n)
[docs] def get_road_type_summary( self, road_types: RoadTypeFilter = None ) -> Dict[str, int]: """Get count summary by road type. Args: road_types: Road types to include Returns: Dictionary mapping road type to count """ segments = self._filter_by_road_types(road_types) summary = defaultdict(int) for segment in segments: summary[segment.osm_type] += 1 return dict(summary)
[docs] def save_csv(self, segments: List[RoadSegment], path: Union[str, Path]) -> None: """Save segments to CSV file. Args: segments: Road segments to save path: Output file path """ columns = [ "segment_id", "osm_id", "osm_name", "osm_type", "start_lat", "start_long", "end_lat", "end_long", ] with open(path, "w", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=columns) writer.writeheader() for segment in segments: writer.writerow(segment.to_dict()) print(f"Saved {len(segments)} segments to {path}")
[docs] def plot_sample( self, sample: List[RoadSegment], title: Optional[str] = None, show_comparison: bool = True, figsize: tuple = (12, 10), ) -> None: """Plot sampled segments. Args: sample: Sampled road segments title: Plot title show_comparison: Whether to show comparison with full dataset figsize: Figure size """ plotter = RoadPlotter(figsize=figsize) if show_comparison and len(self.segments) > len(sample): plotter.plot_sample_comparison(self.segments, sample) else: if not title: title = f"Road Segments Sample (N = {len(sample)})" plotter.plot(sample, title=title)
[docs] def to_dataframe(self, segments: List[RoadSegment]): """Convert segments to pandas DataFrame. Args: segments: Road segments to convert Returns: pandas.DataFrame with segment data """ import pandas as pd data = [segment.to_dict() for segment in segments] return pd.DataFrame(data)
def _filter_by_road_types(self, road_types: RoadTypeFilter) -> List[RoadSegment]: """Filter segments by road types.""" if not road_types: return self.segments # Convert single string to list if isinstance(road_types, str): road_types = [road_types] return [seg for seg in self.segments if seg.osm_type in road_types] def _validate_segments(self) -> None: """Validate input segments.""" if not self.segments: raise ValueError("No segments provided for sampling") if not all(isinstance(seg, RoadSegment) for seg in self.segments): raise ValueError("All segments must be RoadSegment objects")
[docs] def load_segments_from_csv(path: Union[str, Path]) -> List[RoadSegment]: """Load road segments from CSV file. Args: path: Path to CSV file Returns: List of RoadSegment objects """ segments = [] with open(path, "r", encoding="utf-8") as f: reader = csv.DictReader(f) for row in reader: segments.append(RoadSegment.from_dict(row)) return segments
[docs] def sample_roads( segments: List[RoadSegment], n: int, road_types: RoadTypeFilter = None, strategy: str = "random", seed: Optional[int] = None, ) -> List[RoadSegment]: """Convenience function to sample roads. Args: segments: Road segments to sample from n: Number of segments to sample road_types: Road types to include strategy: Sampling strategy ("random" or "stratified") seed: Random seed Returns: List of sampled RoadSegment objects """ sampler = RoadSampler(segments) if strategy == "random": return sampler.random_sample(n, road_types, seed) elif strategy == "stratified": return sampler.stratified_sample(n, road_types=road_types, seed=seed) else: raise ValueError(f"Unknown sampling strategy: {strategy}")